Files
CGA-bench/autoline/tb_dispatcher.py

701 lines
24 KiB
Python
Raw Permalink Normal View History

2026-05-22 10:02:42 +08:00
"""
Description : TB Dispatcher - Multi-TB Generation Coordinator
Coordinates generation of multiple targeted TBs based on RTL functional analysis
Author : CorrectBench
Time : 2026/04/19
"""
import os
import logging
from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from autoline.semantic_analyzer import SemanticAnalyzer
from autoline.semantic_analyzer import FunctionPointType
logger = logging.getLogger(__name__)
class TBTask:
"""
单个 TB 任务
"""
def __init__(self, name: str, targets: List[str], fp_type: str,
description: str = "", task_dir: str = None):
self.name = name
self.targets = targets # 功能点名称列表
self.fp_type = fp_type # 功能点类型 (fsm, counter, condition, protocol, exception)
self.description = description
self.task_dir = task_dir
self.tb_code_v = None # Verilog TB 代码
self.tb_code_py = None # Python 规则代码
self.coverage = 0.0 # 覆盖率
self.passed = False # 是否通过检查
self.error = None # 错误信息
def to_dict(self) -> Dict:
return {
'name': self.name,
'targets': self.targets,
'fp_type': self.fp_type,
'description': self.description,
'coverage': self.coverage,
'passed': self.passed,
'error': self.error
}
class CoverageMerger:
"""
合并多个 TB 的覆盖率报告
"""
@staticmethod
def merge_coverage_dat(coverage_files: List[str], output_file: str) -> Dict:
"""
合并多个 coverage.dat 文件
Args:
coverage_files: coverage.dat 文件路径列表
output_file: 输出合并后的文件路径
Returns:
merged_info: {
'total_lines': int,
'covered_lines': int,
'coverage_percent': float,
'file': str
}
"""
all_lines = {}
covered_count = 0
for cov_file in coverage_files:
if not os.path.exists(cov_file):
logger.warning(f"Coverage file not found: {cov_file}")
continue
try:
with open(cov_file, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
# 解析 Verilator coverage 格式
# %NNNNNN code 或 ~NNNNNN code
parts = line.split(None, 1)
if len(parts) < 2:
continue
count_str = parts[0]
code = parts[1] if len(parts) > 1 else ""
# 提取行号(如果存在)
if count_str.startswith('%') or count_str.startswith('~') or count_str.startswith('^'):
count = int(count_str[1:])
# 使用代码内容作为 key 来去重
if code not in all_lines or count > 0:
all_lines[code] = count
except Exception as e:
logger.warning(f"Error reading coverage file {cov_file}: {e}")
# 计算覆盖率
total_lines = len(all_lines)
covered_lines = sum(1 for c in all_lines.values() if c > 0)
coverage_percent = (covered_lines / total_lines * 100) if total_lines > 0 else 0.0
# 写入合并后的文件
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for code, count in all_lines.items():
if code.startswith('%') or code.startswith('~'):
f.write(f"{count}{code}\n")
else:
f.write(f"%{count} {code}\n")
return {
'total_lines': total_lines,
'covered_lines': covered_lines,
'coverage_percent': coverage_percent,
'file': output_file
}
@staticmethod
def merge_unreachable_reports(report_files: List[str], output_file: str) -> Dict:
"""
合并多个不可达分支报告
Args:
report_files: unreachable_branches_report.txt 文件列表
output_file: 输出合并后的文件路径
Returns:
merged_info: {
'total_unreachable': int,
'total_potentially_coverable': int,
'reports': list of individual reports
}
"""
total_unreachable = 0
total_potentially_coverable = 0
all_reports = []
for report_file in report_files:
if not os.path.exists(report_file):
continue
try:
with open(report_file, 'r', encoding='utf-8') as f:
content = f.read()
# 简单解析
unreachable_count = content.count("Truly unreachable")
potentially_count = content.count("Potentially coverable")
total_unreachable += unreachable_count
total_potentially_coverable += potentially_count
all_reports.append({
'file': report_file,
'unreachable': unreachable_count,
'potentially_coverable': potentially_count
})
except Exception as e:
logger.warning(f"Error reading report {report_file}: {e}")
# 写入合并报告
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
f.write("=" * 60 + "\n")
f.write("MERGED UNREACHABLE BRANCHES REPORT\n")
f.write("=" * 60 + "\n\n")
for rp in all_reports:
f.write(f"From: {rp['file']}\n")
f.write(f" Truly unreachable: {rp['unreachable']}\n")
f.write(f" Potentially coverable: {rp['potentially_coverable']}\n\n")
f.write("=" * 60 + "\n")
f.write(f"Total Truly Unreachable: {total_unreachable}\n")
f.write(f"Total Potentially Coverable: {total_potentially_coverable}\n")
f.write("=" * 60 + "\n")
return {
'total_unreachable': total_unreachable,
'total_potentially_coverable': total_potentially_coverable,
'reports': all_reports,
'file': output_file
}
class TBDispatcher:
"""
TB 调度器 - 协调多 TB 生成
工作流程:
1. 使用语义分析器分析 RTL识别功能区域
2. 按功能类型分组 (FSM, Counter, Protocol, etc.)
3. 为每个分组创建 TB 任务
4. 调度执行各 TB 任务
5. 合并覆盖率报告
"""
def __init__(self, prob_data: Dict, config, task_dir: str):
"""
Args:
prob_data: 题目数据包含 module_code, description, header
config: 配置对象
task_dir: 任务工作目录
"""
self.prob_data = prob_data
self.config = config
self.task_dir = task_dir
# 获取多 TB 配置
multi_tb_config = getattr(config.autoline, 'multi_tb', None) if hasattr(config, 'autoline') else None
self.enabled = getattr(multi_tb_config, 'enabled', False) if multi_tb_config else False
self.strategy = getattr(multi_tb_config, 'strategy', 'functional') if multi_tb_config else 'functional'
self.max_tb_count = getattr(multi_tb_config, 'max_tb_count', 5) if multi_tb_config else 5
self.parallel = getattr(multi_tb_config, 'parallel', False) if multi_tb_config else False
self.auto_threshold_lines = getattr(multi_tb_config, 'auto_threshold_lines', 500) if multi_tb_config else 500
self.tasks: List[TBTask] = []
self.results: List[Dict] = []
self.merged_coverage = {}
def _count_rtl_lines(self) -> int:
"""统计 RTL 代码行数(排除空行和注释)"""
code = self.prob_data.get('module_code', '')
lines = code.split('\n')
count = 0
for line in lines:
stripped = line.strip()
# 排除空行和注释行
if stripped and not stripped.startswith('//') and not stripped.startswith('/*'):
count += 1
return count
def should_use_multi_tb(self) -> bool:
"""
判断是否应该使用多 TB 模式
Returns:
True: 使用多 TB 模式
False: 使用单 TB 模式
"""
# 1. 如果配置明确启用,使用多 TB
if self.enabled:
return True
# 2. 根据 RTL 行数自动判断
rtl_lines = self._count_rtl_lines()
if rtl_lines > self.auto_threshold_lines:
logger.info(f"[TBDispatcher] RTL has {rtl_lines} lines (>{self.auto_threshold_lines}), enabling multi-TB mode")
return True
return False
def analyze_rtl(self) -> Dict[str, Any]:
"""
使用语义分析器分析 RTL
Returns:
analysis_result: 语义分析结果
"""
logger.info(f"[TBDispatcher] Analyzing RTL for task: {self.prob_data.get('task_id', 'unknown')}")
analyzer = SemanticAnalyzer(self.prob_data['module_code'])
result = analyzer.analyze()
logger.info(f"[TBDispatcher] Found {len(result.get('function_points', []))} function points")
return result
def group_function_points(self, analysis_result: Dict[str, Any]) -> Dict[str, List[Dict]]:
"""
按功能类型分组功能点明确每个 TB 的侧重点
Returns:
groups: {
'main': [...], # 主测试,包含所有功能点
'fsm': [...], # FSM 状态机专项
'protocol': [...], # 协议接口专项
'datapath': [...], # 数据通路/计数器专项
'exception': [...] # 异常处理专项
}
"""
# 初始化分组
groups = {
'main': [], # 主测试(所有功能点)
'fsm': [], # FSM 状态机
'protocol': [], # 协议接口
'datapath': [], # 数据通路/计数器/条件分支
'exception': [] # 异常处理
}
all_fps = analysis_result.get('function_points', [])
# 根据功能点类型分配到对应组
for fp in all_fps:
fp_type = fp.get('type', 'condition')
fp_name = fp.get('name', '').lower()
if 'fsm' in fp_name or fp_type == 'fsm':
groups['fsm'].append(fp)
elif 'protocol' in fp_name or fp_type == 'protocol':
groups['protocol'].append(fp)
elif 'counter' in fp_name or 'datapath' in fp_name or fp_type == 'counter':
groups['datapath'].append(fp)
elif 'exception' in fp_name or fp_type == 'exception':
groups['exception'].append(fp)
else:
# 默认归入 datapath
groups['datapath'].append(fp)
# main 组包含所有功能点
groups['main'] = all_fps
# 过滤空组
groups = {k: v for k, v in groups.items() if v}
logger.info(f"[TBDispatcher] Grouped function points: { {k: len(v) for k, v in groups.items()} }")
return groups
def create_tasks(self, groups: Dict[str, List[Dict]]) -> List[TBTask]:
"""
为每个功能组创建 TB 任务明确侧重点
每个 TB 的侧重点
- tb_main: 主测试尝试覆盖所有功能点
- tb_fsm: FSM 状态机专项聚焦状态转换和状态覆盖
- tb_protocol: 协议接口专项聚焦握手时序和数据传输
- tb_datapath: 数据通路专项聚焦数据处理和计数器
- tb_exception: 异常处理专项聚焦边界条件和错误处理
Args:
groups: 按类型分组的功能点
Returns:
tasks: TB 任务列表
"""
tasks = []
# 定义每个 TB 的侧重点描述
focus_descriptions = {
'main': 'Comprehensive test covering all functional points',
'fsm': 'FSM state machine test - focus on state transitions and state coverage',
'protocol': 'Protocol interface test - focus on handshake timing and data transfer',
'datapath': 'Datapath test - focus on data processing and counter logic',
'exception': 'Exception handling test - focus on boundary conditions and error handling'
}
# 按优先级创建任务protocol > fsm > datapath > exception > main
priority_order = ['protocol', 'fsm', 'datapath', 'exception', 'main']
for fp_type in priority_order:
if fp_type not in groups or not groups[fp_type]:
continue
if len(tasks) >= self.max_tb_count:
logger.info(f"[TBDispatcher] Reached max TB count ({self.max_tb_count}), stopping task creation")
break
fps = groups[fp_type]
# 跳过只有1个功能点的组除非是 main
if len(fps) <= 1 and fp_type != 'main':
logger.info(f"[TBDispatcher] Skipping {fp_type} group (only {len(fps)} function point)")
continue
# 构建任务
task_name = f"tb_{fp_type}"
targets = [fp['name'] for fp in fps]
description = f"[{focus_descriptions.get(fp_type, 'Test')}] Targets: {', '.join(targets)}"
# 创建任务目录
task_subdir = os.path.join(self.task_dir, task_name)
os.makedirs(task_subdir, exist_ok=True)
task = TBTask(
name=task_name,
targets=targets,
fp_type=fp_type,
description=description,
task_dir=task_subdir
)
tasks.append(task)
logger.info(f"[TBDispatcher] Created task: {task_name}")
logger.info(f" Focus: {focus_descriptions.get(fp_type, 'N/A')}")
logger.info(f" Targets: {targets}")
logger.info(f"[TBDispatcher] Created {len(tasks)} TB tasks")
return tasks
def run_single_task(self, task: TBTask) -> Dict:
"""
运行单个 TB 任务
这个方法会被多次调用每次处理一个功能组的 TB
Args:
task: TB 任务
Returns:
result: 任务执行结果
"""
logger.info(f"[TBDispatcher] Running task: {task.name}")
logger.info(f"[TBDispatcher] Targets: {task.targets}")
# TODO: 这里需要调用现有的 TBgen -> TBsim -> TBcheck -> CGA 流程
# 目前是占位实现
result = {
'name': task.name,
'targets': task.targets,
'fp_type': task.fp_type,
'coverage': 0.0,
'passed': False,
'error': None,
'tb_dir': task.task_dir
}
# 临时标记,后续集成后填充真实结果
result_file = os.path.join(task.task_dir, "result.json")
if os.path.exists(result_file):
try:
import json
with open(result_file, 'r') as f:
saved_result = json.load(f)
result.update(saved_result)
except:
pass
return result
def run_all(self) -> Dict:
"""
运行所有 TB 任务并合并结果
Returns:
merged_result: 合并后的结果
"""
# [修改] 使用 should_use_multi_tb() 判断是否启用,包括 RTL 行数自动判断
if not self.should_use_multi_tb():
logger.info("[TBDispatcher] Multi-TB mode not needed, skipping")
return {'enabled': False}
logger.info("[TBDispatcher] Starting multi-TB dispatch")
# Step 1: 分析 RTL
analysis_result = self.analyze_rtl()
# Step 2: 分组功能点
groups = self.group_function_points(analysis_result)
if not groups:
logger.warning("[TBDispatcher] No function points found, falling back to single TB")
return {'enabled': True, 'fallback': True, 'reason': 'no_function_points'}
# Step 3: 创建 TB 任务
self.tasks = self.create_tasks(groups)
if not self.tasks:
logger.warning("[TBDispatcher] No tasks created, falling back to single TB")
return {'enabled': True, 'fallback': True, 'reason': 'no_tasks'}
# Step 4: 执行任务
if self.parallel:
# 并行执行
with ThreadPoolExecutor(max_workers=len(self.tasks)) as executor:
futures = {executor.submit(self.run_single_task, task): task for task in self.tasks}
for future in as_completed(futures):
task = futures[future]
try:
result = future.result()
self.results.append(result)
except Exception as e:
logger.error(f"[TBDispatcher] Task {task.name} failed: {e}")
self.results.append({
'name': task.name,
'error': str(e),
'passed': False
})
else:
# 顺序执行
for task in self.tasks:
result = self.run_single_task(task)
self.results.append(result)
# Step 5: 合并覆盖率
self.merged_coverage = self.merge_results()
logger.info(f"[TBDispatcher] All tasks completed. Merged coverage: {self.merged_coverage.get('coverage_percent', 0):.2f}%")
return self.merged_coverage
def merge_results(self) -> Dict:
"""
合并所有 TB 任务的结果
Returns:
merged: 合并后的结果
"""
total_coverage = 0.0
passed_count = 0
total_tasks = len(self.results)
coverage_reports = []
unreachable_reports = []
for result in self.results:
if result.get('coverage', 0) > 0:
total_coverage += result['coverage']
if result.get('passed', False):
passed_count += 1
# 收集报告文件路径
tb_dir = result.get('tb_dir', '')
cov_file = os.path.join(tb_dir, 'CGA', 'coverage.dat')
if os.path.exists(cov_file):
coverage_reports.append(cov_file)
unreach_file = os.path.join(tb_dir, 'CGA', 'unreachable_branches_report.txt')
if os.path.exists(unreach_file):
unreachable_reports.append(unreach_file)
# 计算平均覆盖率
avg_coverage = total_coverage / total_tasks if total_tasks > 0 else 0.0
# 合并 coverage.dat 文件
merged_cov_file = os.path.join(self.task_dir, "CGA_merged", "coverage.dat")
merged_unreach_file = os.path.join(self.task_dir, "CGA_merged", "unreachable_branches_report.txt")
merger = CoverageMerger()
coverage_info = merger.merge_coverage_dat(coverage_reports, merged_cov_file) if coverage_reports else {}
unreachable_info = merger.merge_unreachable_reports(unreachable_reports, merged_unreach_file) if unreachable_reports else {}
merged = {
'enabled': True,
'total_tasks': total_tasks,
'passed_tasks': passed_count,
'average_coverage': avg_coverage,
'merged_coverage': coverage_info,
'merged_unreachable': unreachable_info,
'results': self.results
}
return merged
def generate_summary_report(self) -> str:
"""
生成多 TB 测试汇总报告
Returns:
report: 格式化的报告字符串
"""
if not self.merged_coverage:
return "No results to report"
lines = []
lines.append("=" * 70)
lines.append("MULTI-TB COVERAGE SUMMARY REPORT")
lines.append("=" * 70)
lines.append("")
lines.append(f"Total TB Tasks: {self.merged_coverage.get('total_tasks', 0)}")
lines.append(f"Passed Tasks: {self.merged_coverage.get('passed_tasks', 0)}")
lines.append(f"Average Coverage: {self.merged_coverage.get('average_coverage', 0):.2f}%")
lines.append("")
# 各任务详情
lines.append("Individual TB Results:")
lines.append("-" * 70)
for result in self.results:
status = "PASS" if result.get('passed') else "FAIL"
coverage = result.get('coverage', 0)
lines.append(f" [{status}] {result.get('name', 'unknown')}")
lines.append(f" Targets: {', '.join(result.get('targets', []))}")
lines.append(f" Coverage: {coverage:.2f}%")
if result.get('error'):
lines.append(f" Error: {result['error']}")
lines.append("")
# 合并覆盖率
merged_cov = self.merged_coverage.get('merged_coverage', {})
if merged_cov:
lines.append("-" * 70)
lines.append("Merged Coverage:")
lines.append(f" Total Lines: {merged_cov.get('total_lines', 0)}")
lines.append(f" Covered Lines: {merged_cov.get('covered_lines', 0)}")
lines.append(f" Coverage: {merged_cov.get('coverage_percent', 0):.2f}%")
lines.append("")
# 不可达分支
merged_unreach = self.merged_coverage.get('merged_unreachable', {})
if merged_unreach:
lines.append("-" * 70)
lines.append("Unreachable Branches:")
lines.append(f" Truly Unreachable: {merged_unreach.get('total_unreachable', 0)}")
lines.append(f" Potentially Coverable: {merged_unreach.get('total_potentially_coverable', 0)}")
lines.append("")
lines.append("=" * 70)
return "\n".join(lines)
def create_dispatcher(prob_data: Dict, config, task_dir: str) -> TBDispatcher:
"""
工厂函数创建 TB Dispatcher
Args:
prob_data: 题目数据
config: 配置对象
task_dir: 任务目录
Returns:
dispatcher: TBDispatcher 实例
"""
return TBDispatcher(prob_data, config, task_dir)
if __name__ == "__main__":
# 简单的测试代码
import logging
logging.basicConfig(level=logging.INFO)
# 模拟数据
sample_rtl = """
module example_fsm (
input clk, rst_n,
input [1:0] cmd,
output reg [3:0] out
);
localparam IDLE = 2'b00;
localparam RUN = 2'b01;
localparam DONE = 2'b10;
reg [1:0] state, next_state;
always @(posedge clk or negedge rst_n) begin
if (!rst_n)
state <= IDLE;
else
state <= next_state;
end
always @(*) begin
case (state)
IDLE: next_state = cmd[0] ? RUN : IDLE;
RUN: next_state = DONE;
DONE: next_state = IDLE;
default: next_state = IDLE;
endcase
end
always @(posedge clk) begin
if (state == RUN)
out <= out + 1;
end
endmodule
"""
prob_data = {
'task_id': 'test_fsm',
'module_code': sample_rtl,
'header': 'module example_fsm (input clk, rst_n, input [1:0] cmd, output [3:0] out);',
'description': 'Test FSM module'
}
class MockConfig:
class AutoLine:
multi_tb = None
autoline = AutoLine()
config = MockConfig()
# 测试 dispatcher
dispatcher = TBDispatcher(prob_data, config, "/tmp/test_dispatcher")
dispatcher.enabled = True # 强制启用以便测试
# 分析 RTL
result = dispatcher.analyze_rtl()
print(f"\nFunction Points Found: {len(result.get('function_points', []))}")
# 分组
groups = dispatcher.group_function_points(result)
print(f"\nGroups: {groups}")
# 创建任务
tasks = dispatcher.create_tasks(groups)
print(f"\nTasks Created: {len(tasks)}")
for task in tasks:
print(f" - {task.name}: {task.targets}")
print("\n" + "=" * 50)
print("Test completed successfully")