Files
TBgen_App/run_tbgen.py
2026-03-30 16:46:48 +08:00

210 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
TB Generator - 根据DUT和项目要求生成Testbench完整版
支持多阶段流程: TBgen → TBsim → TBcheck → CGA → TBeval
用法:
from run_tbgen import generate_tb
tb_path, result = generate_tb(
dut_code="module example(...); endmodule",
description="项目描述",
header="module example(input clk, ...);",
task_id="my_task",
model="qwen-max"
)
"""
import os
import sys
from pathlib import Path
# 添加项目路径
PROJECT_ROOT = Path(__file__).parent
sys.path.insert(0, str(PROJECT_ROOT))
# 延迟导入和初始化,避免单例顺序问题
_config_instance = None
_auto_logger_instance = None
def _ensure_init():
"""确保Config和AutoLogger正确初始化"""
global _config_instance, _auto_logger_instance
if _config_instance is None:
# 创建临时配置文件
_temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
_config_content = """
run:
mode: 'autoline'
gpt:
model: "qwen-max"
key_path: "config/key_API.json"
save:
en: True
root: "./output/"
autoline:
cga:
enabled: True
max_iter: 10
promptscript: "pychecker"
onlyrun: "TBgensimeval"
"""
with open(_temp_config_path, "w") as f:
f.write(_config_content)
# 必须在导入autoline之前创建Config
from config import Config
_config_instance = Config(_temp_config_path)
# 初始化AutoLogger
from loader_saver import AutoLogger
_auto_logger_instance = AutoLogger()
return _config_instance
def _create_config_for_task(task_id, model, enable_cga, cga_iter):
"""为特定任务创建配置"""
config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
config_content = f"""
run:
mode: 'autoline'
gpt:
model: "{model}"
key_path: "config/key_API.json"
save:
en: True
root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/"
autoline:
cga:
enabled: {enable_cga}
max_iter: {cga_iter}
promptscript: "pychecker"
onlyrun: "TBgensimeval"
"""
with open(config_path, "w") as f:
f.write(config_content)
from config import Config
return Config(config_path)
class TBGenerator:
"""完整的TB生成器支持多阶段流程"""
def __init__(self, api_key_path="config/key_API.json", model="qwen-max"):
self.api_key_path = api_key_path
self.model = model
def generate(self, dut_code, description, header, task_id="test",
enable_cga=True, cga_iter=10):
"""
生成Testbench完整流程
参数:
dut_code: str, DUT的Verilog代码
description: str, 项目描述/需求
header: str, DUT的module header
task_id: str, 任务ID
enable_cga: bool, 是否启用CGA优化
cga_iter: int, CGA最大迭代次数
返回:
dict: 包含TB代码、评估结果等
"""
# 确保单例初始化
_ensure_init()
# 导入autoline现在可以安全导入了
from autoline.TB_autoline import AutoLine_Task
# 构建prob_data符合HDLBitsProbset格式
prob_data = {
"task_id": task_id,
"task_number": 1,
"description": description,
"header": header,
"module_code": dut_code,
"testbench": None,
"mutants": [],
"llmgen_RTL": []
}
# 为任务创建配置
config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter)
# 创建任务并运行
task = AutoLine_Task(prob_data, config)
task.run()
return {
"TB_code_v": task.TB_code_v,
"TB_code_py": task.TB_code_py,
"run_info": task.run_info,
"cga_coverage": task.cga_coverage,
"full_pass": task.full_pass
}
def generate_tb(dut_code, description, header, task_id,
api_key_path="config/key_API.json",
model="qwen-max",
enable_cga=True,
output_dir="./output"):
"""
便捷函数生成TB并保存
参数:
dut_code: str, DUT代码
description: str, 项目描述
header: str, module header
task_id: str, 任务ID
api_key_path: str, API密钥路径
model: str, 使用的模型
enable_cga: bool, 是否启用CGA
output_dir: str, 输出目录
返回:
tuple: (TB文件路径, 结果字典)
"""
generator = TBGenerator(api_key_path, model)
result = generator.generate(dut_code, description, header, task_id, enable_cga)
os.makedirs(output_dir, exist_ok=True)
tb_path = os.path.join(output_dir, f"{task_id}_tb.v")
from loader_saver import save_code
save_code(result["TB_code_v"], tb_path)
return tb_path, result
if __name__ == "__main__":
# 示例用法
example_dut = """
module example(
input clk,
input rst,
input [7:0] a,
input [7:0] b,
output [15:0] y
);
assign y = a * b;
endmodule
"""
example_desc = "一个8位乘法器输入两个8位无符号数输出16位乘积"
example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);"
print("Generating TB for example multiplier...")
tb_path, result = generate_tb(
dut_code=example_dut,
description=example_desc,
header=example_header,
task_id="example_mul",
model="qwen-max"
)
print(f"TB saved to: {tb_path}")
print(f"Coverage: {result.get('cga_coverage', 0)}")
print(f"Full Pass: {result.get('full_pass', False)}")