Files
TBgen_App/run_tbgen.py

210 lines
5.7 KiB
Python
Raw Permalink Normal View History

2026-03-30 16:46:48 +08:00
"""
TB Generator - 根据DUT和项目要求生成Testbench完整版
支持多阶段流程: TBgen TBsim TBcheck CGA TBeval
用法:
from run_tbgen import generate_tb
tb_path, result = generate_tb(
dut_code="module example(...); endmodule",
description="项目描述",
header="module example(input clk, ...);",
task_id="my_task",
model="qwen-max"
)
"""
import os
import sys
from pathlib import Path
# 添加项目路径
PROJECT_ROOT = Path(__file__).parent
sys.path.insert(0, str(PROJECT_ROOT))
# 延迟导入和初始化,避免单例顺序问题
_config_instance = None
_auto_logger_instance = None
def _ensure_init():
"""确保Config和AutoLogger正确初始化"""
global _config_instance, _auto_logger_instance
if _config_instance is None:
# 创建临时配置文件
_temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
_config_content = """
run:
mode: 'autoline'
gpt:
model: "qwen-max"
key_path: "config/key_API.json"
save:
en: True
root: "./output/"
autoline:
cga:
enabled: True
max_iter: 10
promptscript: "pychecker"
onlyrun: "TBgensimeval"
"""
with open(_temp_config_path, "w") as f:
f.write(_config_content)
# 必须在导入autoline之前创建Config
from config import Config
_config_instance = Config(_temp_config_path)
# 初始化AutoLogger
from loader_saver import AutoLogger
_auto_logger_instance = AutoLogger()
return _config_instance
def _create_config_for_task(task_id, model, enable_cga, cga_iter):
"""为特定任务创建配置"""
config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
config_content = f"""
run:
mode: 'autoline'
gpt:
model: "{model}"
key_path: "config/key_API.json"
save:
en: True
root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/"
autoline:
cga:
enabled: {enable_cga}
max_iter: {cga_iter}
promptscript: "pychecker"
onlyrun: "TBgensimeval"
"""
with open(config_path, "w") as f:
f.write(config_content)
from config import Config
return Config(config_path)
class TBGenerator:
"""完整的TB生成器支持多阶段流程"""
def __init__(self, api_key_path="config/key_API.json", model="qwen-max"):
self.api_key_path = api_key_path
self.model = model
def generate(self, dut_code, description, header, task_id="test",
enable_cga=True, cga_iter=10):
"""
生成Testbench完整流程
参数:
dut_code: str, DUT的Verilog代码
description: str, 项目描述/需求
header: str, DUT的module header
task_id: str, 任务ID
enable_cga: bool, 是否启用CGA优化
cga_iter: int, CGA最大迭代次数
返回:
dict: 包含TB代码评估结果等
"""
# 确保单例初始化
_ensure_init()
# 导入autoline现在可以安全导入了
from autoline.TB_autoline import AutoLine_Task
# 构建prob_data符合HDLBitsProbset格式
prob_data = {
"task_id": task_id,
"task_number": 1,
"description": description,
"header": header,
"module_code": dut_code,
"testbench": None,
"mutants": [],
"llmgen_RTL": []
}
# 为任务创建配置
config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter)
# 创建任务并运行
task = AutoLine_Task(prob_data, config)
task.run()
return {
"TB_code_v": task.TB_code_v,
"TB_code_py": task.TB_code_py,
"run_info": task.run_info,
"cga_coverage": task.cga_coverage,
"full_pass": task.full_pass
}
def generate_tb(dut_code, description, header, task_id,
api_key_path="config/key_API.json",
model="qwen-max",
enable_cga=True,
output_dir="./output"):
"""
便捷函数生成TB并保存
参数:
dut_code: str, DUT代码
description: str, 项目描述
header: str, module header
task_id: str, 任务ID
api_key_path: str, API密钥路径
model: str, 使用的模型
enable_cga: bool, 是否启用CGA
output_dir: str, 输出目录
返回:
tuple: (TB文件路径, 结果字典)
"""
generator = TBGenerator(api_key_path, model)
result = generator.generate(dut_code, description, header, task_id, enable_cga)
os.makedirs(output_dir, exist_ok=True)
tb_path = os.path.join(output_dir, f"{task_id}_tb.v")
from loader_saver import save_code
save_code(result["TB_code_v"], tb_path)
return tb_path, result
if __name__ == "__main__":
# 示例用法
example_dut = """
module example(
input clk,
input rst,
input [7:0] a,
input [7:0] b,
output [15:0] y
);
assign y = a * b;
endmodule
"""
example_desc = "一个8位乘法器输入两个8位无符号数输出16位乘积"
example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);"
print("Generating TB for example multiplier...")
tb_path, result = generate_tb(
dut_code=example_dut,
description=example_desc,
header=example_header,
task_id="example_mul",
model="qwen-max"
)
print(f"TB saved to: {tb_path}")
print(f"Coverage: {result.get('cga_coverage', 0)}")
print(f"Full Pass: {result.get('full_pass', False)}")