上传所有文件
This commit is contained in:
209
run_tbgen.py
Normal file
209
run_tbgen.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
TB Generator - 根据DUT和项目要求生成Testbench(完整版)
|
||||
支持多阶段流程: TBgen → TBsim → TBcheck → CGA → TBeval
|
||||
|
||||
用法:
|
||||
from run_tbgen import generate_tb
|
||||
tb_path, result = generate_tb(
|
||||
dut_code="module example(...); endmodule",
|
||||
description="项目描述",
|
||||
header="module example(input clk, ...);",
|
||||
task_id="my_task",
|
||||
model="qwen-max"
|
||||
)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目路径
|
||||
PROJECT_ROOT = Path(__file__).parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
# 延迟导入和初始化,避免单例顺序问题
|
||||
_config_instance = None
|
||||
_auto_logger_instance = None
|
||||
|
||||
def _ensure_init():
|
||||
"""确保Config和AutoLogger正确初始化"""
|
||||
global _config_instance, _auto_logger_instance
|
||||
|
||||
if _config_instance is None:
|
||||
# 创建临时配置文件
|
||||
_temp_config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
||||
_config_content = """
|
||||
run:
|
||||
mode: 'autoline'
|
||||
gpt:
|
||||
model: "qwen-max"
|
||||
key_path: "config/key_API.json"
|
||||
save:
|
||||
en: True
|
||||
root: "./output/"
|
||||
autoline:
|
||||
cga:
|
||||
enabled: True
|
||||
max_iter: 10
|
||||
promptscript: "pychecker"
|
||||
onlyrun: "TBgensimeval"
|
||||
"""
|
||||
with open(_temp_config_path, "w") as f:
|
||||
f.write(_config_content)
|
||||
|
||||
# 必须在导入autoline之前创建Config
|
||||
from config import Config
|
||||
_config_instance = Config(_temp_config_path)
|
||||
|
||||
# 初始化AutoLogger
|
||||
from loader_saver import AutoLogger
|
||||
_auto_logger_instance = AutoLogger()
|
||||
|
||||
return _config_instance
|
||||
|
||||
|
||||
def _create_config_for_task(task_id, model, enable_cga, cga_iter):
|
||||
"""为特定任务创建配置"""
|
||||
config_path = os.path.join(PROJECT_ROOT, "config", "custom.yaml")
|
||||
|
||||
config_content = f"""
|
||||
run:
|
||||
mode: 'autoline'
|
||||
gpt:
|
||||
model: "{model}"
|
||||
key_path: "config/key_API.json"
|
||||
save:
|
||||
en: True
|
||||
root: "{os.path.join(PROJECT_ROOT, 'output', task_id)}/"
|
||||
autoline:
|
||||
cga:
|
||||
enabled: {enable_cga}
|
||||
max_iter: {cga_iter}
|
||||
promptscript: "pychecker"
|
||||
onlyrun: "TBgensimeval"
|
||||
"""
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
from config import Config
|
||||
return Config(config_path)
|
||||
|
||||
|
||||
class TBGenerator:
|
||||
"""完整的TB生成器,支持多阶段流程"""
|
||||
|
||||
def __init__(self, api_key_path="config/key_API.json", model="qwen-max"):
|
||||
self.api_key_path = api_key_path
|
||||
self.model = model
|
||||
|
||||
def generate(self, dut_code, description, header, task_id="test",
|
||||
enable_cga=True, cga_iter=10):
|
||||
"""
|
||||
生成Testbench(完整流程)
|
||||
|
||||
参数:
|
||||
dut_code: str, DUT的Verilog代码
|
||||
description: str, 项目描述/需求
|
||||
header: str, DUT的module header
|
||||
task_id: str, 任务ID
|
||||
enable_cga: bool, 是否启用CGA优化
|
||||
cga_iter: int, CGA最大迭代次数
|
||||
|
||||
返回:
|
||||
dict: 包含TB代码、评估结果等
|
||||
"""
|
||||
# 确保单例初始化
|
||||
_ensure_init()
|
||||
|
||||
# 导入autoline(现在可以安全导入了)
|
||||
from autoline.TB_autoline import AutoLine_Task
|
||||
|
||||
# 构建prob_data(符合HDLBitsProbset格式)
|
||||
prob_data = {
|
||||
"task_id": task_id,
|
||||
"task_number": 1,
|
||||
"description": description,
|
||||
"header": header,
|
||||
"module_code": dut_code,
|
||||
"testbench": None,
|
||||
"mutants": [],
|
||||
"llmgen_RTL": []
|
||||
}
|
||||
|
||||
# 为任务创建配置
|
||||
config = _create_config_for_task(task_id, self.model, enable_cga, cga_iter)
|
||||
|
||||
# 创建任务并运行
|
||||
task = AutoLine_Task(prob_data, config)
|
||||
task.run()
|
||||
|
||||
return {
|
||||
"TB_code_v": task.TB_code_v,
|
||||
"TB_code_py": task.TB_code_py,
|
||||
"run_info": task.run_info,
|
||||
"cga_coverage": task.cga_coverage,
|
||||
"full_pass": task.full_pass
|
||||
}
|
||||
|
||||
|
||||
def generate_tb(dut_code, description, header, task_id,
|
||||
api_key_path="config/key_API.json",
|
||||
model="qwen-max",
|
||||
enable_cga=True,
|
||||
output_dir="./output"):
|
||||
"""
|
||||
便捷函数:生成TB并保存
|
||||
|
||||
参数:
|
||||
dut_code: str, DUT代码
|
||||
description: str, 项目描述
|
||||
header: str, module header
|
||||
task_id: str, 任务ID
|
||||
api_key_path: str, API密钥路径
|
||||
model: str, 使用的模型
|
||||
enable_cga: bool, 是否启用CGA
|
||||
output_dir: str, 输出目录
|
||||
|
||||
返回:
|
||||
tuple: (TB文件路径, 结果字典)
|
||||
"""
|
||||
generator = TBGenerator(api_key_path, model)
|
||||
result = generator.generate(dut_code, description, header, task_id, enable_cga)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
tb_path = os.path.join(output_dir, f"{task_id}_tb.v")
|
||||
from loader_saver import save_code
|
||||
save_code(result["TB_code_v"], tb_path)
|
||||
|
||||
return tb_path, result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
example_dut = """
|
||||
module example(
|
||||
input clk,
|
||||
input rst,
|
||||
input [7:0] a,
|
||||
input [7:0] b,
|
||||
output [15:0] y
|
||||
);
|
||||
assign y = a * b;
|
||||
endmodule
|
||||
"""
|
||||
|
||||
example_desc = "一个8位乘法器,输入两个8位无符号数,输出16位乘积"
|
||||
|
||||
example_header = "module example(input clk, input rst, input [7:0] a, input [7:0] b, output [15:0] y);"
|
||||
|
||||
print("Generating TB for example multiplier...")
|
||||
tb_path, result = generate_tb(
|
||||
dut_code=example_dut,
|
||||
description=example_desc,
|
||||
header=example_header,
|
||||
task_id="example_mul",
|
||||
model="qwen-max"
|
||||
)
|
||||
print(f"TB saved to: {tb_path}")
|
||||
print(f"Coverage: {result.get('cga_coverage', 0)}")
|
||||
print(f"Full Pass: {result.get('full_pass', False)}")
|
||||
Reference in New Issue
Block a user