""" Description : TB Dispatcher - Multi-TB Generation Coordinator Coordinates generation of multiple targeted TBs based on RTL functional analysis Author : CorrectBench Time : 2026/04/19 """ import os import logging from typing import List, Dict, Tuple, Optional, Any from concurrent.futures import ThreadPoolExecutor, as_completed from autoline.semantic_analyzer import SemanticAnalyzer from autoline.semantic_analyzer import FunctionPointType logger = logging.getLogger(__name__) class TBTask: """ 单个 TB 任务 """ def __init__(self, name: str, targets: List[str], fp_type: str, description: str = "", task_dir: str = None): self.name = name self.targets = targets # 功能点名称列表 self.fp_type = fp_type # 功能点类型 (fsm, counter, condition, protocol, exception) self.description = description self.task_dir = task_dir self.tb_code_v = None # Verilog TB 代码 self.tb_code_py = None # Python 规则代码 self.coverage = 0.0 # 覆盖率 self.passed = False # 是否通过检查 self.error = None # 错误信息 def to_dict(self) -> Dict: return { 'name': self.name, 'targets': self.targets, 'fp_type': self.fp_type, 'description': self.description, 'coverage': self.coverage, 'passed': self.passed, 'error': self.error } class CoverageMerger: """ 合并多个 TB 的覆盖率报告 """ @staticmethod def merge_coverage_dat(coverage_files: List[str], output_file: str) -> Dict: """ 合并多个 coverage.dat 文件 Args: coverage_files: coverage.dat 文件路径列表 output_file: 输出合并后的文件路径 Returns: merged_info: { 'total_lines': int, 'covered_lines': int, 'coverage_percent': float, 'file': str } """ all_lines = {} covered_count = 0 for cov_file in coverage_files: if not os.path.exists(cov_file): logger.warning(f"Coverage file not found: {cov_file}") continue try: with open(cov_file, 'r', encoding='utf-8', errors='ignore') as f: for line in f: line = line.strip() if not line: continue # 解析 Verilator coverage 格式 # %NNNNNN code 或 ~NNNNNN code parts = line.split(None, 1) if len(parts) < 2: continue count_str = parts[0] code = parts[1] if len(parts) > 1 else "" # 提取行号(如果存在) if count_str.startswith('%') or count_str.startswith('~') or count_str.startswith('^'): count = int(count_str[1:]) # 使用代码内容作为 key 来去重 if code not in all_lines or count > 0: all_lines[code] = count except Exception as e: logger.warning(f"Error reading coverage file {cov_file}: {e}") # 计算覆盖率 total_lines = len(all_lines) covered_lines = sum(1 for c in all_lines.values() if c > 0) coverage_percent = (covered_lines / total_lines * 100) if total_lines > 0 else 0.0 # 写入合并后的文件 os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w', encoding='utf-8') as f: for code, count in all_lines.items(): if code.startswith('%') or code.startswith('~'): f.write(f"{count}{code}\n") else: f.write(f"%{count} {code}\n") return { 'total_lines': total_lines, 'covered_lines': covered_lines, 'coverage_percent': coverage_percent, 'file': output_file } @staticmethod def merge_unreachable_reports(report_files: List[str], output_file: str) -> Dict: """ 合并多个不可达分支报告 Args: report_files: unreachable_branches_report.txt 文件列表 output_file: 输出合并后的文件路径 Returns: merged_info: { 'total_unreachable': int, 'total_potentially_coverable': int, 'reports': list of individual reports } """ total_unreachable = 0 total_potentially_coverable = 0 all_reports = [] for report_file in report_files: if not os.path.exists(report_file): continue try: with open(report_file, 'r', encoding='utf-8') as f: content = f.read() # 简单解析 unreachable_count = content.count("Truly unreachable") potentially_count = content.count("Potentially coverable") total_unreachable += unreachable_count total_potentially_coverable += potentially_count all_reports.append({ 'file': report_file, 'unreachable': unreachable_count, 'potentially_coverable': potentially_count }) except Exception as e: logger.warning(f"Error reading report {report_file}: {e}") # 写入合并报告 os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w', encoding='utf-8') as f: f.write("=" * 60 + "\n") f.write("MERGED UNREACHABLE BRANCHES REPORT\n") f.write("=" * 60 + "\n\n") for rp in all_reports: f.write(f"From: {rp['file']}\n") f.write(f" Truly unreachable: {rp['unreachable']}\n") f.write(f" Potentially coverable: {rp['potentially_coverable']}\n\n") f.write("=" * 60 + "\n") f.write(f"Total Truly Unreachable: {total_unreachable}\n") f.write(f"Total Potentially Coverable: {total_potentially_coverable}\n") f.write("=" * 60 + "\n") return { 'total_unreachable': total_unreachable, 'total_potentially_coverable': total_potentially_coverable, 'reports': all_reports, 'file': output_file } class TBDispatcher: """ TB 调度器 - 协调多 TB 生成 工作流程: 1. 使用语义分析器分析 RTL,识别功能区域 2. 按功能类型分组 (FSM, Counter, Protocol, etc.) 3. 为每个分组创建 TB 任务 4. 调度执行各 TB 任务 5. 合并覆盖率报告 """ def __init__(self, prob_data: Dict, config, task_dir: str): """ Args: prob_data: 题目数据,包含 module_code, description, header 等 config: 配置对象 task_dir: 任务工作目录 """ self.prob_data = prob_data self.config = config self.task_dir = task_dir # 获取多 TB 配置 multi_tb_config = getattr(config.autoline, 'multi_tb', None) if hasattr(config, 'autoline') else None self.enabled = getattr(multi_tb_config, 'enabled', False) if multi_tb_config else False self.strategy = getattr(multi_tb_config, 'strategy', 'functional') if multi_tb_config else 'functional' self.max_tb_count = getattr(multi_tb_config, 'max_tb_count', 5) if multi_tb_config else 5 self.parallel = getattr(multi_tb_config, 'parallel', False) if multi_tb_config else False self.auto_threshold_lines = getattr(multi_tb_config, 'auto_threshold_lines', 500) if multi_tb_config else 500 self.tasks: List[TBTask] = [] self.results: List[Dict] = [] self.merged_coverage = {} def _count_rtl_lines(self) -> int: """统计 RTL 代码行数(排除空行和注释)""" code = self.prob_data.get('module_code', '') lines = code.split('\n') count = 0 for line in lines: stripped = line.strip() # 排除空行和注释行 if stripped and not stripped.startswith('//') and not stripped.startswith('/*'): count += 1 return count def should_use_multi_tb(self) -> bool: """ 判断是否应该使用多 TB 模式 Returns: True: 使用多 TB 模式 False: 使用单 TB 模式 """ # 1. 如果配置明确启用,使用多 TB if self.enabled: return True # 2. 根据 RTL 行数自动判断 rtl_lines = self._count_rtl_lines() if rtl_lines > self.auto_threshold_lines: logger.info(f"[TBDispatcher] RTL has {rtl_lines} lines (>{self.auto_threshold_lines}), enabling multi-TB mode") return True return False def analyze_rtl(self) -> Dict[str, Any]: """ 使用语义分析器分析 RTL Returns: analysis_result: 语义分析结果 """ logger.info(f"[TBDispatcher] Analyzing RTL for task: {self.prob_data.get('task_id', 'unknown')}") analyzer = SemanticAnalyzer(self.prob_data['module_code']) result = analyzer.analyze() logger.info(f"[TBDispatcher] Found {len(result.get('function_points', []))} function points") return result def group_function_points(self, analysis_result: Dict[str, Any]) -> Dict[str, List[Dict]]: """ 按功能类型分组功能点,明确每个 TB 的侧重点 Returns: groups: { 'main': [...], # 主测试,包含所有功能点 'fsm': [...], # FSM 状态机专项 'protocol': [...], # 协议接口专项 'datapath': [...], # 数据通路/计数器专项 'exception': [...] # 异常处理专项 } """ # 初始化分组 groups = { 'main': [], # 主测试(所有功能点) 'fsm': [], # FSM 状态机 'protocol': [], # 协议接口 'datapath': [], # 数据通路/计数器/条件分支 'exception': [] # 异常处理 } all_fps = analysis_result.get('function_points', []) # 根据功能点类型分配到对应组 for fp in all_fps: fp_type = fp.get('type', 'condition') fp_name = fp.get('name', '').lower() if 'fsm' in fp_name or fp_type == 'fsm': groups['fsm'].append(fp) elif 'protocol' in fp_name or fp_type == 'protocol': groups['protocol'].append(fp) elif 'counter' in fp_name or 'datapath' in fp_name or fp_type == 'counter': groups['datapath'].append(fp) elif 'exception' in fp_name or fp_type == 'exception': groups['exception'].append(fp) else: # 默认归入 datapath groups['datapath'].append(fp) # main 组包含所有功能点 groups['main'] = all_fps # 过滤空组 groups = {k: v for k, v in groups.items() if v} logger.info(f"[TBDispatcher] Grouped function points: { {k: len(v) for k, v in groups.items()} }") return groups def create_tasks(self, groups: Dict[str, List[Dict]]) -> List[TBTask]: """ 为每个功能组创建 TB 任务,明确侧重点 每个 TB 的侧重点: - tb_main: 主测试,尝试覆盖所有功能点 - tb_fsm: FSM 状态机专项,聚焦状态转换和状态覆盖 - tb_protocol: 协议接口专项,聚焦握手时序和数据传输 - tb_datapath: 数据通路专项,聚焦数据处理和计数器 - tb_exception: 异常处理专项,聚焦边界条件和错误处理 Args: groups: 按类型分组的功能点 Returns: tasks: TB 任务列表 """ tasks = [] # 定义每个 TB 的侧重点描述 focus_descriptions = { 'main': 'Comprehensive test covering all functional points', 'fsm': 'FSM state machine test - focus on state transitions and state coverage', 'protocol': 'Protocol interface test - focus on handshake timing and data transfer', 'datapath': 'Datapath test - focus on data processing and counter logic', 'exception': 'Exception handling test - focus on boundary conditions and error handling' } # 按优先级创建任务(protocol > fsm > datapath > exception > main) priority_order = ['protocol', 'fsm', 'datapath', 'exception', 'main'] for fp_type in priority_order: if fp_type not in groups or not groups[fp_type]: continue if len(tasks) >= self.max_tb_count: logger.info(f"[TBDispatcher] Reached max TB count ({self.max_tb_count}), stopping task creation") break fps = groups[fp_type] # 跳过只有1个功能点的组(除非是 main) if len(fps) <= 1 and fp_type != 'main': logger.info(f"[TBDispatcher] Skipping {fp_type} group (only {len(fps)} function point)") continue # 构建任务 task_name = f"tb_{fp_type}" targets = [fp['name'] for fp in fps] description = f"[{focus_descriptions.get(fp_type, 'Test')}] Targets: {', '.join(targets)}" # 创建任务目录 task_subdir = os.path.join(self.task_dir, task_name) os.makedirs(task_subdir, exist_ok=True) task = TBTask( name=task_name, targets=targets, fp_type=fp_type, description=description, task_dir=task_subdir ) tasks.append(task) logger.info(f"[TBDispatcher] Created task: {task_name}") logger.info(f" Focus: {focus_descriptions.get(fp_type, 'N/A')}") logger.info(f" Targets: {targets}") logger.info(f"[TBDispatcher] Created {len(tasks)} TB tasks") return tasks def run_single_task(self, task: TBTask) -> Dict: """ 运行单个 TB 任务 这个方法会被多次调用,每次处理一个功能组的 TB Args: task: TB 任务 Returns: result: 任务执行结果 """ logger.info(f"[TBDispatcher] Running task: {task.name}") logger.info(f"[TBDispatcher] Targets: {task.targets}") # TODO: 这里需要调用现有的 TBgen -> TBsim -> TBcheck -> CGA 流程 # 目前是占位实现 result = { 'name': task.name, 'targets': task.targets, 'fp_type': task.fp_type, 'coverage': 0.0, 'passed': False, 'error': None, 'tb_dir': task.task_dir } # 临时标记,后续集成后填充真实结果 result_file = os.path.join(task.task_dir, "result.json") if os.path.exists(result_file): try: import json with open(result_file, 'r') as f: saved_result = json.load(f) result.update(saved_result) except: pass return result def run_all(self) -> Dict: """ 运行所有 TB 任务并合并结果 Returns: merged_result: 合并后的结果 """ # [修改] 使用 should_use_multi_tb() 判断是否启用,包括 RTL 行数自动判断 if not self.should_use_multi_tb(): logger.info("[TBDispatcher] Multi-TB mode not needed, skipping") return {'enabled': False} logger.info("[TBDispatcher] Starting multi-TB dispatch") # Step 1: 分析 RTL analysis_result = self.analyze_rtl() # Step 2: 分组功能点 groups = self.group_function_points(analysis_result) if not groups: logger.warning("[TBDispatcher] No function points found, falling back to single TB") return {'enabled': True, 'fallback': True, 'reason': 'no_function_points'} # Step 3: 创建 TB 任务 self.tasks = self.create_tasks(groups) if not self.tasks: logger.warning("[TBDispatcher] No tasks created, falling back to single TB") return {'enabled': True, 'fallback': True, 'reason': 'no_tasks'} # Step 4: 执行任务 if self.parallel: # 并行执行 with ThreadPoolExecutor(max_workers=len(self.tasks)) as executor: futures = {executor.submit(self.run_single_task, task): task for task in self.tasks} for future in as_completed(futures): task = futures[future] try: result = future.result() self.results.append(result) except Exception as e: logger.error(f"[TBDispatcher] Task {task.name} failed: {e}") self.results.append({ 'name': task.name, 'error': str(e), 'passed': False }) else: # 顺序执行 for task in self.tasks: result = self.run_single_task(task) self.results.append(result) # Step 5: 合并覆盖率 self.merged_coverage = self.merge_results() logger.info(f"[TBDispatcher] All tasks completed. Merged coverage: {self.merged_coverage.get('coverage_percent', 0):.2f}%") return self.merged_coverage def merge_results(self) -> Dict: """ 合并所有 TB 任务的结果 Returns: merged: 合并后的结果 """ total_coverage = 0.0 passed_count = 0 total_tasks = len(self.results) coverage_reports = [] unreachable_reports = [] for result in self.results: if result.get('coverage', 0) > 0: total_coverage += result['coverage'] if result.get('passed', False): passed_count += 1 # 收集报告文件路径 tb_dir = result.get('tb_dir', '') cov_file = os.path.join(tb_dir, 'CGA', 'coverage.dat') if os.path.exists(cov_file): coverage_reports.append(cov_file) unreach_file = os.path.join(tb_dir, 'CGA', 'unreachable_branches_report.txt') if os.path.exists(unreach_file): unreachable_reports.append(unreach_file) # 计算平均覆盖率 avg_coverage = total_coverage / total_tasks if total_tasks > 0 else 0.0 # 合并 coverage.dat 文件 merged_cov_file = os.path.join(self.task_dir, "CGA_merged", "coverage.dat") merged_unreach_file = os.path.join(self.task_dir, "CGA_merged", "unreachable_branches_report.txt") merger = CoverageMerger() coverage_info = merger.merge_coverage_dat(coverage_reports, merged_cov_file) if coverage_reports else {} unreachable_info = merger.merge_unreachable_reports(unreachable_reports, merged_unreach_file) if unreachable_reports else {} merged = { 'enabled': True, 'total_tasks': total_tasks, 'passed_tasks': passed_count, 'average_coverage': avg_coverage, 'merged_coverage': coverage_info, 'merged_unreachable': unreachable_info, 'results': self.results } return merged def generate_summary_report(self) -> str: """ 生成多 TB 测试汇总报告 Returns: report: 格式化的报告字符串 """ if not self.merged_coverage: return "No results to report" lines = [] lines.append("=" * 70) lines.append("MULTI-TB COVERAGE SUMMARY REPORT") lines.append("=" * 70) lines.append("") lines.append(f"Total TB Tasks: {self.merged_coverage.get('total_tasks', 0)}") lines.append(f"Passed Tasks: {self.merged_coverage.get('passed_tasks', 0)}") lines.append(f"Average Coverage: {self.merged_coverage.get('average_coverage', 0):.2f}%") lines.append("") # 各任务详情 lines.append("Individual TB Results:") lines.append("-" * 70) for result in self.results: status = "PASS" if result.get('passed') else "FAIL" coverage = result.get('coverage', 0) lines.append(f" [{status}] {result.get('name', 'unknown')}") lines.append(f" Targets: {', '.join(result.get('targets', []))}") lines.append(f" Coverage: {coverage:.2f}%") if result.get('error'): lines.append(f" Error: {result['error']}") lines.append("") # 合并覆盖率 merged_cov = self.merged_coverage.get('merged_coverage', {}) if merged_cov: lines.append("-" * 70) lines.append("Merged Coverage:") lines.append(f" Total Lines: {merged_cov.get('total_lines', 0)}") lines.append(f" Covered Lines: {merged_cov.get('covered_lines', 0)}") lines.append(f" Coverage: {merged_cov.get('coverage_percent', 0):.2f}%") lines.append("") # 不可达分支 merged_unreach = self.merged_coverage.get('merged_unreachable', {}) if merged_unreach: lines.append("-" * 70) lines.append("Unreachable Branches:") lines.append(f" Truly Unreachable: {merged_unreach.get('total_unreachable', 0)}") lines.append(f" Potentially Coverable: {merged_unreach.get('total_potentially_coverable', 0)}") lines.append("") lines.append("=" * 70) return "\n".join(lines) def create_dispatcher(prob_data: Dict, config, task_dir: str) -> TBDispatcher: """ 工厂函数:创建 TB Dispatcher Args: prob_data: 题目数据 config: 配置对象 task_dir: 任务目录 Returns: dispatcher: TBDispatcher 实例 """ return TBDispatcher(prob_data, config, task_dir) if __name__ == "__main__": # 简单的测试代码 import logging logging.basicConfig(level=logging.INFO) # 模拟数据 sample_rtl = """ module example_fsm ( input clk, rst_n, input [1:0] cmd, output reg [3:0] out ); localparam IDLE = 2'b00; localparam RUN = 2'b01; localparam DONE = 2'b10; reg [1:0] state, next_state; always @(posedge clk or negedge rst_n) begin if (!rst_n) state <= IDLE; else state <= next_state; end always @(*) begin case (state) IDLE: next_state = cmd[0] ? RUN : IDLE; RUN: next_state = DONE; DONE: next_state = IDLE; default: next_state = IDLE; endcase end always @(posedge clk) begin if (state == RUN) out <= out + 1; end endmodule """ prob_data = { 'task_id': 'test_fsm', 'module_code': sample_rtl, 'header': 'module example_fsm (input clk, rst_n, input [1:0] cmd, output [3:0] out);', 'description': 'Test FSM module' } class MockConfig: class AutoLine: multi_tb = None autoline = AutoLine() config = MockConfig() # 测试 dispatcher dispatcher = TBDispatcher(prob_data, config, "/tmp/test_dispatcher") dispatcher.enabled = True # 强制启用以便测试 # 分析 RTL result = dispatcher.analyze_rtl() print(f"\nFunction Points Found: {len(result.get('function_points', []))}") # 分组 groups = dispatcher.group_function_points(result) print(f"\nGroups: {groups}") # 创建任务 tasks = dispatcher.create_tasks(groups) print(f"\nTasks Created: {len(tasks)}") for task in tasks: print(f" - {task.name}: {task.targets}") print("\n" + "=" * 50) print("Test completed successfully")