701 lines
24 KiB
Python
701 lines
24 KiB
Python
"""
|
||
Description : TB Dispatcher - Multi-TB Generation Coordinator
|
||
Coordinates generation of multiple targeted TBs based on RTL functional analysis
|
||
Author : CorrectBench
|
||
Time : 2026/04/19
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
from autoline.semantic_analyzer import SemanticAnalyzer
|
||
from autoline.semantic_analyzer import FunctionPointType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TBTask:
|
||
"""
|
||
单个 TB 任务
|
||
"""
|
||
def __init__(self, name: str, targets: List[str], fp_type: str,
|
||
description: str = "", task_dir: str = None):
|
||
self.name = name
|
||
self.targets = targets # 功能点名称列表
|
||
self.fp_type = fp_type # 功能点类型 (fsm, counter, condition, protocol, exception)
|
||
self.description = description
|
||
self.task_dir = task_dir
|
||
|
||
self.tb_code_v = None # Verilog TB 代码
|
||
self.tb_code_py = None # Python 规则代码
|
||
self.coverage = 0.0 # 覆盖率
|
||
self.passed = False # 是否通过检查
|
||
self.error = None # 错误信息
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
'name': self.name,
|
||
'targets': self.targets,
|
||
'fp_type': self.fp_type,
|
||
'description': self.description,
|
||
'coverage': self.coverage,
|
||
'passed': self.passed,
|
||
'error': self.error
|
||
}
|
||
|
||
|
||
class CoverageMerger:
|
||
"""
|
||
合并多个 TB 的覆盖率报告
|
||
"""
|
||
|
||
@staticmethod
|
||
def merge_coverage_dat(coverage_files: List[str], output_file: str) -> Dict:
|
||
"""
|
||
合并多个 coverage.dat 文件
|
||
|
||
Args:
|
||
coverage_files: coverage.dat 文件路径列表
|
||
output_file: 输出合并后的文件路径
|
||
|
||
Returns:
|
||
merged_info: {
|
||
'total_lines': int,
|
||
'covered_lines': int,
|
||
'coverage_percent': float,
|
||
'file': str
|
||
}
|
||
"""
|
||
all_lines = {}
|
||
covered_count = 0
|
||
|
||
for cov_file in coverage_files:
|
||
if not os.path.exists(cov_file):
|
||
logger.warning(f"Coverage file not found: {cov_file}")
|
||
continue
|
||
|
||
try:
|
||
with open(cov_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# 解析 Verilator coverage 格式
|
||
# %NNNNNN code 或 ~NNNNNN code
|
||
parts = line.split(None, 1)
|
||
if len(parts) < 2:
|
||
continue
|
||
|
||
count_str = parts[0]
|
||
code = parts[1] if len(parts) > 1 else ""
|
||
|
||
# 提取行号(如果存在)
|
||
if count_str.startswith('%') or count_str.startswith('~') or count_str.startswith('^'):
|
||
count = int(count_str[1:])
|
||
# 使用代码内容作为 key 来去重
|
||
if code not in all_lines or count > 0:
|
||
all_lines[code] = count
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error reading coverage file {cov_file}: {e}")
|
||
|
||
# 计算覆盖率
|
||
total_lines = len(all_lines)
|
||
covered_lines = sum(1 for c in all_lines.values() if c > 0)
|
||
coverage_percent = (covered_lines / total_lines * 100) if total_lines > 0 else 0.0
|
||
|
||
# 写入合并后的文件
|
||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for code, count in all_lines.items():
|
||
if code.startswith('%') or code.startswith('~'):
|
||
f.write(f"{count}{code}\n")
|
||
else:
|
||
f.write(f"%{count} {code}\n")
|
||
|
||
return {
|
||
'total_lines': total_lines,
|
||
'covered_lines': covered_lines,
|
||
'coverage_percent': coverage_percent,
|
||
'file': output_file
|
||
}
|
||
|
||
@staticmethod
|
||
def merge_unreachable_reports(report_files: List[str], output_file: str) -> Dict:
|
||
"""
|
||
合并多个不可达分支报告
|
||
|
||
Args:
|
||
report_files: unreachable_branches_report.txt 文件列表
|
||
output_file: 输出合并后的文件路径
|
||
|
||
Returns:
|
||
merged_info: {
|
||
'total_unreachable': int,
|
||
'total_potentially_coverable': int,
|
||
'reports': list of individual reports
|
||
}
|
||
"""
|
||
total_unreachable = 0
|
||
total_potentially_coverable = 0
|
||
all_reports = []
|
||
|
||
for report_file in report_files:
|
||
if not os.path.exists(report_file):
|
||
continue
|
||
|
||
try:
|
||
with open(report_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 简单解析
|
||
unreachable_count = content.count("Truly unreachable")
|
||
potentially_count = content.count("Potentially coverable")
|
||
|
||
total_unreachable += unreachable_count
|
||
total_potentially_coverable += potentially_count
|
||
all_reports.append({
|
||
'file': report_file,
|
||
'unreachable': unreachable_count,
|
||
'potentially_coverable': potentially_count
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error reading report {report_file}: {e}")
|
||
|
||
# 写入合并报告
|
||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
f.write("=" * 60 + "\n")
|
||
f.write("MERGED UNREACHABLE BRANCHES REPORT\n")
|
||
f.write("=" * 60 + "\n\n")
|
||
|
||
for rp in all_reports:
|
||
f.write(f"From: {rp['file']}\n")
|
||
f.write(f" Truly unreachable: {rp['unreachable']}\n")
|
||
f.write(f" Potentially coverable: {rp['potentially_coverable']}\n\n")
|
||
|
||
f.write("=" * 60 + "\n")
|
||
f.write(f"Total Truly Unreachable: {total_unreachable}\n")
|
||
f.write(f"Total Potentially Coverable: {total_potentially_coverable}\n")
|
||
f.write("=" * 60 + "\n")
|
||
|
||
return {
|
||
'total_unreachable': total_unreachable,
|
||
'total_potentially_coverable': total_potentially_coverable,
|
||
'reports': all_reports,
|
||
'file': output_file
|
||
}
|
||
|
||
|
||
class TBDispatcher:
|
||
"""
|
||
TB 调度器 - 协调多 TB 生成
|
||
|
||
工作流程:
|
||
1. 使用语义分析器分析 RTL,识别功能区域
|
||
2. 按功能类型分组 (FSM, Counter, Protocol, etc.)
|
||
3. 为每个分组创建 TB 任务
|
||
4. 调度执行各 TB 任务
|
||
5. 合并覆盖率报告
|
||
"""
|
||
|
||
def __init__(self, prob_data: Dict, config, task_dir: str):
|
||
"""
|
||
Args:
|
||
prob_data: 题目数据,包含 module_code, description, header 等
|
||
config: 配置对象
|
||
task_dir: 任务工作目录
|
||
"""
|
||
self.prob_data = prob_data
|
||
self.config = config
|
||
self.task_dir = task_dir
|
||
|
||
# 获取多 TB 配置
|
||
multi_tb_config = getattr(config.autoline, 'multi_tb', None) if hasattr(config, 'autoline') else None
|
||
self.enabled = getattr(multi_tb_config, 'enabled', False) if multi_tb_config else False
|
||
self.strategy = getattr(multi_tb_config, 'strategy', 'functional') if multi_tb_config else 'functional'
|
||
self.max_tb_count = getattr(multi_tb_config, 'max_tb_count', 5) if multi_tb_config else 5
|
||
self.parallel = getattr(multi_tb_config, 'parallel', False) if multi_tb_config else False
|
||
self.auto_threshold_lines = getattr(multi_tb_config, 'auto_threshold_lines', 500) if multi_tb_config else 500
|
||
|
||
self.tasks: List[TBTask] = []
|
||
self.results: List[Dict] = []
|
||
self.merged_coverage = {}
|
||
|
||
def _count_rtl_lines(self) -> int:
|
||
"""统计 RTL 代码行数(排除空行和注释)"""
|
||
code = self.prob_data.get('module_code', '')
|
||
lines = code.split('\n')
|
||
count = 0
|
||
for line in lines:
|
||
stripped = line.strip()
|
||
# 排除空行和注释行
|
||
if stripped and not stripped.startswith('//') and not stripped.startswith('/*'):
|
||
count += 1
|
||
return count
|
||
|
||
def should_use_multi_tb(self) -> bool:
|
||
"""
|
||
判断是否应该使用多 TB 模式
|
||
|
||
Returns:
|
||
True: 使用多 TB 模式
|
||
False: 使用单 TB 模式
|
||
"""
|
||
# 1. 如果配置明确启用,使用多 TB
|
||
if self.enabled:
|
||
return True
|
||
|
||
# 2. 根据 RTL 行数自动判断
|
||
rtl_lines = self._count_rtl_lines()
|
||
if rtl_lines > self.auto_threshold_lines:
|
||
logger.info(f"[TBDispatcher] RTL has {rtl_lines} lines (>{self.auto_threshold_lines}), enabling multi-TB mode")
|
||
return True
|
||
|
||
return False
|
||
|
||
def analyze_rtl(self) -> Dict[str, Any]:
|
||
"""
|
||
使用语义分析器分析 RTL
|
||
|
||
Returns:
|
||
analysis_result: 语义分析结果
|
||
"""
|
||
logger.info(f"[TBDispatcher] Analyzing RTL for task: {self.prob_data.get('task_id', 'unknown')}")
|
||
|
||
analyzer = SemanticAnalyzer(self.prob_data['module_code'])
|
||
result = analyzer.analyze()
|
||
|
||
logger.info(f"[TBDispatcher] Found {len(result.get('function_points', []))} function points")
|
||
return result
|
||
|
||
def group_function_points(self, analysis_result: Dict[str, Any]) -> Dict[str, List[Dict]]:
|
||
"""
|
||
按功能类型分组功能点,明确每个 TB 的侧重点
|
||
|
||
Returns:
|
||
groups: {
|
||
'main': [...], # 主测试,包含所有功能点
|
||
'fsm': [...], # FSM 状态机专项
|
||
'protocol': [...], # 协议接口专项
|
||
'datapath': [...], # 数据通路/计数器专项
|
||
'exception': [...] # 异常处理专项
|
||
}
|
||
"""
|
||
# 初始化分组
|
||
groups = {
|
||
'main': [], # 主测试(所有功能点)
|
||
'fsm': [], # FSM 状态机
|
||
'protocol': [], # 协议接口
|
||
'datapath': [], # 数据通路/计数器/条件分支
|
||
'exception': [] # 异常处理
|
||
}
|
||
|
||
all_fps = analysis_result.get('function_points', [])
|
||
|
||
# 根据功能点类型分配到对应组
|
||
for fp in all_fps:
|
||
fp_type = fp.get('type', 'condition')
|
||
fp_name = fp.get('name', '').lower()
|
||
|
||
if 'fsm' in fp_name or fp_type == 'fsm':
|
||
groups['fsm'].append(fp)
|
||
elif 'protocol' in fp_name or fp_type == 'protocol':
|
||
groups['protocol'].append(fp)
|
||
elif 'counter' in fp_name or 'datapath' in fp_name or fp_type == 'counter':
|
||
groups['datapath'].append(fp)
|
||
elif 'exception' in fp_name or fp_type == 'exception':
|
||
groups['exception'].append(fp)
|
||
else:
|
||
# 默认归入 datapath
|
||
groups['datapath'].append(fp)
|
||
|
||
# main 组包含所有功能点
|
||
groups['main'] = all_fps
|
||
|
||
# 过滤空组
|
||
groups = {k: v for k, v in groups.items() if v}
|
||
|
||
logger.info(f"[TBDispatcher] Grouped function points: { {k: len(v) for k, v in groups.items()} }")
|
||
return groups
|
||
|
||
def create_tasks(self, groups: Dict[str, List[Dict]]) -> List[TBTask]:
|
||
"""
|
||
为每个功能组创建 TB 任务,明确侧重点
|
||
|
||
每个 TB 的侧重点:
|
||
- tb_main: 主测试,尝试覆盖所有功能点
|
||
- tb_fsm: FSM 状态机专项,聚焦状态转换和状态覆盖
|
||
- tb_protocol: 协议接口专项,聚焦握手时序和数据传输
|
||
- tb_datapath: 数据通路专项,聚焦数据处理和计数器
|
||
- tb_exception: 异常处理专项,聚焦边界条件和错误处理
|
||
|
||
Args:
|
||
groups: 按类型分组的功能点
|
||
|
||
Returns:
|
||
tasks: TB 任务列表
|
||
"""
|
||
tasks = []
|
||
|
||
# 定义每个 TB 的侧重点描述
|
||
focus_descriptions = {
|
||
'main': 'Comprehensive test covering all functional points',
|
||
'fsm': 'FSM state machine test - focus on state transitions and state coverage',
|
||
'protocol': 'Protocol interface test - focus on handshake timing and data transfer',
|
||
'datapath': 'Datapath test - focus on data processing and counter logic',
|
||
'exception': 'Exception handling test - focus on boundary conditions and error handling'
|
||
}
|
||
|
||
# 按优先级创建任务(protocol > fsm > datapath > exception > main)
|
||
priority_order = ['protocol', 'fsm', 'datapath', 'exception', 'main']
|
||
|
||
for fp_type in priority_order:
|
||
if fp_type not in groups or not groups[fp_type]:
|
||
continue
|
||
|
||
if len(tasks) >= self.max_tb_count:
|
||
logger.info(f"[TBDispatcher] Reached max TB count ({self.max_tb_count}), stopping task creation")
|
||
break
|
||
|
||
fps = groups[fp_type]
|
||
|
||
# 跳过只有1个功能点的组(除非是 main)
|
||
if len(fps) <= 1 and fp_type != 'main':
|
||
logger.info(f"[TBDispatcher] Skipping {fp_type} group (only {len(fps)} function point)")
|
||
continue
|
||
|
||
# 构建任务
|
||
task_name = f"tb_{fp_type}"
|
||
targets = [fp['name'] for fp in fps]
|
||
description = f"[{focus_descriptions.get(fp_type, 'Test')}] Targets: {', '.join(targets)}"
|
||
|
||
# 创建任务目录
|
||
task_subdir = os.path.join(self.task_dir, task_name)
|
||
os.makedirs(task_subdir, exist_ok=True)
|
||
|
||
task = TBTask(
|
||
name=task_name,
|
||
targets=targets,
|
||
fp_type=fp_type,
|
||
description=description,
|
||
task_dir=task_subdir
|
||
)
|
||
tasks.append(task)
|
||
|
||
logger.info(f"[TBDispatcher] Created task: {task_name}")
|
||
logger.info(f" Focus: {focus_descriptions.get(fp_type, 'N/A')}")
|
||
logger.info(f" Targets: {targets}")
|
||
|
||
logger.info(f"[TBDispatcher] Created {len(tasks)} TB tasks")
|
||
return tasks
|
||
|
||
def run_single_task(self, task: TBTask) -> Dict:
|
||
"""
|
||
运行单个 TB 任务
|
||
|
||
这个方法会被多次调用,每次处理一个功能组的 TB
|
||
|
||
Args:
|
||
task: TB 任务
|
||
|
||
Returns:
|
||
result: 任务执行结果
|
||
"""
|
||
logger.info(f"[TBDispatcher] Running task: {task.name}")
|
||
logger.info(f"[TBDispatcher] Targets: {task.targets}")
|
||
|
||
# TODO: 这里需要调用现有的 TBgen -> TBsim -> TBcheck -> CGA 流程
|
||
# 目前是占位实现
|
||
|
||
result = {
|
||
'name': task.name,
|
||
'targets': task.targets,
|
||
'fp_type': task.fp_type,
|
||
'coverage': 0.0,
|
||
'passed': False,
|
||
'error': None,
|
||
'tb_dir': task.task_dir
|
||
}
|
||
|
||
# 临时标记,后续集成后填充真实结果
|
||
result_file = os.path.join(task.task_dir, "result.json")
|
||
if os.path.exists(result_file):
|
||
try:
|
||
import json
|
||
with open(result_file, 'r') as f:
|
||
saved_result = json.load(f)
|
||
result.update(saved_result)
|
||
except:
|
||
pass
|
||
|
||
return result
|
||
|
||
def run_all(self) -> Dict:
|
||
"""
|
||
运行所有 TB 任务并合并结果
|
||
|
||
Returns:
|
||
merged_result: 合并后的结果
|
||
"""
|
||
# [修改] 使用 should_use_multi_tb() 判断是否启用,包括 RTL 行数自动判断
|
||
if not self.should_use_multi_tb():
|
||
logger.info("[TBDispatcher] Multi-TB mode not needed, skipping")
|
||
return {'enabled': False}
|
||
|
||
logger.info("[TBDispatcher] Starting multi-TB dispatch")
|
||
|
||
# Step 1: 分析 RTL
|
||
analysis_result = self.analyze_rtl()
|
||
|
||
# Step 2: 分组功能点
|
||
groups = self.group_function_points(analysis_result)
|
||
|
||
if not groups:
|
||
logger.warning("[TBDispatcher] No function points found, falling back to single TB")
|
||
return {'enabled': True, 'fallback': True, 'reason': 'no_function_points'}
|
||
|
||
# Step 3: 创建 TB 任务
|
||
self.tasks = self.create_tasks(groups)
|
||
|
||
if not self.tasks:
|
||
logger.warning("[TBDispatcher] No tasks created, falling back to single TB")
|
||
return {'enabled': True, 'fallback': True, 'reason': 'no_tasks'}
|
||
|
||
# Step 4: 执行任务
|
||
if self.parallel:
|
||
# 并行执行
|
||
with ThreadPoolExecutor(max_workers=len(self.tasks)) as executor:
|
||
futures = {executor.submit(self.run_single_task, task): task for task in self.tasks}
|
||
for future in as_completed(futures):
|
||
task = futures[future]
|
||
try:
|
||
result = future.result()
|
||
self.results.append(result)
|
||
except Exception as e:
|
||
logger.error(f"[TBDispatcher] Task {task.name} failed: {e}")
|
||
self.results.append({
|
||
'name': task.name,
|
||
'error': str(e),
|
||
'passed': False
|
||
})
|
||
else:
|
||
# 顺序执行
|
||
for task in self.tasks:
|
||
result = self.run_single_task(task)
|
||
self.results.append(result)
|
||
|
||
# Step 5: 合并覆盖率
|
||
self.merged_coverage = self.merge_results()
|
||
|
||
logger.info(f"[TBDispatcher] All tasks completed. Merged coverage: {self.merged_coverage.get('coverage_percent', 0):.2f}%")
|
||
|
||
return self.merged_coverage
|
||
|
||
def merge_results(self) -> Dict:
|
||
"""
|
||
合并所有 TB 任务的结果
|
||
|
||
Returns:
|
||
merged: 合并后的结果
|
||
"""
|
||
total_coverage = 0.0
|
||
passed_count = 0
|
||
total_tasks = len(self.results)
|
||
|
||
coverage_reports = []
|
||
unreachable_reports = []
|
||
|
||
for result in self.results:
|
||
if result.get('coverage', 0) > 0:
|
||
total_coverage += result['coverage']
|
||
if result.get('passed', False):
|
||
passed_count += 1
|
||
|
||
# 收集报告文件路径
|
||
tb_dir = result.get('tb_dir', '')
|
||
cov_file = os.path.join(tb_dir, 'CGA', 'coverage.dat')
|
||
if os.path.exists(cov_file):
|
||
coverage_reports.append(cov_file)
|
||
|
||
unreach_file = os.path.join(tb_dir, 'CGA', 'unreachable_branches_report.txt')
|
||
if os.path.exists(unreach_file):
|
||
unreachable_reports.append(unreach_file)
|
||
|
||
# 计算平均覆盖率
|
||
avg_coverage = total_coverage / total_tasks if total_tasks > 0 else 0.0
|
||
|
||
# 合并 coverage.dat 文件
|
||
merged_cov_file = os.path.join(self.task_dir, "CGA_merged", "coverage.dat")
|
||
merged_unreach_file = os.path.join(self.task_dir, "CGA_merged", "unreachable_branches_report.txt")
|
||
|
||
merger = CoverageMerger()
|
||
coverage_info = merger.merge_coverage_dat(coverage_reports, merged_cov_file) if coverage_reports else {}
|
||
unreachable_info = merger.merge_unreachable_reports(unreachable_reports, merged_unreach_file) if unreachable_reports else {}
|
||
|
||
merged = {
|
||
'enabled': True,
|
||
'total_tasks': total_tasks,
|
||
'passed_tasks': passed_count,
|
||
'average_coverage': avg_coverage,
|
||
'merged_coverage': coverage_info,
|
||
'merged_unreachable': unreachable_info,
|
||
'results': self.results
|
||
}
|
||
|
||
return merged
|
||
|
||
def generate_summary_report(self) -> str:
|
||
"""
|
||
生成多 TB 测试汇总报告
|
||
|
||
Returns:
|
||
report: 格式化的报告字符串
|
||
"""
|
||
if not self.merged_coverage:
|
||
return "No results to report"
|
||
|
||
lines = []
|
||
lines.append("=" * 70)
|
||
lines.append("MULTI-TB COVERAGE SUMMARY REPORT")
|
||
lines.append("=" * 70)
|
||
lines.append("")
|
||
|
||
lines.append(f"Total TB Tasks: {self.merged_coverage.get('total_tasks', 0)}")
|
||
lines.append(f"Passed Tasks: {self.merged_coverage.get('passed_tasks', 0)}")
|
||
lines.append(f"Average Coverage: {self.merged_coverage.get('average_coverage', 0):.2f}%")
|
||
lines.append("")
|
||
|
||
# 各任务详情
|
||
lines.append("Individual TB Results:")
|
||
lines.append("-" * 70)
|
||
for result in self.results:
|
||
status = "PASS" if result.get('passed') else "FAIL"
|
||
coverage = result.get('coverage', 0)
|
||
lines.append(f" [{status}] {result.get('name', 'unknown')}")
|
||
lines.append(f" Targets: {', '.join(result.get('targets', []))}")
|
||
lines.append(f" Coverage: {coverage:.2f}%")
|
||
if result.get('error'):
|
||
lines.append(f" Error: {result['error']}")
|
||
lines.append("")
|
||
|
||
# 合并覆盖率
|
||
merged_cov = self.merged_coverage.get('merged_coverage', {})
|
||
if merged_cov:
|
||
lines.append("-" * 70)
|
||
lines.append("Merged Coverage:")
|
||
lines.append(f" Total Lines: {merged_cov.get('total_lines', 0)}")
|
||
lines.append(f" Covered Lines: {merged_cov.get('covered_lines', 0)}")
|
||
lines.append(f" Coverage: {merged_cov.get('coverage_percent', 0):.2f}%")
|
||
lines.append("")
|
||
|
||
# 不可达分支
|
||
merged_unreach = self.merged_coverage.get('merged_unreachable', {})
|
||
if merged_unreach:
|
||
lines.append("-" * 70)
|
||
lines.append("Unreachable Branches:")
|
||
lines.append(f" Truly Unreachable: {merged_unreach.get('total_unreachable', 0)}")
|
||
lines.append(f" Potentially Coverable: {merged_unreach.get('total_potentially_coverable', 0)}")
|
||
lines.append("")
|
||
|
||
lines.append("=" * 70)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def create_dispatcher(prob_data: Dict, config, task_dir: str) -> TBDispatcher:
|
||
"""
|
||
工厂函数:创建 TB Dispatcher
|
||
|
||
Args:
|
||
prob_data: 题目数据
|
||
config: 配置对象
|
||
task_dir: 任务目录
|
||
|
||
Returns:
|
||
dispatcher: TBDispatcher 实例
|
||
"""
|
||
return TBDispatcher(prob_data, config, task_dir)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 简单的测试代码
|
||
import logging
|
||
logging.basicConfig(level=logging.INFO)
|
||
|
||
# 模拟数据
|
||
sample_rtl = """
|
||
module example_fsm (
|
||
input clk, rst_n,
|
||
input [1:0] cmd,
|
||
output reg [3:0] out
|
||
);
|
||
localparam IDLE = 2'b00;
|
||
localparam RUN = 2'b01;
|
||
localparam DONE = 2'b10;
|
||
|
||
reg [1:0] state, next_state;
|
||
|
||
always @(posedge clk or negedge rst_n) begin
|
||
if (!rst_n)
|
||
state <= IDLE;
|
||
else
|
||
state <= next_state;
|
||
end
|
||
|
||
always @(*) begin
|
||
case (state)
|
||
IDLE: next_state = cmd[0] ? RUN : IDLE;
|
||
RUN: next_state = DONE;
|
||
DONE: next_state = IDLE;
|
||
default: next_state = IDLE;
|
||
endcase
|
||
end
|
||
|
||
always @(posedge clk) begin
|
||
if (state == RUN)
|
||
out <= out + 1;
|
||
end
|
||
endmodule
|
||
"""
|
||
|
||
prob_data = {
|
||
'task_id': 'test_fsm',
|
||
'module_code': sample_rtl,
|
||
'header': 'module example_fsm (input clk, rst_n, input [1:0] cmd, output [3:0] out);',
|
||
'description': 'Test FSM module'
|
||
}
|
||
|
||
class MockConfig:
|
||
class AutoLine:
|
||
multi_tb = None
|
||
|
||
autoline = AutoLine()
|
||
|
||
config = MockConfig()
|
||
|
||
# 测试 dispatcher
|
||
dispatcher = TBDispatcher(prob_data, config, "/tmp/test_dispatcher")
|
||
dispatcher.enabled = True # 强制启用以便测试
|
||
|
||
# 分析 RTL
|
||
result = dispatcher.analyze_rtl()
|
||
print(f"\nFunction Points Found: {len(result.get('function_points', []))}")
|
||
|
||
# 分组
|
||
groups = dispatcher.group_function_points(result)
|
||
print(f"\nGroups: {groups}")
|
||
|
||
# 创建任务
|
||
tasks = dispatcher.create_tasks(groups)
|
||
print(f"\nTasks Created: {len(tasks)}")
|
||
for task in tasks:
|
||
print(f" - {task.name}: {task.targets}")
|
||
|
||
print("\n" + "=" * 50)
|
||
print("Test completed successfully")
|