Files
CGA-bench/autoline/tb_dispatcher.py
2026-05-22 10:02:42 +08:00

701 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Description : TB Dispatcher - Multi-TB Generation Coordinator
Coordinates generation of multiple targeted TBs based on RTL functional analysis
Author : CorrectBench
Time : 2026/04/19
"""
import os
import logging
from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from autoline.semantic_analyzer import SemanticAnalyzer
from autoline.semantic_analyzer import FunctionPointType
logger = logging.getLogger(__name__)
class TBTask:
"""
单个 TB 任务
"""
def __init__(self, name: str, targets: List[str], fp_type: str,
description: str = "", task_dir: str = None):
self.name = name
self.targets = targets # 功能点名称列表
self.fp_type = fp_type # 功能点类型 (fsm, counter, condition, protocol, exception)
self.description = description
self.task_dir = task_dir
self.tb_code_v = None # Verilog TB 代码
self.tb_code_py = None # Python 规则代码
self.coverage = 0.0 # 覆盖率
self.passed = False # 是否通过检查
self.error = None # 错误信息
def to_dict(self) -> Dict:
return {
'name': self.name,
'targets': self.targets,
'fp_type': self.fp_type,
'description': self.description,
'coverage': self.coverage,
'passed': self.passed,
'error': self.error
}
class CoverageMerger:
"""
合并多个 TB 的覆盖率报告
"""
@staticmethod
def merge_coverage_dat(coverage_files: List[str], output_file: str) -> Dict:
"""
合并多个 coverage.dat 文件
Args:
coverage_files: coverage.dat 文件路径列表
output_file: 输出合并后的文件路径
Returns:
merged_info: {
'total_lines': int,
'covered_lines': int,
'coverage_percent': float,
'file': str
}
"""
all_lines = {}
covered_count = 0
for cov_file in coverage_files:
if not os.path.exists(cov_file):
logger.warning(f"Coverage file not found: {cov_file}")
continue
try:
with open(cov_file, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
# 解析 Verilator coverage 格式
# %NNNNNN code 或 ~NNNNNN code
parts = line.split(None, 1)
if len(parts) < 2:
continue
count_str = parts[0]
code = parts[1] if len(parts) > 1 else ""
# 提取行号(如果存在)
if count_str.startswith('%') or count_str.startswith('~') or count_str.startswith('^'):
count = int(count_str[1:])
# 使用代码内容作为 key 来去重
if code not in all_lines or count > 0:
all_lines[code] = count
except Exception as e:
logger.warning(f"Error reading coverage file {cov_file}: {e}")
# 计算覆盖率
total_lines = len(all_lines)
covered_lines = sum(1 for c in all_lines.values() if c > 0)
coverage_percent = (covered_lines / total_lines * 100) if total_lines > 0 else 0.0
# 写入合并后的文件
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for code, count in all_lines.items():
if code.startswith('%') or code.startswith('~'):
f.write(f"{count}{code}\n")
else:
f.write(f"%{count} {code}\n")
return {
'total_lines': total_lines,
'covered_lines': covered_lines,
'coverage_percent': coverage_percent,
'file': output_file
}
@staticmethod
def merge_unreachable_reports(report_files: List[str], output_file: str) -> Dict:
"""
合并多个不可达分支报告
Args:
report_files: unreachable_branches_report.txt 文件列表
output_file: 输出合并后的文件路径
Returns:
merged_info: {
'total_unreachable': int,
'total_potentially_coverable': int,
'reports': list of individual reports
}
"""
total_unreachable = 0
total_potentially_coverable = 0
all_reports = []
for report_file in report_files:
if not os.path.exists(report_file):
continue
try:
with open(report_file, 'r', encoding='utf-8') as f:
content = f.read()
# 简单解析
unreachable_count = content.count("Truly unreachable")
potentially_count = content.count("Potentially coverable")
total_unreachable += unreachable_count
total_potentially_coverable += potentially_count
all_reports.append({
'file': report_file,
'unreachable': unreachable_count,
'potentially_coverable': potentially_count
})
except Exception as e:
logger.warning(f"Error reading report {report_file}: {e}")
# 写入合并报告
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
f.write("=" * 60 + "\n")
f.write("MERGED UNREACHABLE BRANCHES REPORT\n")
f.write("=" * 60 + "\n\n")
for rp in all_reports:
f.write(f"From: {rp['file']}\n")
f.write(f" Truly unreachable: {rp['unreachable']}\n")
f.write(f" Potentially coverable: {rp['potentially_coverable']}\n\n")
f.write("=" * 60 + "\n")
f.write(f"Total Truly Unreachable: {total_unreachable}\n")
f.write(f"Total Potentially Coverable: {total_potentially_coverable}\n")
f.write("=" * 60 + "\n")
return {
'total_unreachable': total_unreachable,
'total_potentially_coverable': total_potentially_coverable,
'reports': all_reports,
'file': output_file
}
class TBDispatcher:
"""
TB 调度器 - 协调多 TB 生成
工作流程:
1. 使用语义分析器分析 RTL识别功能区域
2. 按功能类型分组 (FSM, Counter, Protocol, etc.)
3. 为每个分组创建 TB 任务
4. 调度执行各 TB 任务
5. 合并覆盖率报告
"""
def __init__(self, prob_data: Dict, config, task_dir: str):
"""
Args:
prob_data: 题目数据,包含 module_code, description, header 等
config: 配置对象
task_dir: 任务工作目录
"""
self.prob_data = prob_data
self.config = config
self.task_dir = task_dir
# 获取多 TB 配置
multi_tb_config = getattr(config.autoline, 'multi_tb', None) if hasattr(config, 'autoline') else None
self.enabled = getattr(multi_tb_config, 'enabled', False) if multi_tb_config else False
self.strategy = getattr(multi_tb_config, 'strategy', 'functional') if multi_tb_config else 'functional'
self.max_tb_count = getattr(multi_tb_config, 'max_tb_count', 5) if multi_tb_config else 5
self.parallel = getattr(multi_tb_config, 'parallel', False) if multi_tb_config else False
self.auto_threshold_lines = getattr(multi_tb_config, 'auto_threshold_lines', 500) if multi_tb_config else 500
self.tasks: List[TBTask] = []
self.results: List[Dict] = []
self.merged_coverage = {}
def _count_rtl_lines(self) -> int:
"""统计 RTL 代码行数(排除空行和注释)"""
code = self.prob_data.get('module_code', '')
lines = code.split('\n')
count = 0
for line in lines:
stripped = line.strip()
# 排除空行和注释行
if stripped and not stripped.startswith('//') and not stripped.startswith('/*'):
count += 1
return count
def should_use_multi_tb(self) -> bool:
"""
判断是否应该使用多 TB 模式
Returns:
True: 使用多 TB 模式
False: 使用单 TB 模式
"""
# 1. 如果配置明确启用,使用多 TB
if self.enabled:
return True
# 2. 根据 RTL 行数自动判断
rtl_lines = self._count_rtl_lines()
if rtl_lines > self.auto_threshold_lines:
logger.info(f"[TBDispatcher] RTL has {rtl_lines} lines (>{self.auto_threshold_lines}), enabling multi-TB mode")
return True
return False
def analyze_rtl(self) -> Dict[str, Any]:
"""
使用语义分析器分析 RTL
Returns:
analysis_result: 语义分析结果
"""
logger.info(f"[TBDispatcher] Analyzing RTL for task: {self.prob_data.get('task_id', 'unknown')}")
analyzer = SemanticAnalyzer(self.prob_data['module_code'])
result = analyzer.analyze()
logger.info(f"[TBDispatcher] Found {len(result.get('function_points', []))} function points")
return result
def group_function_points(self, analysis_result: Dict[str, Any]) -> Dict[str, List[Dict]]:
"""
按功能类型分组功能点,明确每个 TB 的侧重点
Returns:
groups: {
'main': [...], # 主测试,包含所有功能点
'fsm': [...], # FSM 状态机专项
'protocol': [...], # 协议接口专项
'datapath': [...], # 数据通路/计数器专项
'exception': [...] # 异常处理专项
}
"""
# 初始化分组
groups = {
'main': [], # 主测试(所有功能点)
'fsm': [], # FSM 状态机
'protocol': [], # 协议接口
'datapath': [], # 数据通路/计数器/条件分支
'exception': [] # 异常处理
}
all_fps = analysis_result.get('function_points', [])
# 根据功能点类型分配到对应组
for fp in all_fps:
fp_type = fp.get('type', 'condition')
fp_name = fp.get('name', '').lower()
if 'fsm' in fp_name or fp_type == 'fsm':
groups['fsm'].append(fp)
elif 'protocol' in fp_name or fp_type == 'protocol':
groups['protocol'].append(fp)
elif 'counter' in fp_name or 'datapath' in fp_name or fp_type == 'counter':
groups['datapath'].append(fp)
elif 'exception' in fp_name or fp_type == 'exception':
groups['exception'].append(fp)
else:
# 默认归入 datapath
groups['datapath'].append(fp)
# main 组包含所有功能点
groups['main'] = all_fps
# 过滤空组
groups = {k: v for k, v in groups.items() if v}
logger.info(f"[TBDispatcher] Grouped function points: { {k: len(v) for k, v in groups.items()} }")
return groups
def create_tasks(self, groups: Dict[str, List[Dict]]) -> List[TBTask]:
"""
为每个功能组创建 TB 任务,明确侧重点
每个 TB 的侧重点:
- tb_main: 主测试,尝试覆盖所有功能点
- tb_fsm: FSM 状态机专项,聚焦状态转换和状态覆盖
- tb_protocol: 协议接口专项,聚焦握手时序和数据传输
- tb_datapath: 数据通路专项,聚焦数据处理和计数器
- tb_exception: 异常处理专项,聚焦边界条件和错误处理
Args:
groups: 按类型分组的功能点
Returns:
tasks: TB 任务列表
"""
tasks = []
# 定义每个 TB 的侧重点描述
focus_descriptions = {
'main': 'Comprehensive test covering all functional points',
'fsm': 'FSM state machine test - focus on state transitions and state coverage',
'protocol': 'Protocol interface test - focus on handshake timing and data transfer',
'datapath': 'Datapath test - focus on data processing and counter logic',
'exception': 'Exception handling test - focus on boundary conditions and error handling'
}
# 按优先级创建任务protocol > fsm > datapath > exception > main
priority_order = ['protocol', 'fsm', 'datapath', 'exception', 'main']
for fp_type in priority_order:
if fp_type not in groups or not groups[fp_type]:
continue
if len(tasks) >= self.max_tb_count:
logger.info(f"[TBDispatcher] Reached max TB count ({self.max_tb_count}), stopping task creation")
break
fps = groups[fp_type]
# 跳过只有1个功能点的组除非是 main
if len(fps) <= 1 and fp_type != 'main':
logger.info(f"[TBDispatcher] Skipping {fp_type} group (only {len(fps)} function point)")
continue
# 构建任务
task_name = f"tb_{fp_type}"
targets = [fp['name'] for fp in fps]
description = f"[{focus_descriptions.get(fp_type, 'Test')}] Targets: {', '.join(targets)}"
# 创建任务目录
task_subdir = os.path.join(self.task_dir, task_name)
os.makedirs(task_subdir, exist_ok=True)
task = TBTask(
name=task_name,
targets=targets,
fp_type=fp_type,
description=description,
task_dir=task_subdir
)
tasks.append(task)
logger.info(f"[TBDispatcher] Created task: {task_name}")
logger.info(f" Focus: {focus_descriptions.get(fp_type, 'N/A')}")
logger.info(f" Targets: {targets}")
logger.info(f"[TBDispatcher] Created {len(tasks)} TB tasks")
return tasks
def run_single_task(self, task: TBTask) -> Dict:
"""
运行单个 TB 任务
这个方法会被多次调用,每次处理一个功能组的 TB
Args:
task: TB 任务
Returns:
result: 任务执行结果
"""
logger.info(f"[TBDispatcher] Running task: {task.name}")
logger.info(f"[TBDispatcher] Targets: {task.targets}")
# TODO: 这里需要调用现有的 TBgen -> TBsim -> TBcheck -> CGA 流程
# 目前是占位实现
result = {
'name': task.name,
'targets': task.targets,
'fp_type': task.fp_type,
'coverage': 0.0,
'passed': False,
'error': None,
'tb_dir': task.task_dir
}
# 临时标记,后续集成后填充真实结果
result_file = os.path.join(task.task_dir, "result.json")
if os.path.exists(result_file):
try:
import json
with open(result_file, 'r') as f:
saved_result = json.load(f)
result.update(saved_result)
except:
pass
return result
def run_all(self) -> Dict:
"""
运行所有 TB 任务并合并结果
Returns:
merged_result: 合并后的结果
"""
# [修改] 使用 should_use_multi_tb() 判断是否启用,包括 RTL 行数自动判断
if not self.should_use_multi_tb():
logger.info("[TBDispatcher] Multi-TB mode not needed, skipping")
return {'enabled': False}
logger.info("[TBDispatcher] Starting multi-TB dispatch")
# Step 1: 分析 RTL
analysis_result = self.analyze_rtl()
# Step 2: 分组功能点
groups = self.group_function_points(analysis_result)
if not groups:
logger.warning("[TBDispatcher] No function points found, falling back to single TB")
return {'enabled': True, 'fallback': True, 'reason': 'no_function_points'}
# Step 3: 创建 TB 任务
self.tasks = self.create_tasks(groups)
if not self.tasks:
logger.warning("[TBDispatcher] No tasks created, falling back to single TB")
return {'enabled': True, 'fallback': True, 'reason': 'no_tasks'}
# Step 4: 执行任务
if self.parallel:
# 并行执行
with ThreadPoolExecutor(max_workers=len(self.tasks)) as executor:
futures = {executor.submit(self.run_single_task, task): task for task in self.tasks}
for future in as_completed(futures):
task = futures[future]
try:
result = future.result()
self.results.append(result)
except Exception as e:
logger.error(f"[TBDispatcher] Task {task.name} failed: {e}")
self.results.append({
'name': task.name,
'error': str(e),
'passed': False
})
else:
# 顺序执行
for task in self.tasks:
result = self.run_single_task(task)
self.results.append(result)
# Step 5: 合并覆盖率
self.merged_coverage = self.merge_results()
logger.info(f"[TBDispatcher] All tasks completed. Merged coverage: {self.merged_coverage.get('coverage_percent', 0):.2f}%")
return self.merged_coverage
def merge_results(self) -> Dict:
"""
合并所有 TB 任务的结果
Returns:
merged: 合并后的结果
"""
total_coverage = 0.0
passed_count = 0
total_tasks = len(self.results)
coverage_reports = []
unreachable_reports = []
for result in self.results:
if result.get('coverage', 0) > 0:
total_coverage += result['coverage']
if result.get('passed', False):
passed_count += 1
# 收集报告文件路径
tb_dir = result.get('tb_dir', '')
cov_file = os.path.join(tb_dir, 'CGA', 'coverage.dat')
if os.path.exists(cov_file):
coverage_reports.append(cov_file)
unreach_file = os.path.join(tb_dir, 'CGA', 'unreachable_branches_report.txt')
if os.path.exists(unreach_file):
unreachable_reports.append(unreach_file)
# 计算平均覆盖率
avg_coverage = total_coverage / total_tasks if total_tasks > 0 else 0.0
# 合并 coverage.dat 文件
merged_cov_file = os.path.join(self.task_dir, "CGA_merged", "coverage.dat")
merged_unreach_file = os.path.join(self.task_dir, "CGA_merged", "unreachable_branches_report.txt")
merger = CoverageMerger()
coverage_info = merger.merge_coverage_dat(coverage_reports, merged_cov_file) if coverage_reports else {}
unreachable_info = merger.merge_unreachable_reports(unreachable_reports, merged_unreach_file) if unreachable_reports else {}
merged = {
'enabled': True,
'total_tasks': total_tasks,
'passed_tasks': passed_count,
'average_coverage': avg_coverage,
'merged_coverage': coverage_info,
'merged_unreachable': unreachable_info,
'results': self.results
}
return merged
def generate_summary_report(self) -> str:
"""
生成多 TB 测试汇总报告
Returns:
report: 格式化的报告字符串
"""
if not self.merged_coverage:
return "No results to report"
lines = []
lines.append("=" * 70)
lines.append("MULTI-TB COVERAGE SUMMARY REPORT")
lines.append("=" * 70)
lines.append("")
lines.append(f"Total TB Tasks: {self.merged_coverage.get('total_tasks', 0)}")
lines.append(f"Passed Tasks: {self.merged_coverage.get('passed_tasks', 0)}")
lines.append(f"Average Coverage: {self.merged_coverage.get('average_coverage', 0):.2f}%")
lines.append("")
# 各任务详情
lines.append("Individual TB Results:")
lines.append("-" * 70)
for result in self.results:
status = "PASS" if result.get('passed') else "FAIL"
coverage = result.get('coverage', 0)
lines.append(f" [{status}] {result.get('name', 'unknown')}")
lines.append(f" Targets: {', '.join(result.get('targets', []))}")
lines.append(f" Coverage: {coverage:.2f}%")
if result.get('error'):
lines.append(f" Error: {result['error']}")
lines.append("")
# 合并覆盖率
merged_cov = self.merged_coverage.get('merged_coverage', {})
if merged_cov:
lines.append("-" * 70)
lines.append("Merged Coverage:")
lines.append(f" Total Lines: {merged_cov.get('total_lines', 0)}")
lines.append(f" Covered Lines: {merged_cov.get('covered_lines', 0)}")
lines.append(f" Coverage: {merged_cov.get('coverage_percent', 0):.2f}%")
lines.append("")
# 不可达分支
merged_unreach = self.merged_coverage.get('merged_unreachable', {})
if merged_unreach:
lines.append("-" * 70)
lines.append("Unreachable Branches:")
lines.append(f" Truly Unreachable: {merged_unreach.get('total_unreachable', 0)}")
lines.append(f" Potentially Coverable: {merged_unreach.get('total_potentially_coverable', 0)}")
lines.append("")
lines.append("=" * 70)
return "\n".join(lines)
def create_dispatcher(prob_data: Dict, config, task_dir: str) -> TBDispatcher:
"""
工厂函数:创建 TB Dispatcher
Args:
prob_data: 题目数据
config: 配置对象
task_dir: 任务目录
Returns:
dispatcher: TBDispatcher 实例
"""
return TBDispatcher(prob_data, config, task_dir)
if __name__ == "__main__":
# 简单的测试代码
import logging
logging.basicConfig(level=logging.INFO)
# 模拟数据
sample_rtl = """
module example_fsm (
input clk, rst_n,
input [1:0] cmd,
output reg [3:0] out
);
localparam IDLE = 2'b00;
localparam RUN = 2'b01;
localparam DONE = 2'b10;
reg [1:0] state, next_state;
always @(posedge clk or negedge rst_n) begin
if (!rst_n)
state <= IDLE;
else
state <= next_state;
end
always @(*) begin
case (state)
IDLE: next_state = cmd[0] ? RUN : IDLE;
RUN: next_state = DONE;
DONE: next_state = IDLE;
default: next_state = IDLE;
endcase
end
always @(posedge clk) begin
if (state == RUN)
out <= out + 1;
end
endmodule
"""
prob_data = {
'task_id': 'test_fsm',
'module_code': sample_rtl,
'header': 'module example_fsm (input clk, rst_n, input [1:0] cmd, output [3:0] out);',
'description': 'Test FSM module'
}
class MockConfig:
class AutoLine:
multi_tb = None
autoline = AutoLine()
config = MockConfig()
# 测试 dispatcher
dispatcher = TBDispatcher(prob_data, config, "/tmp/test_dispatcher")
dispatcher.enabled = True # 强制启用以便测试
# 分析 RTL
result = dispatcher.analyze_rtl()
print(f"\nFunction Points Found: {len(result.get('function_points', []))}")
# 分组
groups = dispatcher.group_function_points(result)
print(f"\nGroups: {groups}")
# 创建任务
tasks = dispatcher.create_tasks(groups)
print(f"\nTasks Created: {len(tasks)}")
for task in tasks:
print(f" - {task.name}: {task.targets}")
print("\n" + "=" * 50)
print("Test completed successfully")