210 lines
5.7 KiB
Python
210 lines
5.7 KiB
Python
"""
|
||
TB Generator - 根据DUT和项目要求生成Testbench(完整版)
|
||
支持多阶段流程: TBgen → TBsim → TBcheck → CGA → TBeval
|
||
|
||
用法:
|
||
from run_tbgen import generate_tb
|
||
tb_path, result = generate_tb(
|
||
dut_code="module example(...); endmodule",
|
||
description="项目描述",
|
||
header="module example(input clk, ...);",
|
||
task_id="my_task",
|
||
model="qwen-max"
|
||
)
|
||
"""
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目路径
|
||
PROJECT_ROOT = Path(__file__).parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
|
||
# 延迟导入和初始化,避免单例顺序问题
|
||
_config_instance = None
|
||
_auto_logger_instance = None
|
||
|
||
def _ensure_init():
|
||
"""确保Config和AutoLogger正确初始化"""
|
||
global _config_instance, _auto_logger_instance
|
||
|
||
if _config_instance is None:
|
||
# 创建临时配置文件
|
||
_temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
||
_config_content = """
|
||
run:
|
||
mode: 'autoline'
|
||
gpt:
|
||
model: "qwen-max"
|
||
key_path: "config/key_API.json"
|
||
save:
|
||
en: True
|
||
root: "./output/"
|
||
autoline:
|
||
cga:
|
||
enabled: True
|
||
max_iter: 10
|
||
promptscript: "pychecker"
|
||
onlyrun: "TBgensimeval"
|
||
"""
|
||
with open(_temp_config_path, "w") as f:
|
||
f.write(_config_content)
|
||
|
||
# 必须在导入autoline之前创建Config
|
||
from config import Config
|
||
_config_instance = Config(_temp_config_path)
|
||
|
||
# 初始化AutoLogger
|
||
from loader_saver import AutoLogger
|
||
_auto_logger_instance = AutoLogger()
|
||
|
||
return _config_instance
|
||
|
||
|
||
def _create_config_for_task(task_id, model, enable_cga, cga_iter):
|
||
"""为特定任务创建配置"""
|
||
config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
||
|
||
config_content = f"""
|
||
run:
|
||
mode: 'autoline'
|
||
gpt:
|
||
model: "{model}"
|
||
key_path: "config/key_API.json"
|
||
save:
|
||
en: True
|
||
root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/"
|
||
autoline:
|
||
cga:
|
||
enabled: {enable_cga}
|
||
max_iter: {cga_iter}
|
||
promptscript: "pychecker"
|
||
onlyrun: "TBgensimeval"
|
||
"""
|
||
with open(config_path, "w") as f:
|
||
f.write(config_content)
|
||
|
||
from config import Config
|
||
return Config(config_path)
|
||
|
||
|
||
class TBGenerator:
|
||
"""完整的TB生成器,支持多阶段流程"""
|
||
|
||
def __init__(self, api_key_path="config/key_API.json", model="qwen-max"):
|
||
self.api_key_path = api_key_path
|
||
self.model = model
|
||
|
||
def generate(self, dut_code, description, header, task_id="test",
|
||
enable_cga=True, cga_iter=10):
|
||
"""
|
||
生成Testbench(完整流程)
|
||
|
||
参数:
|
||
dut_code: str, DUT的Verilog代码
|
||
description: str, 项目描述/需求
|
||
header: str, DUT的module header
|
||
task_id: str, 任务ID
|
||
enable_cga: bool, 是否启用CGA优化
|
||
cga_iter: int, CGA最大迭代次数
|
||
|
||
返回:
|
||
dict: 包含TB代码、评估结果等
|
||
"""
|
||
# 确保单例初始化
|
||
_ensure_init()
|
||
|
||
# 导入autoline(现在可以安全导入了)
|
||
from autoline.TB_autoline import AutoLine_Task
|
||
|
||
# 构建prob_data(符合HDLBitsProbset格式)
|
||
prob_data = {
|
||
"task_id": task_id,
|
||
"task_number": 1,
|
||
"description": description,
|
||
"header": header,
|
||
"module_code": dut_code,
|
||
"testbench": None,
|
||
"mutants": [],
|
||
"llmgen_RTL": []
|
||
}
|
||
|
||
# 为任务创建配置
|
||
config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter)
|
||
|
||
# 创建任务并运行
|
||
task = AutoLine_Task(prob_data, config)
|
||
task.run()
|
||
|
||
return {
|
||
"TB_code_v": task.TB_code_v,
|
||
"TB_code_py": task.TB_code_py,
|
||
"run_info": task.run_info,
|
||
"cga_coverage": task.cga_coverage,
|
||
"full_pass": task.full_pass
|
||
}
|
||
|
||
|
||
def generate_tb(dut_code, description, header, task_id,
|
||
api_key_path="config/key_API.json",
|
||
model="qwen-max",
|
||
enable_cga=True,
|
||
output_dir="./output"):
|
||
"""
|
||
便捷函数:生成TB并保存
|
||
|
||
参数:
|
||
dut_code: str, DUT代码
|
||
description: str, 项目描述
|
||
header: str, module header
|
||
task_id: str, 任务ID
|
||
api_key_path: str, API密钥路径
|
||
model: str, 使用的模型
|
||
enable_cga: bool, 是否启用CGA
|
||
output_dir: str, 输出目录
|
||
|
||
返回:
|
||
tuple: (TB文件路径, 结果字典)
|
||
"""
|
||
generator = TBGenerator(api_key_path, model)
|
||
result = generator.generate(dut_code, description, header, task_id, enable_cga)
|
||
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
tb_path = os.path.join(output_dir, f"{task_id}_tb.v")
|
||
from loader_saver import save_code
|
||
save_code(result["TB_code_v"], tb_path)
|
||
|
||
return tb_path, result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 示例用法
|
||
example_dut = """
|
||
module example(
|
||
input clk,
|
||
input rst,
|
||
input [7:0] a,
|
||
input [7:0] b,
|
||
output [15:0] y
|
||
);
|
||
assign y = a * b;
|
||
endmodule
|
||
"""
|
||
|
||
example_desc = "一个8位乘法器,输入两个8位无符号数,输出16位乘积"
|
||
|
||
example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);"
|
||
|
||
print("Generating TB for example multiplier...")
|
||
tb_path, result = generate_tb(
|
||
dut_code=example_dut,
|
||
description=example_desc,
|
||
header=example_header,
|
||
task_id="example_mul",
|
||
model="qwen-max"
|
||
)
|
||
print(f"TB saved to: {tb_path}")
|
||
print(f"Coverage: {result.get('cga_coverage', 0)}")
|
||
print(f"Full Pass: {result.get('full_pass', False)}")
|