210 lines
5.7 KiB
Python
210 lines
5.7 KiB
Python
|
|
"""
|
|||
|
|
TB Generator - 根据DUT和项目要求生成Testbench(完整版)
|
|||
|
|
支持多阶段流程: TBgen → TBsim → TBcheck → CGA → TBeval
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
from run_tbgen import generate_tb
|
|||
|
|
tb_path, result = generate_tb(
|
|||
|
|
dut_code="module example(...); endmodule",
|
|||
|
|
description="项目描述",
|
|||
|
|
header="module example(input clk, ...);",
|
|||
|
|
task_id="my_task",
|
|||
|
|
model="qwen-max"
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 添加项目路径
|
|||
|
|
PROJECT_ROOT = Path(__file__).parent
|
|||
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 延迟导入和初始化,避免单例顺序问题
|
|||
|
|
_config_instance = None
|
|||
|
|
_auto_logger_instance = None
|
|||
|
|
|
|||
|
|
def _ensure_init():
|
|||
|
|
"""确保Config和AutoLogger正确初始化"""
|
|||
|
|
global _config_instance, _auto_logger_instance
|
|||
|
|
|
|||
|
|
if _config_instance is None:
|
|||
|
|
# 创建临时配置文件
|
|||
|
|
_temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
|||
|
|
_config_content = """
|
|||
|
|
run:
|
|||
|
|
mode: 'autoline'
|
|||
|
|
gpt:
|
|||
|
|
model: "qwen-max"
|
|||
|
|
key_path: "config/key_API.json"
|
|||
|
|
save:
|
|||
|
|
en: True
|
|||
|
|
root: "./output/"
|
|||
|
|
autoline:
|
|||
|
|
cga:
|
|||
|
|
enabled: True
|
|||
|
|
max_iter: 10
|
|||
|
|
promptscript: "pychecker"
|
|||
|
|
onlyrun: "TBgensimeval"
|
|||
|
|
"""
|
|||
|
|
with open(_temp_config_path, "w") as f:
|
|||
|
|
f.write(_config_content)
|
|||
|
|
|
|||
|
|
# 必须在导入autoline之前创建Config
|
|||
|
|
from config import Config
|
|||
|
|
_config_instance = Config(_temp_config_path)
|
|||
|
|
|
|||
|
|
# 初始化AutoLogger
|
|||
|
|
from loader_saver import AutoLogger
|
|||
|
|
_auto_logger_instance = AutoLogger()
|
|||
|
|
|
|||
|
|
return _config_instance
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _create_config_for_task(task_id, model, enable_cga, cga_iter):
|
|||
|
|
"""为特定任务创建配置"""
|
|||
|
|
config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
|||
|
|
|
|||
|
|
config_content = f"""
|
|||
|
|
run:
|
|||
|
|
mode: 'autoline'
|
|||
|
|
gpt:
|
|||
|
|
model: "{model}"
|
|||
|
|
key_path: "config/key_API.json"
|
|||
|
|
save:
|
|||
|
|
en: True
|
|||
|
|
root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/"
|
|||
|
|
autoline:
|
|||
|
|
cga:
|
|||
|
|
enabled: {enable_cga}
|
|||
|
|
max_iter: {cga_iter}
|
|||
|
|
promptscript: "pychecker"
|
|||
|
|
onlyrun: "TBgensimeval"
|
|||
|
|
"""
|
|||
|
|
with open(config_path, "w") as f:
|
|||
|
|
f.write(config_content)
|
|||
|
|
|
|||
|
|
from config import Config
|
|||
|
|
return Config(config_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TBGenerator:
|
|||
|
|
"""完整的TB生成器,支持多阶段流程"""
|
|||
|
|
|
|||
|
|
def __init__(self, api_key_path="config/key_API.json", model="qwen-max"):
|
|||
|
|
self.api_key_path = api_key_path
|
|||
|
|
self.model = model
|
|||
|
|
|
|||
|
|
def generate(self, dut_code, description, header, task_id="test",
|
|||
|
|
enable_cga=True, cga_iter=10):
|
|||
|
|
"""
|
|||
|
|
生成Testbench(完整流程)
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
dut_code: str, DUT的Verilog代码
|
|||
|
|
description: str, 项目描述/需求
|
|||
|
|
header: str, DUT的module header
|
|||
|
|
task_id: str, 任务ID
|
|||
|
|
enable_cga: bool, 是否启用CGA优化
|
|||
|
|
cga_iter: int, CGA最大迭代次数
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
dict: 包含TB代码、评估结果等
|
|||
|
|
"""
|
|||
|
|
# 确保单例初始化
|
|||
|
|
_ensure_init()
|
|||
|
|
|
|||
|
|
# 导入autoline(现在可以安全导入了)
|
|||
|
|
from autoline.TB_autoline import AutoLine_Task
|
|||
|
|
|
|||
|
|
# 构建prob_data(符合HDLBitsProbset格式)
|
|||
|
|
prob_data = {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"task_number": 1,
|
|||
|
|
"description": description,
|
|||
|
|
"header": header,
|
|||
|
|
"module_code": dut_code,
|
|||
|
|
"testbench": None,
|
|||
|
|
"mutants": [],
|
|||
|
|
"llmgen_RTL": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 为任务创建配置
|
|||
|
|
config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter)
|
|||
|
|
|
|||
|
|
# 创建任务并运行
|
|||
|
|
task = AutoLine_Task(prob_data, config)
|
|||
|
|
task.run()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"TB_code_v": task.TB_code_v,
|
|||
|
|
"TB_code_py": task.TB_code_py,
|
|||
|
|
"run_info": task.run_info,
|
|||
|
|
"cga_coverage": task.cga_coverage,
|
|||
|
|
"full_pass": task.full_pass
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_tb(dut_code, description, header, task_id,
|
|||
|
|
api_key_path="config/key_API.json",
|
|||
|
|
model="qwen-max",
|
|||
|
|
enable_cga=True,
|
|||
|
|
output_dir="./output"):
|
|||
|
|
"""
|
|||
|
|
便捷函数:生成TB并保存
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
dut_code: str, DUT代码
|
|||
|
|
description: str, 项目描述
|
|||
|
|
header: str, module header
|
|||
|
|
task_id: str, 任务ID
|
|||
|
|
api_key_path: str, API密钥路径
|
|||
|
|
model: str, 使用的模型
|
|||
|
|
enable_cga: bool, 是否启用CGA
|
|||
|
|
output_dir: str, 输出目录
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
tuple: (TB文件路径, 结果字典)
|
|||
|
|
"""
|
|||
|
|
generator = TBGenerator(api_key_path, model)
|
|||
|
|
result = generator.generate(dut_code, description, header, task_id, enable_cga)
|
|||
|
|
|
|||
|
|
os.makedirs(output_dir, exist_ok=True)
|
|||
|
|
tb_path = os.path.join(output_dir, f"{task_id}_tb.v")
|
|||
|
|
from loader_saver import save_code
|
|||
|
|
save_code(result["TB_code_v"], tb_path)
|
|||
|
|
|
|||
|
|
return tb_path, result
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 示例用法
|
|||
|
|
example_dut = """
|
|||
|
|
module example(
|
|||
|
|
input clk,
|
|||
|
|
input rst,
|
|||
|
|
input [7:0] a,
|
|||
|
|
input [7:0] b,
|
|||
|
|
output [15:0] y
|
|||
|
|
);
|
|||
|
|
assign y = a * b;
|
|||
|
|
endmodule
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
example_desc = "一个8位乘法器,输入两个8位无符号数,输出16位乘积"
|
|||
|
|
|
|||
|
|
example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);"
|
|||
|
|
|
|||
|
|
print("Generating TB for example multiplier...")
|
|||
|
|
tb_path, result = generate_tb(
|
|||
|
|
dut_code=example_dut,
|
|||
|
|
description=example_desc,
|
|||
|
|
header=example_header,
|
|||
|
|
task_id="example_mul",
|
|||
|
|
model="qwen-max"
|
|||
|
|
)
|
|||
|
|
print(f"TB saved to: {tb_path}")
|
|||
|
|
print(f"Coverage: {result.get('cga_coverage', 0)}")
|
|||
|
|
print(f"Full Pass: {result.get('full_pass', False)}")
|