""" TB Generator - 根据DUT和项目要求生成Testbench(完整版) 支持多阶段流程: TBgen → TBsim → TBcheck → CGA → TBeval 用法: from run_tbgen import generate_tb tb_path, result = generate_tb( dut_code="module example(...); endmodule", description="项目描述", header="module example(input clk, ...);", task_id="my_task", model="qwen-max" ) """ import os import sys from pathlib import Path # 添加项目路径 PROJECT_ROOT = Path(__file__).parent sys.path.insert(0, str(PROJECT_ROOT)) # 延迟导入和初始化,避免单例顺序问题 _config_instance = None _auto_logger_instance = None def _ensure_init(): """确保Config和AutoLogger正确初始化""" global _config_instance, _auto_logger_instance if _config_instance is None: # 创建临时配置文件 _temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml") _config_content = """ run: mode: 'autoline' gpt: model: "qwen-max" key_path: "config/key_API.json" save: en: True root: "./output/" autoline: cga: enabled: True max_iter: 10 promptscript: "pychecker" onlyrun: "TBgensimeval" """ with open(_temp_config_path, "w") as f: f.write(_config_content) # 必须在导入autoline之前创建Config from config import Config _config_instance = Config(_temp_config_path) # 初始化AutoLogger from loader_saver import AutoLogger _auto_logger_instance = AutoLogger() return _config_instance def _create_config_for_task(task_id, model, enable_cga, cga_iter): """为特定任务创建配置""" config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml") config_content = f""" run: mode: 'autoline' gpt: model: "{model}" key_path: "config/key_API.json" save: en: True root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/" autoline: cga: enabled: {enable_cga} max_iter: {cga_iter} promptscript: "pychecker" onlyrun: "TBgensimeval" """ with open(config_path, "w") as f: f.write(config_content) from config import Config return Config(config_path) class TBGenerator: """完整的TB生成器,支持多阶段流程""" def __init__(self, api_key_path="config/key_API.json", model="qwen-max"): self.api_key_path = api_key_path self.model = model def generate(self, dut_code, description, header, task_id="test", enable_cga=True, cga_iter=10): """ 生成Testbench(完整流程) 参数: dut_code: str, DUT的Verilog代码 description: str, 项目描述/需求 header: str, DUT的module header task_id: str, 任务ID enable_cga: bool, 是否启用CGA优化 cga_iter: int, CGA最大迭代次数 返回: dict: 包含TB代码、评估结果等 """ # 确保单例初始化 _ensure_init() # 导入autoline(现在可以安全导入了) from autoline.TB_autoline import AutoLine_Task # 构建prob_data(符合HDLBitsProbset格式) prob_data = { "task_id": task_id, "task_number": 1, "description": description, "header": header, "module_code": dut_code, "testbench": None, "mutants": [], "llmgen_RTL": [] } # 为任务创建配置 config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter) # 创建任务并运行 task = AutoLine_Task(prob_data, config) task.run() return { "TB_code_v": task.TB_code_v, "TB_code_py": task.TB_code_py, "run_info": task.run_info, "cga_coverage": task.cga_coverage, "full_pass": task.full_pass } def generate_tb(dut_code, description, header, task_id, api_key_path="config/key_API.json", model="qwen-max", enable_cga=True, output_dir="./output"): """ 便捷函数:生成TB并保存 参数: dut_code: str, DUT代码 description: str, 项目描述 header: str, module header task_id: str, 任务ID api_key_path: str, API密钥路径 model: str, 使用的模型 enable_cga: bool, 是否启用CGA output_dir: str, 输出目录 返回: tuple: (TB文件路径, 结果字典) """ generator = TBGenerator(api_key_path, model) result = generator.generate(dut_code, description, header, task_id, enable_cga) os.makedirs(output_dir, exist_ok=True) tb_path = os.path.join(output_dir, f"{task_id}_tb.v") from loader_saver import save_code save_code(result["TB_code_v"], tb_path) return tb_path, result if __name__ == "__main__": # 示例用法 example_dut = """ module example( input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y ); assign y = a * b; endmodule """ example_desc = "一个8位乘法器,输入两个8位无符号数,输出16位乘积" example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);" print("Generating TB for example multiplier...") tb_path, result = generate_tb( dut_code=example_dut, description=example_desc, header=example_header, task_id="example_mul", model="qwen-max" ) print(f"TB saved to: {tb_path}") print(f"Coverage: {result.get('cga_coverage', 0)}") print(f"Full Pass: {result.get('full_pass', False)}")