上传所有文件

This commit is contained in:
zice6688
2026-03-30 16:46:48 +08:00
parent 8c2008c738
commit 35c99bac58
110 changed files with 23243 additions and 0 deletions

43
autoline/TB1_gen.py Normal file
View File

@@ -0,0 +1,43 @@
"""
Description : The TB generation stage in the autoline. The main TB generation workflow is implemented in prompt_scriptws
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/24 11:27:21
LastEdited : 2024/8/12 23:30:30
"""
from prompt_scripts import get_script, BaseScript
from loader_saver import log_localprefix
class TaskTBgen():
# TODO: in the future use pythonized prompt scripts and this class to replace the old TaskTBgen
"""TBgen, in this class we generate tb by calling different python script according to stage_template"""
def __init__(self, prob_data: dict, TBgen_prompt_script: str, task_dir: str, config):
self.prob_data = prob_data
self.prompt_script_name = TBgen_prompt_script
self.task_dir = task_dir
self.config = config
WorkFlowClass = get_script(TBgen_prompt_script)
self.workflow = WorkFlowClass(
prob_data = prob_data,
task_dir = task_dir,
config = config
)
@log_localprefix("TBgen")
def run(self):
self.workflow()
@property
def scenario_num(self):
return self.get_wf_attr("scenario_num")
@property
def scenario_dict(self):
return self.get_wf_attr("scenario_dict")
def get_wf_attr(self, attr_name:str):
if hasattr(self.workflow, attr_name):
return getattr(self.workflow, attr_name)
else:
return None

416
autoline/TB2_syncheck.py Normal file
View File

@@ -0,0 +1,416 @@
"""
Description : This is the TB syntactic checking stage in the autoline (previously named as TaskTBsim in autoline.py)
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/24 11:24:31
LastEdited : 2024/8/23 15:53:15
"""
import os
import LLM_call as llm
import iverilog_call as iv
import python_call as py
import loader_saver as ls
from config import Config
from loader_saver import autologger as logger
from loader_saver import log_localprefix
from prompt_scripts import get_script, BaseScript
from utils.utils import Timer, get_time
IDENTIFIER = {
"tb_start" : "```verilog",
"tb_end" : "```"
}
TESTBENCH_TEMPLATE = """%s
`timescale 1ns / 1ps
(more verilog testbench code here...)
endmodule
%s""" % (IDENTIFIER["tb_start"], IDENTIFIER["tb_end"])
DEBUG_TEMPLATE = """please fix the verilog testbench code below according to the error message below. please directly give me the corrected verilog testbench codes.
Attention: never remove the irrelevant codes!!!
your verilog testbench should be like:
%s
please only reply the full code modified. NEVER remove other irrelevant codes!!!
The testbench I give you is the one with error. To be convienient, each of the line begins with a line number. The line number also appears at the error message. You should use the line number to locate the error with the help of error message.
""" % (TESTBENCH_TEMPLATE)
DEBUG_FINAL_INSTR = """ please directly give me the corrected verilog codes, no other words needed. Your verilog codes should start with [```verilog] and end with [```]."""
DEBUG_TEMPLATE_PY = """please fix the python code below according to the error message below. please directly give me the corrected python codes.
Attention: never remove the irrelevant codes!!!
please only reply the full code modified. NEVER remove other irrelevant codes!!!
The python code I give you is the one with error. To be convienient, each of the line begins with a line number. The line number also appears at the error message. You should use the line number to locate the error with the help of error message.
"""
DEBUG_FINAL_INSTR_PY = """ please directly give me the corrected python codes, no other words needed. Your python codes should start with [```python] and end with [```]."""
# will be discarded by 15/08/2024
# DEBUG_TEMPLATE_END = """
# VERY IMPORTANT: please ONLY reply the full code modified. NEVER remove other irrelevant codes!!!
# Your testbench SHOULD NOT have the line number at the beginning of each line!!!
# """
class TaskTBsim():
"""
#### input:
- ivcode_path:
- the path of iverilog dir (xxx/TB_gen/), will contain all verilog files. generated .vvp will also be saved here
#### output:
- dict of the simulation result
- "sim_pass" : bool (whether the simulation is successful. This is only the necessary condition of the correctness of the testbench)
- "debug_iter" : int (the number of debug iterations)
- "sim_out" : str (the output of the simulation)
- "sim_err" : str (the error of the simulation)
- "TB_gen_debugged" : str or None (the testbench code after debug)
#### iverilog_path:
- the path of iverilog dir, will contain all verilog files. generated .vvp will also be saved here
#### task_id:
- the name of the problem, will be used as the name of generated files
file structure:
- original
- task_id.v
- task_id_tb.v
- task_id_vlist.txt
- task_id.vvp
- debug_1
- task_id.v
- task_id_tb.v
- task_id_vlist.txt
- task_id.vvp
- debug_2
- ...
"""
def __init__(self, TBgen: BaseScript, TB_code: str, module_header: str, task_dir: str, task_id: str, config):
self.TBgen = TBgen
self.TB_code_now = TB_code
self.module_header = module_header
self.task_dir = task_dir if task_dir.endswith("/") else task_dir + "/" # for the compatibility with the old version
self.task_id = task_id
self.config = config
self.working_dir = TBgen.TB_code_dir if TBgen.TB_code_dir.endswith("/") else TBgen.TB_code_dir + "/" # will change during the debug process
self.DUT_code = module_header + "\n\nendmodule\n"
self.debug_iter_max = config.autoline.debug.max
self.debug_iter_to_reboot = config.autoline.debug.reboot
self.proc_timeout = config.autoline.timeout
# self.debug_iter_now = 0 # this is a counter for both iverilog and python so it is possible to be larger than debug_iter_max
self.debug_iter_iv_now = 0
self.debug_iter_after_reboot_iv = 0
self.debug_iter_py_now = 0
self.debug_iter_after_reboot_py = 0
self.reboot_both = False
# self.debug_iter_after_reboot = 0
# pychecker related
self.pychecker_en = self.TBgen.Pychecker_en
self.PY_code_now = ""
if self.pychecker_en:
self.TBout_content = "" # will get after the last iverilog run
self.PY_code_now = self.TBgen.Pychecker_code
self.py_fail_reboot_both_iter = config.autoline.debug.py_rollback # will reboot both iv and py if python simulation failed xxx times
self.py_debug_focus = self.TBgen.py_debug_focus
# infos
self.sim_pass = False # this should be com_pass, but it is too late to change it now
self.py_pass = False
self.Eval0_pass = False
self.iverilog_info = None
self.reboot_both_times = 0
self.iv_runing_time = 0.0 # the time of running the last iverilog
self.py_runing_time = 0.0 # the time of running the last python
self.tokens = {"prompt": 0, "completion": 0}
@log_localprefix("TBsim")
def run(self):
if not self.pychecker_en:
self.run_iverilog()
self.Eval0_pass = self.sim_pass
else:
exit_en = False
while (not exit_en):
self.run_iverilog()
if self.sim_pass:
self.run_python()
# if (self.sim_pass and self.py_pass) or self.exceed_max_debug:
if not self.reboot_both:
exit_en = True
else:
exit_en = True
self.Eval0_pass = False
raise ValueError("TBsim: iverilog failed, python simulation is not allowed.")
self.Eval0_pass = self.sim_pass and self.py_pass
logger.info("TBsim finished : %s!"%(self.Eval0_pass))
def run_iverilog(self):
"""
- the main function of TBsim
"""
if not self.reboot_both:
# this will only be called at the first time of runing run_iverilog
self._save_code_run_iverilog()
self.sim_pass = self.iverilog_info[0]
while (self.debug_iter_iv_now < self.debug_iter_max) and (not self.sim_pass):
self.debug_iter_iv_now += 1
if self.debug_iter_after_reboot_iv < self.debug_iter_to_reboot:
self.debug_iter_after_reboot_iv += 1
self._debug_iv()
else:
self._reboot_iv()
self.sim_pass = self.iverilog_info[0]
self.reboot_both = False
if self.reboot_both:
# this means didn't enter the while, because debug_iter_max is already reached
logger.info("iverilog compilation (reboot from python) : failed! iverilog exceeded max debug iteration (%s)"%(self.debug_iter_max))
if self.sim_pass:
logger.info("iverilog compilation : passed!")
else:
logger.info("iverilog compilation : failed! exceeded max debug iteration (%s)"%(self.debug_iter_max))
# self.sim_out = self.iverilog_info[4]["out"] if self.iverilog_info[4] is not None else ""
# self.sim_err = self.iverilog_info[-1]
# clean .vcd wave files
self.clean_vcd()
def run_python(self):
# read the TBout.txt into TBout_content in working_dir
with open(self.TBout_path, "r") as f:
self.TBout_content = f.read()
self.debug_iter_after_reboot_py = 0
py_rollback = 0 # local variable
self._save_code_run_python()
# self.debug_iter_py_now
while (self.debug_iter_py_now < self.debug_iter_max) and (not self.python_info[0]):
if (not self.python_info[0]) and (py_rollback >= self.py_fail_reboot_both_iter):
# +1: debug py fail + [generated py fail]
self.reboot_both = True
break
py_rollback += 1
self.debug_iter_py_now += 1
if self.debug_iter_after_reboot_py < self.debug_iter_to_reboot:
self.debug_iter_after_reboot_py += 1
self._debug_py()
else:
self._reboot_py()
# self._reboot_py() # only reboot, no debugging because python debugging is much harder than verilog
# currently debug_py doesn't support reboot
if self.reboot_both:
self.py_pass = False
self.sim_pass = False
self.debug_iter_after_reboot_iv = self.debug_iter_to_reboot
logger.info("python simulation : failed! will reboot both iverilog and python")
elif self.python_info[0]:
self.py_pass = True
logger.info("python simulation : passed!")
else:
self.py_pass = False
logger.info("python simulation : failed! exceeded max debug iteration (%s)"%(self.debug_iter_max))
self.py_out = self.python_info[1]["out"] if self.python_info[1] is not None else ""
self.py_err = self.python_info[-1]
def _debug_iv(self):
with Timer(print_en=False) as debug_time:
logger.info("iverilog simulation failed! Debuging... (debug_iter = %s)"%(self.debug_iter_iv_now))
self.working_dir = self.task_dir + "debug_%s/" % (self.total_debug_iter_now)
os.makedirs(self.working_dir, exist_ok=True)
debug_prompt = self._debug_prompt_gen_iv()
debug_message = [{"role": "user", "content": debug_prompt}]
gpt_response, info = llm.llm_call(debug_message, self.config.gpt.model, self.config.gpt.key_path)
debug_message = info["messages"]
self.TB_code_now = llm.extract_code(gpt_response, "verilog")[-1]
self.TB_code_now = self.del_linemark(self.TB_code_now)
self._save_code_run_iverilog()
logger.info("%s: verilog DEBUG finished (%ss used)" % (self.debug_iter_info("iv"), round(debug_time.interval, 2)))
self.tokens["prompt"] += info["usage"]["prompt_tokens"]
self.tokens["completion"] += info["usage"]["completion_tokens"]
ls.save_messages_to_txt(debug_message, self.working_dir+"debug_messages.txt")
def _reboot_iv(self):
# change TBgen's code dir
with Timer (print_en=False) as reboot_time:
logger.info("iverilog simulation failed! Rebooting... (debug_iter = %s)"%(self.debug_iter_iv_now))
self.working_dir = self.task_dir + "debug_%s_reboot/" % (self.total_debug_iter_now)
os.makedirs(self.working_dir, exist_ok=True)
self.TBgen.run_reboot(self.working_dir, reboot_mode="TB")
self.TB_code_now = self.TBgen.TB_code
self._save_code_run_iverilog()
logger.info("%s: verilog REBOOT finished (%ss used)" % (self.debug_iter_info("iv"), round(reboot_time.interval, 2)))
# the tookens will be added into TBgen's tokens count, we don't count it again here.
# reset reboot counter
self.debug_iter_after_reboot_iv = 0
def _debug_py(self):
with Timer(print_en=False) as debug_time:
logger.info("python compilation failed! Debuging python... (debug_iter = %s)"%(self.debug_iter_py_now))
self.working_dir = self.task_dir + "debug_%s/" % (self.total_debug_iter_now)
os.makedirs(self.working_dir, exist_ok=True)
# run gpt
debug_prompt = self._debug_prompt_gen_py()
debug_message = [{"role": "user", "content": debug_prompt}]
gpt_response, info = llm.llm_call(debug_message, self.config.gpt.model, self.config.gpt.key_path)
debug_message = info["messages"]
self.PY_code_now = llm.extract_code(gpt_response, "python")[-1]
self.PY_code_now = self.del_linemark(self.PY_code_now)
if self.py_debug_focus: # currently only support pychecker SEQ mode
self.PY_code_now = self._py_focus(self.PY_code_now, before=False)
self._save_code_run_python()
logger.info("%s: python DEBUG finished (%ss used)" % (self.debug_iter_info("py"), round(debug_time.interval, 2)))
self.tokens["prompt"] += info["usage"]["prompt_tokens"]
self.tokens["completion"] += info["usage"]["completion_tokens"]
ls.save_messages_to_txt(debug_message, self.working_dir+"debug_messages.txt")
def _reboot_py(self):
# change TBgen's code dir
with Timer (print_en=False) as reboot_time:
logger.info("python compilation failed! Rebooting... (debug_iter = %s)"%(self.debug_iter_py_now))
self.working_dir = self.task_dir + "debug_%s_reboot/" % (self.total_debug_iter_now)
os.makedirs(self.working_dir, exist_ok=True)
self.TBgen.run_reboot(self.working_dir, reboot_mode="PY")
self.PY_code_now = self.TBgen.Pychecker_code
self._save_code_run_python()
logger.info("%s: python REBOOT finished (%ss used)" % (self.debug_iter_info("py"), round(reboot_time.interval, 2)))
# the tookens will be added into TBgen's tokens count, we don't count it again here.
# reset reboot counter
self.debug_iter_after_reboot_py = 0
def _save_code_run_iverilog(self):
with open(self.TB_path, "w") as f:
f.write(self.TB_code_now)
with open(self.DUT_path, "w") as f:
f.write(self.DUT_code)
with Timer(print_en=False) as iverilog_time:
self.iverilog_info = iv.iverilog_call_and_save(self.working_dir, silent=True, timeout=self.proc_timeout)
self.iv_runing_time = round(iverilog_time.interval, 2)
self.error_message_now = self.iverilog_info[-1]
if "program is timeout" in self.error_message_now:
# if the error message is timeout, we will delete the TBout.txt
# this is to avoid the situation that infinite loop produces a large TBout.txt
if os.path.exists(self.TBout_path):
os.remove(self.TBout_path)
self.clean_vvp()
def _save_code_run_python(self):
with open(self.PY_path, "w") as f:
f.write(self.PY_code_now)
with open(self.TBout_path, "w") as f:
f.write(self.TBout_content)
with Timer(print_en=False) as python_time:
self.python_info = py.python_call_and_save(pypath=self.PY_path, silent=True, timeout=self.proc_timeout)
self.py_runing_time = round(python_time.interval, 2)
self.error_message_now = self.python_info[-1]
def _debug_prompt_gen_iv(self):
debug_prompt = DEBUG_TEMPLATE + "\n previous testbench with error:\n" + self.add_linemark(self.TB_code_now) + "\n compiling error message:\n" + self.error_message_now
return debug_prompt
def _debug_prompt_gen_py(self):
if self.py_debug_focus:
py_code = self._py_focus(self.PY_code_now, before=True)
else:
py_code = self.PY_code_now
if not ("program is timeout" in self.error_message_now):
self.error_message_now = self._py_error_message_simplify(self.error_message_now)
debug_prompt = DEBUG_TEMPLATE_PY + "\n previous python code with error:\n" + self.add_linemark(py_code) + "\n compiling error message:\n" + self.error_message_now + DEBUG_FINAL_INSTR_PY
return debug_prompt
def _py_focus(self, code:str, before:bool):
"""
- code: the code under debug / after debug
- before: True, if before debug, will split the code; False, if after debug, will restore the code
"""
# KEY_WORD = "\ndef check_dut"
KEY_WORDs_1 = "def check_dut(vectors_in):\n golden_dut = GoldenDUT()\n failed_scenarios = []"
KEY_WORDs_2 = "\ndef SignalTxt_to_dictlist"
if before:
key_words = KEY_WORDs_1 if KEY_WORDs_1 in code else KEY_WORDs_2
if key_words not in code:
py_code_focus = code
self.py_code_nofocus = ""
else:
py_code_focus = code.split(key_words)[0]
self.py_code_nofocus = key_words + code.split(key_words)[1]
return py_code_focus
else:
return code + self.py_code_nofocus
@staticmethod
def _py_error_message_simplify(error_message:str, error_depth:int=1):
"""
- extract the key point of python error message
- error_depth: how many (how deep, from bottom to top) error messages to extract
"""
msg_lines = error_message.split("\n")
msg_out = ""
for line in reversed(msg_lines):
msg_out = line + "\n" + msg_out
if "File" in line:
error_depth -= 1
if error_depth == 0:
break
return msg_out
@property
def exceed_max_debug(self):
return (self.debug_iter_iv_now >= self.debug_iter_max) or (self.debug_iter_py_now >= self.debug_iter_max)
@property
def total_debug_iter_now(self):
return self.debug_iter_iv_now + self.debug_iter_py_now
@property
def TB_path(self):
return self.working_dir + self.task_id + "_tb.v"
@property
def DUT_path(self):
return self.working_dir + self.task_id + ".v"
@property
def PY_path(self):
return self.working_dir + self.task_id + "_tb.py"
@property
def TBout_path(self):
return self.working_dir + "TBout.txt"
def debug_iter_info(self, type):
"""return debug iter info string. Type: "iv" or "py" """
if self.pychecker_en:
if type == "iv":
return "verilog iter - %d/%d, total - %d/%d"%(self.debug_iter_iv_now, self.debug_iter_max, self.total_debug_iter_now, self.debug_iter_max*2)
elif type == "py":
return "python tier - %d/%d, total - %d/%d"%(self.debug_iter_py_now, self.debug_iter_max, self.total_debug_iter_now, self.debug_iter_max*2)
else:
raise ValueError("TaskTBsim.debug_iter_info(type): type should be 'iv' or 'py'")
else:
# only iverilog
return "debug iter %d/%d"%(self.debug_iter_iv_now, self.debug_iter_max)
@staticmethod
def add_linemark(code: str):
"""add the line mark (1., 2., ...) to the code at the beginning of each line"""
code = code.split("\n")
code = [str(i+1) + ". " + line for i, line in enumerate(code)]
return "\n".join(code)
@staticmethod
def del_linemark(code: str):
"""delete the line mark at the begening of each line if line mark exists"""
code = code.split("\n")
if code[1].split(".")[0].isdigit(): # use code[1] in case the first line is empty
code = [line.split(". ")[1:] for line in code]
for i, line in enumerate(code):
code[i] = ". ".join(line)
return "\n".join(code)
def clean_vcd(self):
"""clean the .vcd files in the task_dir"""
clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
for root, dirs, files in os.walk(clean_dir):
for file in files:
if file.endswith(".vcd"):
os.remove(os.path.join(root, file))
def clean_vvp(self):
"""clean the .vvp files in the task_dir"""
clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
for root, dirs, files in os.walk(clean_dir):
for file in files:
if file.endswith(".vvp"):
os.remove(os.path.join(root, file))

836
autoline/TB3_funccheck.py Normal file
View File

@@ -0,0 +1,836 @@
"""
Description : The functionality checking of the generated TB, the submodule of Autoline
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/22 10:36:06
LastEdited : 2025/2/25 22:11:13
"""
import os
import LLM_call as llm
import iverilog_call as iv
import python_call as py
import numpy as np
import loader_saver as ls
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from loader_saver import autologger as logger
from loader_saver import log_localprefix
class TaskTBcheck():
"""
### description
- this is the self-checking stage of our pipeline; This is the main contribution of AutoBench2。
- This stage is to check the functional correctness of the testbench generated by AutoBench.
"""
def __init__(self, task_dir:str, task_id:str, description:str, module_header:str, TB_code_v:str, TB_code_py:str|None=None, rtl_list:list[str]=None, rtl_num:int=20, scenario_num=None, correct_max:int=3, runfiles_save:bool=True, discriminator_mode:str="col_full_wrong", corrector_mode:str="naive", circuit_type:str=None, rtl_compens_max_iter:int=3, rtl_compens_en:bool=True, desc_improve:bool=False, **LLM_kwargs) -> None:
"""
- input:
- task_dir: the root directory of the taskTBcheck
- task_id: the name of the problem
- description: the description of the problem
- module_header: the header of the module
- TB_code_v: the generated verilog testbench code
- TB_code_py (opt.): the generated python testbench code, if None, the tb is in a pure verilog mode
- rtl_list (opt.): the list of llm-generated RTL codes, if None, will generate 20 RTL codes using LLM
- rtl_num (opt.): the number of RTL codes to generate, only used when rtl_list is None
- scenario_num (opt.): the number of scenarios in the testbench, if None, will be calculated from the failed scenarios (not accurate but won't impact the results)
- correct_max (opt.): the maximum number of correction iterations
- runfiles_save (opt.): whether to save the compilation files in TB_discrimination
- discriminator_mode (default: col_full_wrong): the mode of the discriminator
- corrector_mode (default: naive): the mode of the corrector
- circuit_type (opt.): the type of the circuit, used in the corrector (better performance if provided)
- rtl_compens_max_iter (default: 3): the maximum number of iterations of RTL compensation
- rtl_compens_en (default: True): whether to enable RTL compensation
- **LLM_kwargs: the keyword arguments for LLM (used in corrector and rtl generation), including:
- "main_model": the main llm name used in TB_generation and correction
- "rtlgen_model": the llm naem used in RTL generation
- "api_key_path": the path of the api key
- "temperature": the temperature of LLM
"""
self.task_dir = task_dir
self.working_dir = self.task_dir
self.task_id = task_id
self.description = description
self.module_header = module_header
self.TB_code_v = TB_code_v
self.TB_code_py = TB_code_py
self.pychecker_en = TB_code_py is not None
self.rtl_list = rtl_list
self.scenario_num = scenario_num
self.correct_max = correct_max
self.runfiles_save = runfiles_save
self.main_model = LLM_kwargs.get("main_model", None)
self.rtlgen_model = LLM_kwargs.get("rtlgen_model", None)
self.llm_api_key_path = LLM_kwargs.get("api_key_path", "config/key_API.json")
self.llm_temperature = LLM_kwargs.get("temperature", None)
self.circuit_type = circuit_type
self.rtl_compens_max_iter = rtl_compens_max_iter # see self.discriminate_TB for more info
self.rtl_compens_en = rtl_compens_en
self.desc_improve = desc_improve
self.tolerance_for_same_wrong_scen = 2
self.same_wrong_scen_times = 0
# discriminator and corrector
self.discriminator_mode = discriminator_mode
self.corrector_mode = corrector_mode
self.discriminator = TB_discriminator(discriminator_mode)
self.corrector = TB_corrector(self.corrector_mode, self.pychecker_en, self.circuit_type)
self.improver = SPEC_improver(description, "naive", self.pychecker_en, self.main_model, self.circuit_type)
# rtl list and number
self.rtl_newly_gen_num = 0
if self.rtl_list is None:
# self.rtl_num = rtl_num
self.set_rtl_num(rtl_num)
self.rtl_list_gen()
else:
# self.rtl_num = len(self.rtl_list)
self.set_rtl_num(len(self.rtl_list))
self.scenario_matrix = None
self.wrong_scen_num = 0 # a small number as default, will be replaced later
self.previous_wrong_scen_num = 9999 # a very large number as default, will be replaced later
self.TB_syntax_error = False
# tb analysis results
self.tb_pass = None
self.wrong_col_index = None
self.correct_col_index = None
self.unsure_col_index = None
# next_action
self.next_action = None
self.iter_now = 0
self.corrected = False # this means the TB has been corrected
if self.main_model is None:
logger.warning("main_model not found, may have trouble while correcting tb")
# record and runinfo
self.op_record = [] # record the order of the operations, similar to the var in autoline; will be added to the funccheck's op_record in the final runinfo.
@property
def rtl_num(self):
"""protected attr. rtl_num is initialized at the beginning and will not be changed"""
return self._rtl_num
@log_localprefix("TBcheck")
def run(self):
"""
- the main function of TaskTBcheck
- the TB check stage contains several sub-stages:
- 1. TB discriminating
- 2. TB correcting
- output: will update the next action of the task, including:
- "pass": the TB already passed the selfcheck, will go to the evaluation stage
- "reboot": the whole pipeline will start from the very beginning
- workflow: inital discrimination -> correction-discrimination lloop -> pass or reboot
"""
# TODO: if error occurs, go to reboot.
# initial discrimination
tolerance = 1
syntax_error = False
self.discriminate_TB()
if self.TB_syntax_error:
logger.negative("Testbench has syntax error, I give up. Reboot the whole process")
syntax_error = True
self.next_action = "reboot"
elif self.tb_pass:
logger.info("Testbench passed the funccheck")
self.next_action = "pass"
else:
# enter the correction loop, the initial TB has no syntax error when entering
if self.correct_max == 0:
logger.negative("No correction is allowed, I give up. Reboot the whole autoline process")
self.next_action = "reboot"
for self.iter_now in range(1, self.correct_max+1):
if (self.iter_now > 1) and (self.wrong_scen_num > self.previous_wrong_scen_num) and (not syntax_error):
# give up, the correction makes it worse
logger.negative(f"wrong scenarios increased ({self.wrong_scen_num} > {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
self.next_action = "reboot"
break
elif (self.iter_now > 1) and (self.wrong_scen_num == self.previous_wrong_scen_num) and (not syntax_error):
self.same_wrong_scen_times += 1
if self.same_wrong_scen_times >= self.tolerance_for_same_wrong_scen:
logger.info(f"wrong scenarios not decreased for {self.tolerance_for_same_wrong_scen} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
self.next_action = "reboot"
break
else:
logger.info(f"wrong scenarios not decreased for {self.same_wrong_scen_times} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), continue the correction")
self.correct_TB()
self.discriminate_TB()
if self.tb_pass:
logger.info("Testbench passed the funccheck after correction")
self.next_action = "pass"
break
elif self.TB_syntax_error:
# if the syntax error is from the corrector, we should roll back before correction
if tolerance > 0:
logger.negative(f"Testbench has syntax error after correction, I still have tolerance for that (tolerance={tolerance}). roll back and retry the self correction.")
self.TB_code_v, self.TB_code_py = self.TB_code_v_before_cor, self.TB_code_py_before_cor
tolerance -= 1
syntax_error = True
else:
logger.negative("Testbench has syntax error after correction, I don't have tolerance. I give up. Reboot the whole autoline process")
self.next_action = "reboot"
syntax_error = True
break
self.next_action = "reboot" if self.iter_now == self.correct_max else None
if (self.next_action == "reboot") and (not syntax_error) and (self.desc_improve):
# the desc improver does not work well so we do not use it in this work
self.improve_SPEC()
logger.info(f"self funccheck finished. Next Action: [{self.next_action}]")
return self.next_action
@log_localprefix("discriminator")
def discriminate_TB(self, no_any_files:bool=False):
"""
- check the correctness of the testbench and return the rtl analysis results in matrix form
- important data: the rtl list, the TB code
- update the following data: `scenario_matrix`, `tb_pass`, `wrong_col_index`, `correct_col_index`, `unsure_col_index`, `wrong_scen_num`
"""
rtl_dir_prefix = "RTL_"
self.op_record.append("discrim")
self.working_dir = os.path.join(self.task_dir, f"discrim_{self.iter_now}")
logger.info(f"Discriminating the testbench, NO.{self.iter_now} discrimination")
for i in range(self.rtl_compens_max_iter):
# the loop is for the case that too few RTL passed the syntax check, generate more rtls and recheck
failed_scenario_matrix = []
# for rtl_code in self.rtl_list:
self.TB_syntax_error = True
syntax_error_rtl = []
for rtl_idx, rtl_code in enumerate(self.rtl_list):
rtl_dir = os.path.join(self.working_dir, f"{rtl_dir_prefix}{rtl_idx+1}")
scenario_vector = self.run_testbench(rtl_dir, self.TB_code_v, rtl_code, self.TB_code_py, rtl_idx+1, self.runfiles_save and (not no_any_files))
failed_scenario_matrix.append(scenario_vector) # like [[2, 5], [3, 4, 5]]
if scenario_vector != [-1]:
self.TB_syntax_error = False
else:
syntax_error_rtl.append(rtl_idx+1)
if syntax_error_rtl != []:
logger.info(f"RTL(s) {syntax_error_rtl} have syntax error during discrimination")
if self.TB_syntax_error:
# there are two cases for TB syntax error:
# 1. this syntax error is from the previous stage, if so, we should directly reboot the whole autoline process
# 2. this syntax error is from the corrector, if so, we roll back to the version before correction and retry the correction
self.tb_pass = False
return None
# the len of each scenario vector should be the same, thus we transform each vector into a onehot vector [[1,1,0,1,1,0], [1,1,1,0,0,0]]
self.scenario_matrix = self.failed_scenarios_to_onehot_array(failed_scenario_matrix, max_scen_idx=self.scenario_num, taskid=self.task_id)
if not no_any_files:
# save this matrix into the working dir in a human readable form
np.savetxt(os.path.join(self.working_dir, "scenario_matrix.csv"), self.scenario_matrix, delimiter=",", fmt="%d")
# save this matrix in a plot form
self.draw_scenario_matrix(self.scenario_matrix, self.task_id, os.path.join(self.working_dir, "scenario_matrix.png"))
# we delete the syntax errored rtl, if no syntax error in TB
self.rtl_list = [rtl for rtl, scen in zip(self.rtl_list, failed_scenario_matrix) if scen != [-1]]
failed_scenario_matrix = [scen for scen in failed_scenario_matrix if scen != [-1]]
if len(self.rtl_list) < 0.5*self.rtl_num:
# too few RTL passed the syntax check
logger.info(f"too few RTL passed the syntax check ({len(self.rtl_list)}/{self.rtl_num}), I will generate more and recheck. This is not TB's fault.")
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num-len(self.rtl_list), self.description, self.module_header, self.rtlgen_model, self.rtl_list)
self.rtl_newly_gen_num += gen_num
# delete the previous rtl dirs (dir start with rtl_dir_prefix and under the working dir)
for subdir in os.listdir(self.working_dir):
if subdir.startswith(rtl_dir_prefix) and os.path.isdir(os.path.join(self.working_dir, subdir)):
os.system(f"rm -rf {os.path.join(self.working_dir, subdir)}")
logger.info(f"re-discriminate the testbench with updated RTL list")
if i == self.rtl_compens_max_iter-1:
logger.info(f"no re-discrimination since the max iteration reached")
else:
break
# discriminate the testbench according to the one hot matrix
self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index = self.discriminator.discriminate(self.scenario_matrix)
self.previous_wrong_scen_num = self.wrong_scen_num
self.wrong_scen_num = len(self.wrong_col_index)
return self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index
@log_localprefix("corrector")
def correct_TB(self):
"""
- correct the testbench by using the RTL analysis results
"""
self.op_record.append("correct")
self.working_dir = os.path.join(self.task_dir, f"correct_{self.iter_now}")
self.TB_code_v_before_cor, self.TB_code_py_before_cor = self.TB_code_v, self.TB_code_py
self.TB_code_v, self.TB_code_py = self.corrector.correct(self.description, self.wrong_col_index, self.TB_code_v, self.main_model, self.TB_code_py, self.working_dir)
self.corrected = True
# self check if the testbench is corrected
# self.working_dir = os.path.join(self.task_dir, f"TBcheck_after_correct")
# self.discriminate_TB()
# if self.tb_pass:
# logger.info(f"[{self.task_id}] - Testbench passed the selfcheck after correction")
# self.next_action = "pass"
# else:
# logger.warning(f"[{self.task_id}] - Testbench failed the selfcheck after correction, failed scenarios: {self.wrong_col_index}")
@log_localprefix("improver")
def improve_SPEC(self):
"""
- improve the specification of the task according to the discrimination and correction results
"""
self.op_record.append("improve")
self.working_dir = os.path.join(self.task_dir, "improve_Desc")
self.description = self.improver.improve(self.wrong_col_index, self.correct_col_index, self.TB_code_v, self.TB_code_py, working_dir=self.working_dir)
@staticmethod
def run_testbench(dir, driver_code:str, DUT_code:str, checker_code:str, rtl_index:int, save_en:bool=True):
"""
- modified from autoline.py TBEval.run_testbench
- it has two mode: pychecker mode or verilog testbench mode
-input:
- dir: the dir to save the TB, DUT and pychecker code
- driver_code: str; the testbench code
- DUT_code: str; the DUT code
- checker_code: str; the pychecker code
- output:
- a list of failed scenarios (if rtl has syntax error, return [-1])
"""
# iverilog part
# save the TB and DUT
os.makedirs(dir, exist_ok=True)
v_driver_path = os.path.join(dir, "driver.v")
py_checker_path = os.path.join(dir, "checker.py")
dut_path = os.path.join(dir, "DUT.v")
with open(v_driver_path, "w") as f:
f.write(driver_code)
with open(dut_path, "w") as f:
f.write(DUT_code)
iv_run_info = iv.iverilog_call_and_save(dir, silent=True)
if not iv_run_info[0]:
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed, the PREREQUISITE of 'Evaluation' is no syntactic error from Testbench!!!")
# raise RuntimeError("Iverilog Compilation Failed")
return [-1]
with open(py_checker_path, "w") as f:
f.write(checker_code)
py_run_info = py.python_call_and_save(pypath=py_checker_path, silent=True)
if not py_run_info[0]:
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Python code!!!")
# raise RuntimeError("Python Compilation Failed")
return [-1]
python_info_out = py_run_info[1]["out"]
python_info_out : str
# find the last ] in the out
last_bracket_end = python_info_out.rfind("]")
# find the last [ in the out
last_bracket_start = python_info_out.rfind("[")
# if the last [..] is a [], return []
if last_bracket_end == last_bracket_start+1:
return []
# extract the digits
failed_scenarios = python_info_out[last_bracket_start+1:last_bracket_end].replace("'", "").split(",")
# if the item is not pure digit such as 2b, then we only extract the digit part
failed_scenarios = [int("".join([char for char in scenario if char.isdigit()])) for scenario in failed_scenarios]
failed_scenarios = list(map(int, failed_scenarios))
# if save_en false, we delete the dir
if not save_en:
os.system(f"rm -rf {dir}")
return list(set(failed_scenarios))
def rtl_list_gen(self)->list[str]:
"""
- generate the RTL list using LLM, will empty the old rtl list
- attr needed: description, module_header, rtl_num, llm_model
- attr changed: rtl_list
"""
self.rtl_list = []
logger.info(f"rtl list not found, generating naive rtls for testbench checking")
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num, self.description, self.module_header, self.rtlgen_model, self.rtl_list)
self.rtl_newly_gen_num += gen_num
# save the rtl list
save_path = os.path.join(self.task_dir, "rtl_list.json")
os.makedirs(self.task_dir, exist_ok=True)
ls.save_json_lines([{"task_id": self.task_id, "llmgen_RTL": self.rtl_list}], save_path)
@staticmethod
def gen_rtl(num:int, description:str, header:str, llm_mode:str, rtl_list:list=[]):
"""
- input:
- num (int): the number of RTLs to generate
- description (str): the description of the rtl problem
- header (str): the header of the module
- llm (str): the llm model to use (official model name)
- rtl_list (list) [optional]: the newly generated RTLs will be appended to this list, can be empty
- output:
- rtl_list (list): the list of the newly generated RTLs (and the old ones, if have)
"""
rtl_gen_num = 0
prompt = "Your task is to write a verilog RTL design according to the design specification. The infomation we have is the problem description that guides student to write the RTL code (DUT) and the header of the desired module. here is the problem description:\n"
prompt += description
prompt += "\nHere is the header of the desired module:\n"
prompt += header
prompt += "\nPlease only return the module code (header should be included) in verilog, please do not include any other words."
for i in range(num):
# call llm
answer = llm.llm_call(prompt, llm_mode)[0]
# extract the module code
module_code = llm.extract_code(answer, "verilog")[0]
# logger.trace(f"[{self.task_id}] - {i+1} RTLs generated")
rtl_list.append(module_code)
rtl_gen_num += 1
logger.info("%d naive rtls generated"%(rtl_gen_num))
return rtl_list, rtl_gen_num
@staticmethod
def failed_scenarios_to_onehot_array(failed_scenarios:list[list], max_scen_idx:int|None=None, taskid:str=""):
"""
- input: [failed_scenarios:list[int]] (for example: [[1,2,3], [2,3,4], [1,3,4], [-1]]), if one failed scenario list is [-1], it means the rtl has syntax error, should be skipped
- output (np.array): a onehot array (for example: [[0,0,0,1], [1,0,0,0], [0,1,0,0], [-1,-1,-1,-1]]) (1 denots pass, 0 denotes fail, -1 means syntax error)
"""
# find the max scenario index
listlen = len(failed_scenarios)
max_idx_given = max_scen_idx if max_scen_idx is not None else 1
# we calculate the max_index_cal, and define the final max scenario index using max(max_index_cal, max_index_given)
max_idx_cal = max(map(lambda x: max(x) if x != [] else 0, failed_scenarios))
if max_idx_cal in [-1, 0]:
# -1: all the scenarios in this rtl are -1
# 0: usually not possible because the scenario index is from 1, but if exists, set to 1.
max_idx_cal = 1
if failed_scenarios == list(map(lambda x: [], range(listlen))):
# this means all rtl passed
max_idx_cal = 1 # set to 1 otherwise the plot will be empty
max_idx = max(max_idx_cal, max_idx_given)
# if the failed scenario list is [-1], then all the scenarios in this rtl are -1
# create the onehot array
grid_map = [[1]*max_idx for _ in range(listlen)]
for rtl_idx, failed_scens in enumerate(failed_scenarios):
if failed_scens == [-1]:
grid_map[rtl_idx] = [-1]*max_idx
continue
for scen_idx in failed_scens:
grid_map[rtl_idx][scen_idx-1] = 0
return np.array(grid_map)
@staticmethod
def draw_scenario_matrix(scenario_matrix:np.ndarray, task_id:str, saving_path:str):
"""
- draw the 2D failed scenario array. The element in the array can only be 0, 1, -1. We use red for 0, green for 1, and gray for -1.
- if the scenario is empty, will return a gray color block.
"""
if len(scenario_matrix) == 0:
scenario_matrix = np.array([[-1]])
# if the element in data not in [0, 1, -1], change the element to -1
scenario_matrix = np.where(np.logical_or(scenario_matrix == 0, np.logical_or(scenario_matrix == 1, scenario_matrix == -1)), scenario_matrix, -1)
# get the RGB values for salmon, grey and mediumseagreen
salmon = mcolors.to_rgb("salmon")
grey = mcolors.to_rgb("grey")
mediumseagreen = mcolors.to_rgb("mediumseagreen")
color_mapping = {
0: salmon,
1: mediumseagreen,
-1: grey
}
rgb_image = np.array([[color_mapping[value] for value in row] for row in scenario_matrix])
# assign the color to the scenario_matrix
for value, color in color_mapping.items():
rgb_image[scenario_matrix == value] = color
plt.imshow(rgb_image)
plt.ylabel("RTL index")
plt.xlabel("Scenario index")
current_xticks = np.arange(scenario_matrix.shape[1])
plt.xticks(current_xticks, current_xticks + 1)
current_yticks = np.arange(scenario_matrix.shape[0])
plt.yticks(current_yticks, current_yticks + 1)
plt.title(f"[{task_id}] - Matrix of RTL-TB Scenario Correctness")
plt.savefig(saving_path)
plt.close()
def update_description(self)->str:
"""
- will modify the description of the task according to the descrimination and correction results
"""
logger.info("the description of the task is updated")
return self.description
def set_rtl_num(self, value:int):
self._rtl_num = value
def __call__(self, *args, **kwds):
return self.run()
class TB_discriminator():
"""
this class is used to discriminate the testbench according to the failing matrix
"""
def __init__(self, mode:str) -> None:
self.mode = mode
match self.mode:
case "col_full_wrong":
# the most naive mode;
# tb correction: if any scenario col is fully wrong, then the tb is wrong
# scenario correction: the scenarios that are fully wrong are wrong
pass
case "col_80_wrong":
# similar to the above, but the criterion is 80% wrong
pass
case "col_90_wrong":
# similar to the above, but the criterion is 90% wrong
pass
case "col_70_wrong":
# similar to the above, but the criterion is 70% wrong
pass
case "col_60_wrong":
# similar to the above, but the criterion is 60% wrong
pass
case "col_50_wrong":
# similar to the above, but the criterion is 50% wrong
pass
case "col_40_wrong":
# similar to the above, but the criterion is 40% wrong
pass
case "col_70_wrong_row_25_correct":
# similar to 70_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
pass
case "col_50_wrong_row_25_correct":
# similar to 50_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
pass
case "col_70_wrong_row_1_correct":
# similar to 70_wrong, but if 1% of the RTLs are fully correct, then the TB is correct
pass
case "col_70_wrong_row_10_correct":
# similar to 70_wrong, but if 10% of the RTLs are fully correct, then the TB is correct
pass
case _:
logger.critical("class discriminator - mode not found!!!")
def discriminate(self, failed_matrix:np.ndarray)->tuple[bool, list[int], list[int], list[int]]:
"""
- input: the failed matrix of the testbench in onehot form
- output:
- the idexes in the scen list are starting from 1
- bool: whether the testbench is correct
- list[int]: the list of the wrong scenarios
- list[int]: the list of the correct scenarios
- list[int]: the list of the scenarios that discriminator are not sure
- the -1 row will not be considered here. See function 'failed_scenarios_to_onehot_array'
"""
# first check if all the scenarios are [-1], which means the tb has syntax error
if np.all(failed_matrix == -1):
return None, [], [], []
failed_matrix = failed_matrix[~np.all(failed_matrix == -1, axis=1)]
match self.mode:
case "col_full_wrong":
# check which column of the matrix is fully wrong (0)
wrong_col_index = np.where(np.all(np.isin(failed_matrix, [0, -1]), axis=0))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.any(failed_matrix == 0, axis=0) & np.any(failed_matrix == 1, axis=0))[0] + 1
tb_pass = len(wrong_col_index) == 0 # as long as there is no fully wrong column, the tb is correct (loose criterion)
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
# my_log = logger.positive if tb_pass else logger.negative
# my_log.info(f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_80_wrong":
# check which column of the matrix is 80% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.8*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.8*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_90_wrong":
# check which column of the matrix is 90% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.9*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.9*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_25_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_1_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.01*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_10_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.1*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_60_wrong":
# check which column of the matrix is 60% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.6*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.6*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_50_wrong":
# check which column of the matrix is 50% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_50_wrong_row_25_correct":
# check which column of the matrix is 50% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_40_wrong":
# check which column of the matrix is 40% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.4*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.4*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case _:
logger.critical("TB discriminator - mode not found!!!")
raise RuntimeError("TB discriminator - mode not found!!!")
COR_PROMPT_1 = """Your task is to correct the testbench according to the failing scenarios. the information we have is the failed/passed scenarios of the testbench, the problem description and the testbench code.
the testbench code is consisted of both verilog and python code. The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
ATTENTION: The python code contains error, and your target is to find it and tell me how to correct it (you don't need to give me the code in this stage).
"""
HINT_SEQ = """
Hints - explaination of the given python code:
the python class "GoldenDUT": This python class can represent the golden DUT (the ideal one). In "GoldenDUT", following methods are defined:
- 1. a method "def __init__(self)": Set the inner states/values of the golden DUT. These values have suffix "_reg". The initial value of these inner values is "x", but later will be digits. The "__init__" method has no input parameters except "self".
- 2. a method "def load(self, signal_vector)": This method is to load the important input signals and the inner values of "GoldenDUT" shall change according to the input signals. There is no clock signal in the input signal vector, every time the "load" method is called, it means a new clock cycle. The initial values "x" should be changed according to the input signals. This method has no return value.
- 3. a method "def check(self, signal_vector)": This method is to determine the expected output values and compare them with output signals from DUT. It should return True or False only.
"""
HINT_CMB = """
Hints - explaination of the given python code:
The given python code contains one class "GoldenDUT". this python class can represent the golden DUT (the ideal one). By calling the inner method "check", the signal vector from DUT will be checked. The details of the golden DUT are as follows:
- a. a method "def __init__(self)". Set the inner states/values of the golden DUT. The "__init__" method has no input parameters except "self".
- b. a method "def load(self, signal_vector)". This method is to load the important input signals and get the expected output signals. it should return the expected output values. It can call other methods to help computing the expected output. It will be called by other inner methods later.
- c. a method "def check(self, signal_vector)". This method is to call "load" to get the expected output values, and compare them with output signals from DUT. It should return True or False only. It can call other methods to help checking.
- d. other methods, they can be called by "__init__", "load" or "check".
- e. the input of "load" and "check" is the signal vector. The signal vector is a dictionary, the key is the signal name, the value is the signal value.
"""
COR_PROMPT_2_PART1 = """
please correct the python code according to the following rules:
PYTHON code rule: please do not change the original high level structure of the python code. i.e., if python code only contains one class and several functions such as init, load, check and more, only modify the implementation of the function, but do not change the name or delete the functions/class methods. you can add new class methods or functions if needed. you can use python libraries such as numpy or math.
"""
COR_PROMPT_2_PART2 = """
i.e., your python code format in response should still be like:
class <class_name>:
def __init__(self):
...(omitted)
def load(self, ...):
...
def check(self, ...):
...
def <other_functions>(self, ...):
...
ATTENTION: please give me the corrected python code according to our previous conversation and the hints above. please give me the corrected full python code (not the part but the whole python code like I give you in our previous conversation).
"""
class TB_corrector():
def __init__(self, mode:str, pychecker_en:bool, circuit_type:str="") -> None:
self.mode = mode
self.pychecker_en = pychecker_en
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
# logger.debug(f"TB_corrector class - mode: {self.mode}, pychecker_en: {self.pychecker_en}, circuit_type: {self.circuit_type}; The input circuit type is {circuit_type}")
match self.mode:
case "naive":
# the most naive mode;
pass
case _:
logger.critical("TB_corrector class - mode not found!!!")
def correct(self, description, failed_scenarios, TB_code_v:str, llm_model:str, TB_code_py:str|None=None, working_dir:str=None) -> tuple[str, str]:
match self.mode:
case "naive":
if self.pychecker_en:
TB_code_py = self._py_focus(TB_code_py, before=True)
py_code_hint = HINT_CMB if self.circuit_type == "combinational" else HINT_SEQ
prompt = COR_PROMPT_1
prompt += "Here is the problem description:\n"
prompt += description
prompt += "\nHere is the testbench code:\n"
prompt += "ATTENTION: the following scenarios are wrong: " + str(failed_scenarios) + "\n"
# circuit type
prompt += f"the circuit type of this task is {self.circuit_type}\n"
prompt += "Here is the verilog code. it contains the meaning of each scenario. you can combine the wong scenario info above and the following code to better understand the reason of failing:\n"
prompt += TB_code_v
prompt += "\nHere is the python code, it contains error, please combine it with the wrong scenario info and the verilog code to understand:\n"
prompt += TB_code_py
prompt += "\nHere is some hints for your better understanding of the python codes above:"
prompt += py_code_hint
prompt += "\nplease reply me with the following steps:"
prompt += "\n1. please analyze the reason of the failed scenarios. If possible, please find the in common between the failed scenarios."
prompt += f"\n2. please analyze which part of the python code is related to the failed test scenarios ({str(failed_scenarios)})."
prompt += "\n3. please tell me how to correct the wrong part (in natural language, do not give me the complete code implementation. please explain it in English.)"
prompt += "\nhere is an example of the reply:"
prompt += "\n1. the failed scenarios are all related to the same signal x\n2. the mid part of the function_X is related to the failed scenarios\n3. the correct logic of signal x should be y."
logger.info("naive corrector mode begins")
answer = llm.llm_call(prompt, llm_model)[0]
prompt_2 = COR_PROMPT_2_PART1 + py_code_hint + COR_PROMPT_2_PART2
message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, {"role": "user", "content": prompt_2}]
answer_2, more_info = llm.llm_call(message, llm_model)
# if "VERILOG" in answer_2:
# TB_code_v = llm.extract_code(answer_2, "verilog")[0]
# else:
# TB_code_py = llm.extract_code(answer_2, "python")[0]
TB_code_py = llm.extract_code(answer_2, "python")[0]
TB_code_py = self._py_focus(TB_code_py, before=False)
if working_dir is not None:
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
ls.save_code(TB_code_v, os.path.join(working_dir, "TB.v"))
if TB_code_py is not None:
ls.save_code(TB_code_py, os.path.join(working_dir, "TB.py"))
logger.info("naive corrector mode ends; conversation and codes saved")
else:
logger.critical("TB_corrector - pychecker not enabled")
raise RuntimeError("TB_corrector - pychecker not enabled")
return TB_code_v, TB_code_py
case _:
logger.critical("TB_corrector - mode not found!!!")
raise RuntimeError("TB_corrector - mode not found!!!")
def _py_focus(self, code:str, before:bool):
"""
- imigrated from TB2_syncheck.py
- code: the code under debug / after debug
- before: True, if before debug, will split the code; False, if after debug, will restore the code
"""
KEY_WORDs_1 = "def check_dut(vectors_in):\n golden_dut = GoldenDUT()\n failed_scenarios = []"
KEY_WORDs_2 = "\ndef SignalTxt_to_dictlist"
if before:
key_words = KEY_WORDs_1 if KEY_WORDs_1 in code else KEY_WORDs_2 # for compatibility with the old version
if key_words not in code:
py_code_focus = code
self._py_code_nofocus = ""
else:
py_code_focus = code.split(key_words)[0]
self._py_code_nofocus = key_words + code.split(key_words)[1]
return py_code_focus
else:
return code + self._py_code_nofocus
IMPR_PROMPT_1 = """Your task is to improve the quality of an RTL problem description using the following given information. Our final target is using the description to generate a testbench for the RTL design. Currently we already have the testbench but it is not perfect correct. Now in this stage the target is to generate a better description.
The information we have the is original RTL description, the testbench code. The testbench code includes the verilog code for test scenario generation, and the python code for checking the output vector.The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
Attention: the testbench we provide is not the perfect one, and it contains error. However, we already know the test scenarios where the testbench works correctly and the scenarios where the testbench has problem. We will provide you the scenarios index later.
Here, firstly, is the problem description to be imporved:
\n"""
DESC_MARK_BEGIN = "***description begins***"
DESC_MARK_END = "***description ends***"
DESC_STEP_INSTRUCT = f"""
please reply me with the following steps:
1. please analyze which part of the testbench (especially the python checker code) is correct and can be used to improve the description.
2. please analyze how can we improve the descriptin. for example, which part of the technical details can be more detailed, which part can be more clear, which part can be more concise.
3. please provide the improved complete description. We will directly use it in the later stages.
the format of description should be like:
{DESC_MARK_BEGIN}
... (the improved description, should be complete)
{DESC_MARK_END}
"""
DESC_FINAL_INSTRUCT = f"""
ATTENTION: please know that the provided testbench is not perfect and may contains many errors. Thus, your modification on the description should not change the function of the original description. When there are conflicts between the testbench and the description, always believe the description is correct. Do not delete the information in the description, but you can rewrite it in a better way. You can also add more details to it. But NEVER mention any scenario index because the scenarios will not be the same at the next stage.
when you answer the last question (provide the improved complete description), the descriptino should start with "{DESC_MARK_BEGIN}" and end with "{DESC_MARK_END}". Only in this way can we recognize the improved description.
"""
class SPEC_improver():
def __init__(self, description, mode:str, pychecker_en:bool, llm_model:str, circuit_type:str="") -> None:
self.description = description
self.mode = mode
self.pychecker_en = pychecker_en
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
self.llm_model = llm_model
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
def improve(self, wrong_scenarios, correct_scenarios, TB_code_v:str, TB_code_py:str|None=None, working_dir:str|None=None) -> str:
# not implemented yet
match self.mode:
case "naive":
logger.info("naive description improver mode begins")
prompt = ""
prompt += IMPR_PROMPT_1
prompt += DESC_MARK_BEGIN + "\n"
prompt += self.description + "\n"
prompt += DESC_MARK_END + "\n"
prompt += "\nHere is the testbench codes:\n"
prompt += "ATTENTION: the following scenarios are wrong: " + str(wrong_scenarios) + "\n"
prompt += "ATTENTION: the following scenarios are correct, you can rely on these scenarios to improve the description: " + str(correct_scenarios) + "\n"
prompt += TB_code_v + "\n"
if self.pychecker_en:
prompt += f"\nHere is the python code (python checker). please note that the python checker has correct function under the scenarios {str(correct_scenarios)}, but wrong under the scenarios {str(wrong_scenarios)}:\n"
prompt += TB_code_py + "\n"
prompt += DESC_STEP_INSTRUCT
prompt += DESC_FINAL_INSTRUCT
message = [{"role": "user", "content": prompt}]
answer, more_info = llm.llm_call(message, self.llm_model)
try:
improved_description = answer.split(DESC_MARK_BEGIN)[1].split(DESC_MARK_END)[0]
if improved_description == "":
improved_description = self.description
except:
improved_description = self.description
if working_dir is not None:
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
# save description
with open(os.path.join(working_dir, "description.txt"), "w") as f:
f.write(improved_description)
logger.info("naive description improver mode ends")
return improved_description
case "hint":
logger.info("description improver 'hint' mode begins")
return self.description
def test():
from loguru import logger
failed_scenarios = [
[1,3,5],
[-1],
[2,3]
]
max_scenario = 7
onehot_array = TaskTBcheck.failed_scenarios_to_onehot_array(failed_scenarios, max_scenario)
print(np.array(onehot_array))
my_disc = TB_discriminator("col_full_wrong")
print(my_disc.discriminate(np.array(onehot_array)))

245
autoline/TB4_eval.py Normal file
View File

@@ -0,0 +1,245 @@
"""
Description : This is the testbench eval stage in autoline
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/24 11:24:43
LastEdited : 2024/8/28 21:08:21
"""
import os
import iverilog_call as iv
import python_call as py
from loader_saver import autologger as logger
from loader_saver import log_localprefix
from utils.utils import Timer, get_time
TC_PASS_CHECK_LIST_TB_GEN = ["All test cases passed", "all test cases passed", "All Test Cases Passed"]
TC_PASS_CHECK_LIST_TB_GOLDEN = ['Mismatches: 0 in ', 'Hint: Total mismatched samples is 0 out of']
TC_PASS_CHECK_LIST_PYCHECKER = ["[]"]
class TaskTBeval():
"""
### description
- this is the evaluation stage of our pipeline; the priority of this stage is that TB is generated and the empty DUT compilation is passed;
- please use `try` to catch the exception of this function.
- this module is independent from the previous modules.
#### input
- task_id: the name of the problem
- root_dir: the dir of one problem
- TB_gen: the testbench under evaluation (str)
- TB_golden: the golden testbench (str)
- DUT_golden: the golden RTL DUT (str)
- DUT_mutant_list: the list of RTL DUT mutants modified from DUT_golden;[str]
#### output
- dict
- "Eval1_pass" : bool (whether the golden RTL checking passed)
- "Eval2_pass" : bool (whether the golden TB comparison on RTL mutants passed)
- "Eval2_failed_mutant_idxes" : list of int (the index of the failed mutants)
"""
"""main structure: run(), run_Eval1(), run_Eval2()"""
def __init__(self, task_id: str, task_dir: str, TB_gen: str, TB_golden:str=None, DUT_golden:str=None, DUT_mutant_list:list=None, DUT_gptgen_list:list = None, pychecker_en:bool = False, pychecker_code:str = "", runfiles_save:bool = True):
self.task_id = task_id
self.task_dir = task_dir
self.TB_gen = TB_gen
self.TB_golden = TB_golden
self.DUT_golden = DUT_golden
self.DUT_mutant_list = DUT_mutant_list
self.DUT_gptgen_list = DUT_gptgen_list
self.pychecker_en = pychecker_en
self.save_en = runfiles_save
self.TB_gen_mode = "TB_gen" if not self.pychecker_en else "Pychecker"
self.pychecker_code = pychecker_code
self.working_dir = ""
# Eval1 related
self.Eval1_exist = False
# self.Eval1_dir = task_dir + "eval1_GoldenRTL/"
self.Eval1_dir = os.path.join(task_dir, "eval1_GoldenRTL")
self.Eval1_results = None
self.Eval1_pass = None
# Eval2 related
self.Eval2_exist = False
# self.Eval2_dir = task_dir + "eval2_GoldenTB_and_mutants/"
self.Eval2_dir = os.path.join(task_dir, "eval2_GoldenTB_and_mutants")
self.Eval2_pass = None
self.Eval2_failed_mutant_idx = None
self.Eval2_passed_mutant_idx = None
# Eval2b related
self.Eval2b_exist = False
# self.Eval2b_dir = task_dir + "eval2b_GPTgenTB/"
self.Eval2b_dir = os.path.join(task_dir, "eval2b_GPTgenTB")
self.Eval2b_pass = None
self.Eval2b_failed_mutant_idx = None
self.Eval2b_passed_mutant_idx = None
@log_localprefix("TBeval")
def run(self):
# Eval 1
if self.DUT_golden is not None:
self.run_Eval1()
if self.Eval1_pass:
# Eval 2
if self.TB_golden is not None and self.DUT_mutant_list is not None:
self.run_Eval2(mode="mutant")
# Eval 2b
if self.TB_golden is not None and self.DUT_gptgen_list is not None:
self.run_Eval2(mode="gptgen")
else:
logger.info("[%s] Eval 2/2b is skipped because Eval 1 failed" % (self.task_id))
self.clean_wave_vcd() # some golden TBs may generate wave.vcd files
def run_Eval1(self):
silent = True
### Eval 1: Golden RTL checking
logger.info("Eval 1: Golden RTL checking begins")
self.Eval1_pass = self.run_testbench(self.Eval1_dir, self.TB_gen, self.DUT_golden, self.TB_gen_mode, self.pychecker_code, raise_when_fail=True, save_en=self.save_en)
logger.match_level(self.Eval1_pass, "positive", "failed", "Eval 1: Golden RTL checking %s!" % ("passed" if self.Eval1_pass else "failed"))
# my_log = logger.positive if self.Eval1_pass else logger.failed
# my_log("[%s] Eval 1: Golden RTL checking %s!" % (self.task_id, "passed" if self.Eval1_pass else "failed"))
self.Eval1_exist = True
def run_Eval2(self, mode:str="mutant"):
""" mode: "mutant" or "gptgen" """
silent = True
assert mode in ["mutant", "gptgen"], "Invalid mode in run_Eval2: " + mode
if mode == "mutant": # Eval2
print_str = "Eval 2: Golden TB checking on RTL mutants"
mutant_subdir_name = "mutant"
DUT_list = self.DUT_mutant_list
eval_dir = self.Eval2_dir
elif mode == "gptgen": # Eval2b
print_str = "Eval 2b: Golden TB checking on GPT generated RTL codes"
mutant_subdir_name = "gptgen_DUT"
DUT_list = self.DUT_gptgen_list
eval_dir = self.Eval2b_dir
### Eval 2: Golden TB comparison on RTL mutants
logger.info(print_str)
mutant_results = []
for idx, DUT_mutant in enumerate(DUT_list):
# mutant_subdir = eval_dir + "%s_%d/"%(mutant_subdir_name, idx+1)
mutant_subdir = os.path.join(eval_dir, "%s_%d"%(mutant_subdir_name, idx+1))
# GoldenTB_subsubdir = mutant_subdir + "GoldenTB/"
GoldenTB_subsubdir = os.path.join(mutant_subdir, "GoldenTB")
# GenedTB_subsubdir = mutant_subdir + "GeneratedTB/"
GenedTB_subsubdir = os.path.join(mutant_subdir, "GeneratedTB")
try: #in case the mutant has syntax error
TBgolden_pass = self.run_testbench(GoldenTB_subsubdir, self.TB_golden, DUT_mutant, "TB_golden", save_en=self.save_en)
except:
TBgolden_pass = False
try:
TBgen_pass = self.run_testbench(GenedTB_subsubdir, self.TB_gen, DUT_mutant, self.TB_gen_mode, self.pychecker_code, save_en=self.save_en)
except:
TBgen_pass = False
if not TBgolden_pass and not TBgen_pass:
mutant_pass = True
elif TBgolden_pass and TBgen_pass:
mutant_pass = True
else:
mutant_pass = False
mutant_results.append(mutant_pass)
eval_pass = all(mutant_results)
failed_mutant_idx = [idx + 1 for idx, result in enumerate(mutant_results) if not result]
passed_mutant_idx = [idx + 1 for idx, result in enumerate(mutant_results) if result]
if mode == "mutant":
self.Eval2_pass, self.Eval2_failed_mutant_idx, self.Eval2_passed_mutant_idx, self.Eval2_exist = eval_pass, failed_mutant_idx, passed_mutant_idx, True
elif mode == "gptgen":
self.Eval2b_pass, self.Eval2b_failed_mutant_idx, self.Eval2b_passed_mutant_idx, self.Eval2b_exist = eval_pass, failed_mutant_idx, passed_mutant_idx, True
result = "perfectly passed" if eval_pass else ("finished (%d/%d)" % (len(passed_mutant_idx), len(mutant_results)))
my_log = logger.success if (eval_pass or (len(passed_mutant_idx)/len(mutant_results)>=0.8)) else logger.failed
my_log("%s %s!" % (print_str, result))
def run_testbench(self, dir, TB_code, DUT_code, TB_type, pychecker_code = "", raise_when_fail = False, save_en = True):
"""
it has two mode: pychecker mode or verilog testbench mode
-input:
- dir: the dir to save the TB, DUT and pychecker code
- TB_code: str; the testbench code
- DUT_code: str; the DUT code
- TB_type: str: TB_gen, TB_golden, Pychecker
- pychecker_code: str; the pychecker code
- output:
- pass: bool; if the DUT passed the testbench
"""
# iverilog part
# save the TB and DUT
assert TB_type in ["TB_gen", "TB_golden", "Pychecker"], "Invalid TB_type in run_testbench: " + TB_type
os.makedirs(dir, exist_ok=True)
self.working_dir = dir
with open(self.TB_path, "w") as f:
f.write(TB_code)
with open(self.DUT_path, "w") as f:
f.write(DUT_code)
iv_run_info = iv.iverilog_call_and_save(dir, silent=True)
if raise_when_fail:
assert iv_run_info[0], "%s Iverilog Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Testbench!!!"%(TB_type)
# pychecker part (if enabled)
if TB_type == "Pychecker":
with open(self.PY_path, "w") as f:
f.write(pychecker_code)
py_run_info = py.python_call_and_save(pypath=self.PY_path, silent=True)
if raise_when_fail:
assert py_run_info[0], "%s Python Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Python code!!!"%(TB_type)
# check if the DUT passed the testbench
TC_pass = self.TC_pass_from_TC_out(sim_pass=True, sim_out=py_run_info[1]["out"], TB_type="Pychecker") & iv_run_info[0] & py_run_info[0]
else:
TC_pass = self.TC_pass_from_TC_out(sim_pass=True, sim_out=iv_run_info[4]["out"], TB_type=TB_type) & iv_run_info[0]
if not save_en:
# os.system(f"rm -rf {dir}")
cmd = f"find {dir} -type f ! -name 'run_info*'" + r" -exec rm -f {} +"
os.system(cmd)
return TC_pass
def clean_wave_vcd(self):
"""clean the .vcd files in the task_dir"""
# clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
clean_dir = self.task_dir
for root, dirs, files in os.walk(clean_dir):
for file in files:
# clean wave.vcd
if file.endswith(".vcd"):
os.remove(os.path.join(root, file))
@property
def TB_path(self):
# return self.working_dir + self.task_id + "_tb.v"
return os.path.join(self.working_dir, self.task_id + "_tb.v")
@property
def DUT_path(self):
# return self.working_dir + self.task_id + ".v"
return os.path.join(self.working_dir, self.task_id + ".v")
@property
def PY_path(self):
# return self.working_dir + self.task_id + "_tb.py"
return os.path.join(self.working_dir, self.task_id + "_tb.py")
@staticmethod
def TC_pass_from_TC_out(sim_pass: bool, sim_out: str, TB_type="TB_gen"):
"""
get the information if DUT passed all the test cases from the testbench
#### input
- sim_pass: bool; if TB passed the compilation. if not, will return False without check
- sim_out: the simulation output message;
- TB_ty: "TB_gen" or "TB_golden" or "Pychecker"; the type of the testbench
"""
if not sim_pass:
return False
assert TB_type in ["TB_gen", "TB_golden", "Pychecker"], "Invalid TB_type during 'TC_pass_from_TC_out': " + TB_type
tc_pass_check_list_dict = {"TB_gen": TC_PASS_CHECK_LIST_TB_GEN, "TB_golden": TC_PASS_CHECK_LIST_TB_GOLDEN, "Pychecker": TC_PASS_CHECK_LIST_PYCHECKER}
tc_pass_check_list = tc_pass_check_list_dict[TB_type]
if TB_type in ["TB_gen", "TB_golden"]:
for check_str in tc_pass_check_list:
if check_str in sim_out:
return True
return False
elif TB_type in ['Pychecker']:
# check if the last [] contains any element
# find the last ] in the out
last_bracket_end = sim_out.rfind("]")
# find the last [ in the out
last_bracket_start = sim_out.rfind("[")
# check if the last bracket pair is "[]", containing no element
if (last_bracket_end - last_bracket_start) == 1:
return True
else:
return False

537
autoline/TB_autoline.py Normal file
View File

@@ -0,0 +1,537 @@
"""
Description : The main function of autoline, originally the first part of autoline.py in AutoBench 1.0
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/24 11:44:15
LastEdited : 2024/9/1 10:32:18
"""
import os
import analyze as al
import loader_saver as ls
import time
from config import Config
from loader_saver import save_dict_json_form, log_localprefix
from data.probset import HDLBitsProbset
from loader_saver import autologger as logger
from utils.utils import Timer
from autoline.TB1_gen import TaskTBgen
from autoline.TB2_syncheck import TaskTBsim
from autoline.TB3_funccheck import TaskTBcheck
from autoline.TB4_eval import TaskTBeval
from prompt_scripts import BaseScript
from LLM_call import llm_manager
# [新增] 引入我们刚写的模块
from autoline.TB_cga import TaskTBCGA
def run_autoline():
# load config
config = Config()
autoline = AutoLine(config)
autoline()
class AutoLine():
"""the class of the autoline"""
def __init__(self, config: Config):
self.config = config
self.logger = logger
self.logger.assert_(config.get_item("autoline", "promptscript") is not None, "config.autoline.promptscript is None, please check the config file.")
self.load_data()
# set run info
# self.run_info_path = config.save.root + "Chatbench_RunInfo.json"
self.run_info_path = os.path.join(config.save.root, "Chatbench_RunInfo.json")
self.run_info = []
self.analyzer_en = (config.autoline.onlyrun is None) or (config.autoline.onlyrun == "TBgensimeval") # only run the analyzer when not in the onlyrun mode (partial run)
def run(self):
for idx, probdata_single in enumerate(self.probset.data):
task_id = probdata_single["task_id"]
self.logger.info("")
self.logger.info("######################### task %d/%d [%s] #########################" % (idx+1, self.probset.num, task_id))
# run_info_single = pipeline_one_prob(probdata_single, self.config)
one_task = AutoLine_Task(probdata_single, self.config)
run_info_single = one_task.run()
self.run_info.append(run_info_single)
# save run info: (write to file every iteration and will overwrite the previous one)
save_dict_json_form(self.run_info, self.run_info_path)
if self.analyzer_en:
self.run_analyzer()
def __call__(self, *args, **kwargs):
return self.run(*args, **kwargs)
def load_data(self):
cfg_probset = self.config.autoline.probset
self.probset = HDLBitsProbset()
self.probset.load_by_config(cfg_probset)
def run_analyzer(self):
analyzer = al.Analyzer(self.run_info, self.config.gpt.model)
analyzer.run()
logger.info(analyzer.messages)
class AutoLine_Task():
def __init__(self, prob_data:dict, config:Config):
# config:
self.config = config
# probdata:
self.prob_data = prob_data
self.main_model = self.config.gpt.model # The main llm model used in the autoline (generation, correction...)
self.task_id = prob_data["task_id"]
self.task_NO = prob_data["task_number"]
self.prob_description = prob_data["description"]
self.header = prob_data["header"]
self.DUT_golden = prob_data['module_code']
self.TB_golden = prob_data.get("testbench", None)
self.mutant_list = prob_data.get("mutants", None)
self.rtlgen_list = prob_data.get('llmgen_RTL', None)
self.rtlgen_model = self.config.gpt.rtlgen_model # if llmgen_list is none, this will be used
self.rtl_num = self.config.autoline.TBcheck.rtl_num # will be covered if llmgen_list is not None
# system config:
# self.task_dir = self.config.save.root + self.task_id + "/"
self.task_dir = os.path.join(self.config.save.root, self.task_id)
self.working_dir = self.task_dir
os.makedirs(self.task_dir, exist_ok=True)
# === [CGA Mod] Save DUT immediately to task dir for CGA access ===
self.dut_path = os.path.join(self.task_dir, "DUT.v")
ls.save_code(self.DUT_golden, self.dut_path)
# ==============================================================
self.update_desc = config.autoline.update_desc
self.error_interuption = config.autoline.error_interruption # for debug'
self.save_codes = config.autoline.save_finalcodes
self.save_compile = self.config.autoline.save_compile # save the compiling codes in TBcheck and TBeval or not.
# TBgen paras:
self.TBgen_prompt_script = config.autoline.promptscript
self.circuit_type = None
self.scenario_dict = None
self.scenario_num = None
self.checklist_worked = None
# TBcheck paras:
self.TBcheck_correct_max = self.config.autoline.TBcheck.correct_max
self.iter_max = config.autoline.itermax
self.discrim_mode = config.autoline.TBcheck.discrim_mode
self.correct_mode = config.autoline.TBcheck.correct_mode
self.rtl_compens_en = config.autoline.TBcheck.rtl_compens_en
self.rtl_compens_max_iter = config.autoline.TBcheck.rtl_compens_max_iter
self.cga_enabled = config.autoline.cga.enabled
# stages:
self.TBgen_manager:TaskTBgen = None
self.TBgen:BaseScript = None
self.TBsim:TaskTBsim = None
self.TBcheck:TaskTBcheck = None
self.TBeval:TaskTBeval = None
self.stage_now = "initialization"
# changing paras:
self.autoline_iter_now = 0
self.TB_code_v = None
self.TB_code_py = None
self.next_action = None
# results:
self.incomplete_running = True
self.full_pass = False
self.TB_corrected = False
self.run_info = {}
self.run_info_short = {}
self.TBcheck_rtl_newly_gen_num = 0 # in autoline, "funccheck" = "TBcheck"
self.op_record = [] # will record the order of each stage, for example: ["gen", "syncheck", "funccheck", "gen", "syncheck", "funccheck", "eval"]
self.funccheck_op_record = []
self.funccheck_iters = []
#初始化
self.cga_coverage = 0.0
# === [CGA Mod] Initialize result dictionary for final reporting ===
self.result_dict = {
"task_id": self.task_id,
"stage": "Init",
"pass": False,
"coverage": 0.0,
"cga_enabled": self.cga_enabled
}
# =================================================================
# renew current section of llm_manager and logger
llm_manager.new_section()
logger.set_temp_log()
def run(self):
"""
The main function of running the autoline for one problem
"""
with log_localprefix(self.task_id):
self.run_stages()
self.runinfo_update()
if self.save_codes:
self.save_TB_codes()
# === [CGA Mod] Save Result JSON for Analyzer ===
self.result_dict['stage'] = self.stage_now
try:
result_save_path = self.config.autoline.result_path
except AttributeError:
# 如果 config 对象没这个属性或者它是字典且没这个key
result_save_path = "results"
# 确保是绝对路径或相对于项目根目录
if not os.path.exists(result_save_path):
os.makedirs(result_save_path, exist_ok=True)
ls.save_dict_json_form(self.result_dict, os.path.join(result_save_path, f"{self.task_id}.json"))
# ===============================================
return self.run_info
def run_TBgen(self, subdir:str=None):
# TODO: export the circuit type and scenario number
self.op_record.append("gen")
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
self.stage_now = "TBgen"
self.TBgen_manager = TaskTBgen(self.prob_data, self.TBgen_prompt_script, working_dir, self.config)
self.TBgen = self.TBgen_manager.workflow
with log_localprefix("TBgen"):
self.TBgen()
self.TB_code_v = self.TBgen.get_attr("TB_code_v")
self.TB_code_py = self.TBgen.get_attr("TB_code_py")
self.scenario_dict = self.TBgen.get_attr("scenario_dict")
self.scenario_num = self.TBgen.get_attr("scenario_num")
self.circuit_type = self.TBgen.get_attr("circuit_type")
self.checklist_worked = self.TBgen.get_attr("checklist_worked")
self.incomplete_running = True
self._blank_log()
def run_TBsim(self, subdir:str=None):
self.op_record.append("syncheck")
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
self.stage_now = "TBsim"
self.TBsim = TaskTBsim(
self.TBgen,
self.TBgen.TB_code,
self.header,
working_dir,
self.task_id,
self.config
)
self.TBsim.run()
self.TB_code_v = self.TBsim.TB_code_now
self.TB_code_py = self.TBsim.PY_code_now
self._blank_log()
def run_TBcheck(self, subdir:str=None):
self.op_record.append("funccheck")
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
self.stage_now = "TBcheck"
self.TBcheck = TaskTBcheck(
task_dir = working_dir,
task_id = self.task_id,
description = self.prob_description,
module_header = self.header,
TB_code_v = self.TB_code_v,
TB_code_py = self.TB_code_py,
rtl_list = self.rtlgen_list,
rtl_num = self.rtl_num,
scenario_num = self.scenario_num,
correct_max = self.TBcheck_correct_max,
runfiles_save=self.save_compile,
discriminator_mode=self.discrim_mode,
corrector_mode=self.correct_mode,
circuit_type=self.circuit_type,
rtl_compens_en=self.rtl_compens_en,
rtl_compens_max_iter=self.rtl_compens_max_iter,
main_model = self.main_model,
rtlgen_model = self.rtlgen_model,
desc_improve=self.update_desc
)
self.rtlgen_list = self.TBcheck.rtl_list
self.TBcheck.run()
self.TB_code_v = self.TBcheck.TB_code_v
self.TB_code_py = self.TBcheck.TB_code_py
self.TB_corrected = self.TBcheck.corrected
self.funccheck_op_record.append(self.TBcheck.op_record)
self.funccheck_iters.append(self.TBcheck.iter_now)
self.TBcheck_rtl_newly_gen_num += self.TBcheck.rtl_newly_gen_num
self.next_action = self.TBcheck.next_action
if self.update_desc:
self.prob_data['description'] = self.TBcheck.update_description()
self.prob_description = self.prob_data['description']
self._blank_log()
def run_TBeval(self, subdir:str=None):
self.op_record.append("eval")
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
self.stage_now = "TBeval"
self.TBeval = TaskTBeval(
self.task_id,
working_dir,
TB_gen=self.TB_code_v,
TB_golden=self.TB_golden,
DUT_golden=self.DUT_golden,
DUT_mutant_list=self.mutant_list,
DUT_gptgen_list=None,
pychecker_en=self.TBsim.pychecker_en,
pychecker_code=self.TB_code_py,
runfiles_save=self.save_compile
)
# attention: the rtls in DUT_gptgen_list are not the same as the rtls used in TBcheck, so currently we just block this feature
try:
self.TBeval.run()
except:
logger.failed("error when running TBeval, the autoline for this task stopped.")
self.incomplete_running = True
self._blank_log()
# 在 run_TB4_eval 或其他方法旁边添加这个新方法
def run_TBCGA(self, work_subdir="CGA", optimize=True, op_name="cga"):
"""
Coverage-Guided Agent 阶段
"""
self.stage_now = "TBCGA"
self.op_record.append(op_name)
cga = TaskTBCGA(
task_dir=self.task_dir,
task_id=self.task_id,
header=self.header,
DUT_code=self.DUT_golden,
TB_code=self.TB_code_v,
config=self.config,
work_subdir=work_subdir,
max_iter=(self.config.autoline.cga.max_iter if optimize else 0)
)
# [修改] 接收分数
final_tb, final_score = cga.run()
self.cga_coverage = final_score
# 更新状态
self.TB_code_v = final_tb
self.result_dict['coverage'] = final_score
# [新增] 强制归档 final_TB.v 到工作目录
final_tb_path = os.path.join(self.task_dir, "final_TB.v")
ls.save_code(final_tb, final_tb_path)
logger.info(f"Saved optimized TB to: {final_tb_path}")
def run_stages(self):
with Timer(print_en=False) as self.running_time:
if not self.error_interuption:
self.run_stages_core()
else:
try:
self.run_stages_core()
except Exception as e:
self.incomplete_running = True
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
if self.error_interuption:
# if True, stop the pipeline
raise e
self.incomplete_running = False
def run_stages_core(self):
match self.config.autoline.onlyrun:
case "TBgen":
self.run_TBgen()
case "TBgensim":
self.run_TBgen()
self.run_TBsim()
# case _: # default, run all
case "TBgensimeval":
try:
self.run_TBgen("1_TBgen")
self.run_TBsim("2_TBsim")
self.run_TBeval("3_TBeval")
except Exception as e:
self.incomplete_running = True
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
else:
self.incomplete_running = False
case _: # default, run all
for i in range(self.iter_max):
self.autoline_iter_now = i
try:
self.run_TBgen(f"{i+1}_1_TBgen")
self.run_TBsim(f"{i+1}_2_TBsim")
self.run_TBcheck(f"{i+1}_3_TBcheck")
except Exception as e:
# logger.error(f"error when running {self.stage_now}, current pipeline iter: {i+1}, will {"REBOOT" if i<self.iter_max-1 else "go to NEXT STAGE"}. error message: {str(e)}")
# self.next_action = "reboot"
# continue
err_msg = str(e)
logger.error(f"Error when running {self.stage_now}, iter: {i+1}. Message: {err_msg}")
# === [关键修改API 降温冷静期] ===
# 如果是 iverilog 失败或 API 超时,强制休息 15 秒
# 这能有效防止阿里云 API 报 429 错误或连接被重置
logger.warning("⚠️ Pipeline interrupted. Cooling down for 15s to avoid API Rate Limit...")
time.sleep(15)
# ================================
# 如果配置里要求一报错就退出,则抛出异常
if getattr(self.config.autoline, 'error_interruption', False):
raise e
# 否则,标记为重启,准备进入下一次循环
self.next_action = "reboot"
self.incomplete_running = True # 标记当前运行不完整
continue
match self.next_action:
case "pass":
break
case "reboot":
continue
# === [CGA 插入点 START] ===
# 只有当任务状态正常,且没有要求重启时
if self.next_action == "pass":
# 在进入 CGA 前,手动标记当前状态为完成,防止内部逻辑误判
self.incomplete_running = False
try:
if self.cga_enabled:
self.run_TBCGA(work_subdir="CGA", optimize=True, op_name="cga")
else:
self.run_TBCGA(work_subdir="CGA_baseline", optimize=False, op_name="coverage_eval")
except Exception as e:
logger.error(f"CGA Stage Failed: {e}. Fallback to original TB.")
self.result_dict['error'] = str(e)
# === [CGA 插入点 END] ===
try:
self.run_TBeval(f"{self.autoline_iter_now+1}_4_TBeval")
except Exception as e:
self.incomplete_running = True
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
def runinfo_update(self):
# general
self.run_info = {
"task_id": self.task_id,
"task_number": self.task_NO,
"time": round(self.running_time.interval, 2),
"prompt_tokens": llm_manager.tokens_in_section,
"completion_tokens": llm_manager.tokens_out_section,
"token_cost": llm_manager.cost_section,
"ERROR(incomplete)": self.incomplete_running,
"op_record": self.op_record,
"reboot_times": self.autoline_iter_now,
"max_iter": self.iter_max,
# === [新增] 将覆盖率写入最终报告 ===
"coverage": self.cga_coverage
}
# token and cost from llm_manager
# TBgen
if self.TBgen is not None:
# self.run_info["prompt_tokens"] += self.TBgen.tokens["prompt"]
# self.run_info["completion_tokens"] += self.TBgen.tokens["completion"]
self.run_info["circuit_type"] = self.circuit_type
self.run_info["checklist_worked"] = self.checklist_worked
self.run_info["scenario_num"] = self.scenario_num
# TBsim
if self.TBsim is not None:
# self.run_info["prompt_tokens"] += self.TBsim.tokens["prompt"]
# self.run_info["completion_tokens"] += self.TBsim.tokens["completion"]
self.run_info.update({
"Eval0_pass": self.TBsim.Eval0_pass,
"Eval0_iv_pass": self.TBsim.sim_pass,
"debug_iter_iv": self.TBsim.debug_iter_iv_now,
"iv_runing_time": self.TBsim.iv_runing_time
})
if self.TBsim.pychecker_en:
self.run_info.update({
"Eval0_py_pass": self.TBsim.py_pass,
"debug_iter_py": self.TBsim.debug_iter_py_now,
"py_runing_time": self.TBsim.py_runing_time
})
# TODO: TBcheck runinfo update
if self.TBcheck is not None:
self.run_info.update({
"TB_corrected": self.TB_corrected,
"TBcheck_oprecord": self.funccheck_op_record,
"rtl_num_newly_gen": self.TBcheck_rtl_newly_gen_num
})
# TBeval
if self.TBeval is not None:
if self.TBeval.Eval1_exist:
self.run_info.update({"Eval1_pass": self.TBeval.Eval1_pass})
self.result_dict["Eval1_pass"] = self.TBeval.Eval1_pass
if self.TBeval.Eval2_exist:
self.run_info.update({
"Eval2_pass": self.TBeval.Eval2_pass,
"Eval2_ratio": "%d/%d"%(len(self.TBeval.Eval2_passed_mutant_idx), len(self.prob_data['mutants'])),
"Eval2_failed_mutant_idxes": self.TBeval.Eval2_failed_mutant_idx
})
self.result_dict.update({
"Eval2_pass": self.TBeval.Eval2_pass,
"Eval2_ratio": "%d/%d"%(len(self.TBeval.Eval2_passed_mutant_idx), len(self.prob_data['mutants'])),
"Eval2_failed_mutant_idxes": self.TBeval.Eval2_failed_mutant_idx
})
if self.TBeval.Eval2b_exist:
self.run_info.update({
"Eval2b_pass": self.TBeval.Eval2b_pass,
"Eval2b_ratio": "%d/%d"%(len(self.TBeval.Eval2b_passed_mutant_idx), len(self.prob_data['gptgen_RTL'])),
"Eval2b_failed_mutant_idxes": self.TBeval.Eval2b_failed_mutant_idx
})
self.result_dict.update({
"Eval2b_pass": self.TBeval.Eval2b_pass,
"Eval2b_ratio": "%d/%d"%(len(self.TBeval.Eval2b_passed_mutant_idx), len(self.prob_data['gptgen_RTL'])),
"Eval2b_failed_mutant_idxes": self.TBeval.Eval2b_failed_mutant_idx
})
# full pass
if not self.incomplete_running:
self.full_pass = self.TBsim.sim_pass and self.TBeval.Eval1_pass and self.TBeval.Eval2_pass
self.run_info.update({
"full_pass": self.full_pass
})
self.result_dict["full_pass"] = self.full_pass
self.result_dict["pass"] = self.full_pass
else:
self.result_dict["full_pass"] = False
self.result_dict["pass"] = False
self.result_dict["stage"] = self.stage_now
self.result_dict["coverage"] = self.cga_coverage
save_dict_json_form(self.run_info, os.path.join(self.task_dir, "run_info.json"))
# short run info
if "Eval2_ratio" in self.run_info.keys():
eval_progress = "Eval2 - " + self.run_info["Eval2_ratio"]
elif "Eval1_pass" in self.run_info.keys() and self.run_info["Eval1_pass"]:
eval_progress = "Eval1 - passed"
elif "Eval0_pass" in self.run_info.keys() and self.run_info["Eval0_pass"]:
eval_progress = "Eval1 - failed"
elif "Eval0_pass" in self.run_info.keys() and not self.run_info["Eval0_pass"]:
eval_progress = "Eval0 - failed"
else:
eval_progress = "Eval0 - not found"
self.run_info_short = {
"task_id": self.run_info.get("task_id", None),
"eval_progress": eval_progress,
"TB_corrected": self.run_info.get("TB_corrected", None),
"reboot_times": self.run_info.get("reboot_times", None),
"time": self.run_info.get("time", None),
"cost": self.run_info.get("token_cost", None),
}
save_dict_json_form(self.run_info_short, os.path.join(self.task_dir, "run_info_short.json"))
# run log
running_log = logger.reset_temp_log()
tasklog_path = os.path.join(self.task_dir, "task_log.log")
os.makedirs(os.path.dirname(tasklog_path), exist_ok=True)
with open(tasklog_path, "w") as f:
f.write(running_log)
return self.run_info
def save_TB_codes(self):
save_dir = self.task_dir
ls.save_code(self.TB_code_v if isinstance(self.TB_code_v, str) else "// TB code (Verilog) unavailable", os.path.join(save_dir, "final_TB.v"))
ls.save_code(self.TB_code_py if isinstance(self.TB_code_py, str) else "## TB code (Python) unavailable", os.path.join(save_dir, "final_TB.py"))
@staticmethod
def _blank_log():
logger.info("")
def __call__(self, *args, **kwargs):
return self.run(*args, **kwargs)

1453
autoline/TB_cga.py Normal file

File diff suppressed because it is too large Load Diff

22
autoline/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
"""
Description : Automatic pipeline of Chatbench: from HDLBits problem to simulation
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2023/12/7 15:13:00
LastEdited : 2024/8/16 13:37:31
autoline.py (c) 2023
"""
from autoline.TB_autoline import run_autoline
from autoline.TB1_gen import TaskTBgen
from autoline.TB2_syncheck import TaskTBsim
from autoline.TB3_funccheck import TaskTBcheck, TB_corrector, TB_discriminator
from autoline.TB4_eval import TaskTBeval
if __name__ == "__main__":
raise RuntimeError("you cannot run autoline.py directly!")
# probset = Probset("data/HDLBits/HDLBits_data.jsonl", "data/HDLBits/HDLBits_data_miniset_mutants.jsonl", "data/HDLBits/HDLBits_circuit_type.jsonl", exclude_tasks=['rule110'], filter_content={'circuit_type': 'SEQ'})
# print(probset.num)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

5444
autoline/cga_utils.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,601 @@
"""
Description : Diversity Constraint Injector (Layer 1)
- Analyze existing test sequences
- Detect overused patterns
- Generate diversity constraints for Prompt
- Recommend new test scenarios
Author : CGA Enhancement Project
Time : 2026/03/16
"""
import logging
import re
from typing import List, Dict, Optional, Any, Tuple, Set
from dataclasses import dataclass, field
from enum import Enum
# 支持两种导入方式:包导入和直接加载
try:
from .test_history import (
TestHistoryManager,
TestRecord,
InputSequence,
SequencePattern
)
except ImportError:
from test_history import (
TestHistoryManager,
TestRecord,
InputSequence,
SequencePattern
)
logger = logging.getLogger(__name__)
# ============================================================================
# 配置常量
# ============================================================================
class DiversityConfig:
"""多样性约束配置"""
# 编辑距离阈值
MIN_EDIT_DISTANCE = 3
# 模式过度使用阈值
OVERUSE_THRESHOLD = 3
# 新场景推荐数量
NEW_SCENARIO_COUNT = 3
# 序列长度限制(用于约束生成)
MAX_SEQUENCE_LENGTH = 10
# 多样性得分权重
PATTERN_WEIGHT = 0.4
EDIT_DISTANCE_WEIGHT = 0.3
COVERAGE_WEIGHT = 0.3
# ============================================================================
# 约束类型定义
# ============================================================================
class ConstraintType(Enum):
"""约束类型枚举"""
FORBID_SEQUENCE = "forbid_sequence" # 禁止特定序列
MIN_EDIT_DISTANCE = "min_edit_distance" # 最小编辑距离
AVOID_PATTERN = "avoid_pattern" # 革免模式
TRY_SCENARIO = "try_scenario" # 尝试新场景
EXPLORE_RANGE = "explore_range" # 探索范围
# ============================================================================
# 约束数据结构
# ============================================================================
@dataclass
class DiversityConstraint:
"""
多样性约束
Attributes:
constraint_type: 约束类型
description: 约束描述
details: 详细信息
priority: 优先级 (1-5, 5最高)
"""
constraint_type: ConstraintType
description: str
details: Dict[str, Any] = field(default_factory=dict)
priority: int = 3
def to_prompt_text(self) -> str:
"""转换为Prompt文本"""
if self.constraint_type == ConstraintType.FORBID_SEQUENCE:
return f"- AVOID using this sequence pattern: {self.details.get('pattern', 'unknown')}"
elif self.constraint_type == ConstraintType.MIN_EDIT_DISTANCE:
return f"- Your test sequence MUST differ from previous tests (edit distance >= {self.details.get('min_distance', 3)})"
elif self.constraint_type == ConstraintType.AVOID_PATTERN:
signal = self.details.get('signal', '')
pattern = self.details.get('pattern', '')
return f"- AVOID the pattern '{pattern}' for signal '{signal}' (already used {self.details.get('count', 0)} times)"
elif self.constraint_type == ConstraintType.TRY_SCENARIO:
return f"- TRY this new approach: {self.details.get('scenario', 'unknown')}"
elif self.constraint_type == ConstraintType.EXPLORE_RANGE:
return f"- EXPLORE values in range [{self.details.get('min', 0)}, {self.details.get('max', 255)}] for {self.details.get('signal', 'signal')}"
return f"- {self.description}"
# ============================================================================
# 序列分析器
# ============================================================================
class SequenceAnalyzer:
"""
序列分析器
分析输入序列的特征和模式
"""
@staticmethod
def extract_value_range(values: List[Tuple[int, Any]]) -> Tuple[Any, Any]:
"""提取值范围"""
if not values:
return (0, 0)
numeric_values = []
for _, v in values:
# 尝试转换为数值
if isinstance(v, (int, float)):
numeric_values.append(v)
elif isinstance(v, str):
# 处理 '0, '1, 'x 等
if v in ['0', '1', 'x', 'z']:
numeric_values.append(int(v) if v.isdigit() else 0)
# 处理带位宽的值
match = re.match(r"(\d+)'[bdh]([0-9a-fA-fA-FxXzZ_]+)", v)
if match:
try:
numeric_values.append(int(match.group(2), 16))
except:
pass
if numeric_values:
return (min(numeric_values), max(numeric_values))
return (0, 0)
@staticmethod
def detect_transition_pattern(values: List[Tuple[int, Any]]) -> str:
"""检测转换模式"""
if len(values) < 2:
return "single"
# 提取值序列
val_seq = [v for _, v in values]
# 检测递增
if all(str(val_seq[i]) <= str(val_seq[i+1]) for i in range(len(val_seq)-1)):
return "incremental"
# 检测递减
if all(str(val_seq[i]) >= str(val_seq[i+1]) for i in range(len(val_seq)-1)):
return "decremental"
# 检测交替
if len(val_seq) >= 4:
if val_seq[0] == val_seq[2] and val_seq[1] == val_seq[3]:
return "alternating"
# 检测脉冲(单个变化后恢复)
if len(val_seq) == 3 and val_seq[0] == val_seq[2] != val_seq[1]:
return "pulse"
return "random"
@staticmethod
def calculate_sequence_length(code: str) -> int:
"""计算代码中的操作序列长度"""
# 统计赋值语句数量
assignments = len(re.findall(r'\w+\s*=\s*\S+\s*;', code))
# 统计repeat语句
repeats = re.findall(r'repeat\s*\(\s*(\d+)\s*\)', code)
repeat_cycles = sum(int(r) for r in repeats)
return assignments + repeat_cycles
# ============================================================================
# 场景推荐器
# ============================================================================
class ScenarioRecommender:
"""
场景推荐器
根据历史记录和未覆盖功能点推荐新测试场景
"""
# 场景模板
SCENARIO_TEMPLATES = {
'fsm': [
"Test state transition from {state_a} to {state_b}",
"Test illegal state transition handling",
"Test state machine reset behavior",
"Test state holding under stable inputs"
],
'counter': [
"Test counter overflow behavior (count to max value)",
"Test counter underflow (if applicable)",
"Test counter reset during counting",
"Test counter enable/disable control"
],
'branch': [
"Test boundary condition: {condition} at threshold",
"Test all branches of nested if-else",
"Test case statement with all possible values"
],
'protocol': [
"Test handshake timeout scenario",
"Test back-to-back transactions",
"Test protocol violation handling"
],
'general': [
"Apply random input patterns for extended duration",
"Test with boundary values (all 0s, all 1s)",
"Test rapid signal transitions",
"Test power-on/reset sequence variations"
]
}
def __init__(self, history_manager: TestHistoryManager):
self.history = history_manager
def recommend_scenarios(self,
uncovered_functions: List[Dict],
covered_patterns: Set[str] = None) -> List[str]:
"""
推荐新的测试场景
Args:
uncovered_functions: 未覆盖的功能点列表
covered_patterns: 已覆盖的模式集合
Returns:
推荐场景列表
"""
recommendations = []
covered_patterns = covered_patterns or set()
# 基于未覆盖功能点推荐
for func in uncovered_functions[:3]:
func_type = func.get('type', 'general')
func_name = func.get('name', '')
templates = self.SCENARIO_TEMPLATES.get(func_type, self.SCENARIO_TEMPLATES['general'])
for template in templates[:1]: # 每个功能点取一个模板
scenario = self._fill_template(template, func)
if scenario not in covered_patterns:
recommendations.append(scenario)
# 基于历史分析推荐
if self.history.records:
# 分析已使用的场景类型
used_patterns = set()
for record in self.history.records:
for seq in record.input_sequences:
pattern = SequenceAnalyzer.detect_transition_pattern(seq.values)
used_patterns.add(pattern)
# 推荐未使用的场景类型
all_patterns = {'incremental', 'decremental', 'alternating', 'pulse', 'random'}
unused_patterns = all_patterns - used_patterns
if unused_patterns:
recommendations.append(f"Try {list(unused_patterns)[0]} input pattern (different from your usual approach)")
# 确保有足够的推荐
while len(recommendations) < DiversityConfig.NEW_SCENARIO_COUNT:
recommendations.append("Explore a completely different input sequence than before")
return recommendations[:DiversityConfig.NEW_SCENARIO_COUNT]
def _fill_template(self, template: str, func: Dict) -> str:
"""填充场景模板"""
result = template
# 替换占位符
if '{state_a}' in template or '{state_b}' in template:
states = func.get('states', ['STATE_A', 'STATE_B'])
if len(states) >= 2:
result = result.replace('{state_a}', states[0])
result = result.replace('{state_b}', states[1])
if '{condition}' in template:
condition = func.get('condition', 'signal')
result = result.replace('{condition}', condition)
return result
# ============================================================================
# 约束生成器
# ============================================================================
class ConstraintGenerator:
"""
约束生成器
根据历史分析生成多样性约束
"""
def __init__(self, history_manager: TestHistoryManager):
self.history = history_manager
self.analyzer = SequenceAnalyzer()
def generate_constraints(self,
target_function: str = None,
uncovered_functions: List[Dict] = None) -> List[DiversityConstraint]:
"""
生成多样性约束
Args:
target_function: 当前目标功能点
uncovered_functions: 未覆盖功能点列表
Returns:
约束列表
"""
constraints = []
if not self.history.records:
return constraints
# 1. 生成过度使用模式约束
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
for pattern in overused[:3]: # 最多3个
constraints.append(DiversityConstraint(
constraint_type=ConstraintType.AVOID_PATTERN,
description=f"Avoid overused pattern for {pattern.signal_name}",
details={
'signal': pattern.signal_name,
'pattern': pattern.pattern,
'count': pattern.count
},
priority=5
))
# 2. 生成编辑距离约束
recent_count = min(5, len(self.history.records))
if recent_count > 00:
constraints.append(DiversityConstraint(
constraint_type=ConstraintType.MIN_EDIT_DISTANCE,
description="Maintain minimum edit distance from recent tests",
details={
'min_distance': DiversityConfig.MIN_EDIT_DISTANCE,
'reference_count': recent_count
},
priority=4
))
# 3. 生成值范围探索约束
if uncovered_functions:
for func in uncovered_functions[:2]:
# 根据功能点类型生成范围约束
if func.get('type') == 'counter':
max_val = func.get('max_value', 255)
constraints.append(DiversityConstraint(
constraint_type=ConstraintType.EXPLORE_RANGE,
description=f"Explore counter boundary values",
details={
'signal': func.get('name', 'counter'),
'min': 0,
'max': max_val
},
priority=3
))
# 按优先级排序
constraints.sort(key=lambda c: c.priority, reverse=True)
return constraints
def generate_forbidden_sequence_prompt(self) -> str:
"""生成禁止序列提示"""
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
if not overused:
return ""
lines = ["[DIVERSITY CONSTRAINTS - AVOID THESE OVERUSED PATTERNS]"]
for i, pattern in enumerate(overused[:5], 1):
lines.append(f"{i}. Signal '{pattern.signal_name}': {pattern.pattern[:50]}")
lines.append(f" (This pattern has been used {pattern.count} times already)")
lines.append("\nPlease create a DIFFERENT input sequence to improve test diversity.")
return "\n".join(lines)
# ============================================================================
# 多样性约束注入器(主入口)
# ============================================================================
class DiversityInjector:
"""
多样性约束注入器 - 第1层主入口
整合序列分析、模式检测、约束生成,提供统一的多样性约束接口
"""
def __init__(self, history_manager: TestHistoryManager = None):
"""
Args:
history_manager: 测试历史管理器
"""
self.history = history_manager or TestHistoryManager()
self.constraint_generator = ConstraintGenerator(self.history)
self.scenario_recommender = ScenarioRecommender(self.history)
def inject_diversity_constraints(self,
prompt: str,
target_function: str = None,
uncovered_functions: List[Dict] = None) -> str:
"""
将多样性约束注入到Prompt中
Args:
prompt: 废始Prompt
target_function: 当前目标功能点
uncovered_functions: 未覆盖功能点列表
Returns:
注入约束后的Prompt
"""
if not self.history.records:
return prompt # 没有历史记录时不注入
# 生成约束
constraints = self.constraint_generator.generate_constraints(
target_function=target_function,
uncovered_functions=uncovered_functions
)
# 生成推荐场景
recommendations = self.scenario_recommender.recommend_scenarios(
uncovered_functions=uncovered_functions or []
)
# 构建约束文本
constraint_text = self._build_constraint_section(constraints, recommendations)
# 找到插入点(在 [OUTPUT REQUIREMENTS] 之前插入)
insert_marker = "[OUTPUT REQUIREMENTS"
if insert_marker in prompt:
parts = prompt.split(insert_marker, 1)
enhanced_prompt = parts[0] + constraint_text + "\n\n" + insert_marker + parts[1]
else:
# 如果找不到标记,追加到末尾
enhanced_prompt = prompt + "\n\n" + constraint_text
return enhanced_prompt
def _build_constraint_section(self,
constraints: List[DiversityConstraint],
recommendations: List[str]) -> str:
"""构建约束章节"""
lines = []
lines.append("[DIVERSITY CONSTRAINTS - CRITICAL]")
lines.append("To improve test effectiveness, follow these diversity requirements:")
lines.append("")
# 添加约束
for constraint in constraints:
lines.append(constraint.to_prompt_text())
lines.append("")
# 添加推荐场景
if recommendations:
lines.append("[RECOMMENDED NEW APPROACHES]")
for i, rec in enumerate(recommendations, 1):
lines.append(f"{i}. {rec}")
lines.append("")
lines.append("IMPORTANT: Repeated test patterns reduce coverage improvement efficiency.")
lines.append("Generate a DISTINCTLY DIFFERENT test sequence from previous attempts.")
return "\n".join(lines)
def get_diversity_context(self) -> str:
"""获取多样性上下文信息用于Prompt"""
if not self.history.records:
return ""
stats = self.history.get_statistics()
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
context_lines = []
context_lines.append(f"Test History: {stats['total_tests']} tests generated")
context_lines.append(f"Unique Patterns: {stats['total_patterns']}")
if overused:
context_lines.append(f"Overused Patterns: {len(overused)} (avoid these)")
return "\n".join(context_lines)
def evaluate_diversity(self,
new_code: str,
known_signals: List[str] = None) -> Dict[str, float]:
"""
评估新代码的多样性
Args:
new_code: 新生成的测试代码
known_signals: 已知信号列表
Returns:
多样性评估结果
"""
results = {}
# 1. 序列多样性
if known_signals:
self.history.sequence_extractor.set_known_signals(known_signals)
new_sequences = self.history.sequence_extractor.extract(new_code)
results['sequence_diversity'] = self.history.calculate_sequence_diversity(new_sequences)
# 2. 编辑距离多样性
results['edit_distance_diversity'] = self.history.calculate_edit_distance_diversity(new_code)
# 3. 综合得分
results['overall_diversity'] = (
DiversityConfig.PATTERN_WEIGHT * results['sequence_diversity'] +
DiversityConfig.EDIT_DISTANCE_WEIGHT * results['edit_distance_diversity']
)
return results
def record_test(self,
code: str,
target_function: str = "",
coverage_score: float = 0.0,
success: bool = False,
iteration: int = 0,
known_signals: List[str] = None) -> TestRecord:
"""
记录新的测试用例
Args:
code: 测试代码
target_function: 目标功能点
coverage_score: 覆盖率分数
success: 是否成功
iteration: 迭代次数
known_signals: 已知信号列表
Returns:
测试记录
"""
return self.history.add_record(
code=code,
target_function=target_function,
coverage_score=coverage_score,
success=success,
iteration=iteration,
known_signals=known_signals
)
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return self.history.get_statistics()
def generate_diversity_report(self) -> str:
"""生成多样性报告"""
return self.history.get_diversity_report()
# ============================================================================
# 便捷函数
# ============================================================================
def create_diversity_injector(history_file: str = None) -> DiversityInjector:
"""
创建多样性约束注入器
Args:
history_file: 屆史记录文件路径
Returns:
初始化完成的多样性约束注入器
"""
history_manager = TestHistoryManager(history_file=history_file)
return DiversityInjector(history_manager=history_manager)

View File

@@ -0,0 +1,787 @@
"""
Description : Energy Allocation Layer (Layer 4)
- Adaptive Resource Scheduling
- Dynamic energy distribution based on function point importance
Author : CGA Enhancement Project
Time : 2026/03/11
"""
import logging
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
# ============================================================================
# 数据结构定义
# ============================================================================
class EnergyState(Enum):
"""能量状态枚举"""
ACTIVE = "active" # 活跃,有剩余能量
DEPLETED = "depleted" # 能量耗尽
COMPLETED = "completed" # 已完成覆盖
SUSPENDED = "suspended" # 暂停(连续失败过多)
@dataclass
class EnergyAllocation:
"""
能量分配记录
Attributes:
function_point: 功能点名称
importance: 重要性评分 (0.0 - 1.0)
allocated: 分配的总能量
consumed: 已消耗的能量
remaining: 剩余能量
consecutive_failures: 连续失败次数
state: 当前能量状态
"""
function_point: str
importance: float
allocated: float = 0.0
consumed: float = 0.0
remaining: float = 0.0
consecutive_failures: int = 0
state: EnergyState = EnergyState.ACTIVE
total_attempts: int = 0
successful_attempts: int = 0
@dataclass
class GenerationResult:
"""
生成结果记录
Attributes:
function_point: 目标功能点
success: 是否成功覆盖
coverage_delta: 覆盖率变化
energy_cost: 消耗的能量
code_generated: 生成的代码
quality_score: 代码质量分数
"""
function_point: str
success: bool
coverage_delta: float = 0.0
energy_cost: float = 1.0
code_generated: str = ""
quality_score: float = 0.0
# ============================================================================
# 能量初始化器
# ============================================================================
class EnergyInitializer:
"""
能量初始化器
根据总能量预算和功能点重要性评分,初始化各功能点的能量分配
"""
# 默认配置
DEFAULT_TOTAL_ENERGY = 10.0 # 默认总能量(对应最大迭代次数)
MIN_ENERGY_PER_FP = 1.0 # 每个功能点最小能量
ENERGY_BUFFER_RATIO = 0.1 # 能量缓冲比例(保留用于重分配)
def __init__(self,
total_energy: float = None,
min_energy: float = None,
buffer_ratio: float = None):
"""
Args:
total_energy: 总能量预算(默认为 max_iter
min_energy: 每个功能点最小能量
buffer_ratio: 能量缓冲比例
"""
self.total_energy = total_energy or self.DEFAULT_TOTAL_ENERGY
self.min_energy = min_energy or self.MIN_ENERGY_PER_FP
self.buffer_ratio = buffer_ratio or self.ENERGY_BUFFER_RATIO
def initialize(self,
function_points: List[Dict],
max_iterations: int = None) -> Dict[str, EnergyAllocation]:
"""
初始化能量分配
Args:
function_points: 功能点列表,每个元素包含 name, importance, covered 等
max_iterations: 最大迭代次数(用于设置总能量)
Returns:
功能点名称 -> 能量分配记录 的字典
"""
# 如果提供了最大迭代次数,使用它作为总能量
if max_iterations:
self.total_energy = float(max_iterations)
# 过滤出未覆盖的功能点
uncovered_fps = [fp for fp in function_points if not fp.get('covered', False)]
if not uncovered_fps:
logger.info("All function points are covered. No energy allocation needed.")
return {}
# 计算总重要性
total_importance = sum(fp.get('importance', 0.5) for fp in uncovered_fps)
# 预留缓冲能量
buffer_energy = self.total_energy * self.buffer_ratio
available_energy = self.total_energy - buffer_energy
# 按重要性比例分配能量
allocations = {}
for fp in uncovered_fps:
name = fp.get('name', 'unknown')
importance = fp.get('importance', 0.5)
# 按比例计算分配能量,但不少于最小值
if total_importance > 0:
proportional_energy = (importance / total_importance) * available_energy
else:
proportional_energy = available_energy / len(uncovered_fps)
allocated = max(self.min_energy, proportional_energy)
allocations[name] = EnergyAllocation(
function_point=name,
importance=importance,
allocated=allocated,
consumed=0.0,
remaining=allocated,
consecutive_failures=0,
state=EnergyState.ACTIVE,
total_attempts=0,
successful_attempts=0
)
# 记录分配情况
logger.info(f"Energy initialized: total={self.total_energy:.1f}, "
f"allocated={sum(a.allocated for a in allocations.values()):.1f}, "
f"buffer={buffer_energy:.1f}, "
f"targets={len(allocations)}")
return allocations
# ============================================================================
# 目标选择器
# ============================================================================
class TargetSelector:
"""
目标选择器
选择下一个需要生成测试的目标功能点
采用优先级策略:重要性 × (剩余能量/分配能量)
"""
# 连续失败阈值
MAX_CONSECUTIVE_FAILURES = 3
def __init__(self, allocations: Dict[str, EnergyAllocation]):
"""
Args:
allocations: 能量分配字典
"""
self.allocations = allocations
def select_next_target(self) -> Optional[EnergyAllocation]:
"""
选择下一个目标功能点
优先级计算importance × (remaining / allocated) × (1 / (1 + consecutive_failures))
Returns:
选中的能量分配记录,如果没有可用目标则返回 None
"""
# 筛选候选:未覆盖、有剩余能量、非暂停状态
candidates = [
alloc for alloc in self.allocations.values()
if alloc.state == EnergyState.ACTIVE
and alloc.remaining > 0
]
if not candidates:
logger.info("No active targets with remaining energy.")
return None
# 计算优先级并排序
def calculate_priority(alloc: EnergyAllocation) -> float:
# 重要性权重
importance_weight = alloc.importance
# 能量剩余比例
energy_ratio = alloc.remaining / alloc.allocated if alloc.allocated > 0 else 0
# 失败惩罚因子
failure_penalty = 1.0 / (1.0 + alloc.consecutive_failures * 0.5)
# 综合优先级
priority = importance_weight * energy_ratio * failure_penalty
return priority
candidates.sort(key=calculate_priority, reverse=True)
selected = candidates[0]
logger.debug(f"Selected target: {selected.function_point} "
f"(importance={selected.importance:.2f}, "
f"remaining={selected.remaining:.1f}, "
f"failures={selected.consecutive_failures})")
return selected
def get_candidates_count(self) -> int:
"""获取候选目标数量"""
return len([a for a in self.allocations.values()
if a.state == EnergyState.ACTIVE and a.remaining > 0])
def get_top_candidates(self, n: int = 3) -> List[EnergyAllocation]:
"""获取优先级最高的 N 个候选目标"""
candidates = [
alloc for alloc in self.allocations.values()
if alloc.state == EnergyState.ACTIVE and alloc.remaining > 0
]
def calculate_priority(alloc: EnergyAllocation) -> float:
importance_weight = alloc.importance
energy_ratio = alloc.remaining / alloc.allocated if alloc.allocated > 0 else 0
failure_penalty = 1.0 / (1.0 + alloc.consecutive_failures * 0.5)
return importance_weight * energy_ratio * failure_penalty
candidates.sort(key=calculate_priority, reverse=True)
return candidates[:n]
# ============================================================================
# 能量消耗跟踪器
# ============================================================================
class EnergyConsumptionTracker:
"""
能量消耗跟踪器
跟踪每次生成尝试的能量消耗,根据结果更新状态
"""
# 能量衰减因子(连续失败时)
ENERGY_DECAY_FACTOR = 0.7
def __init__(self, allocations: Dict[str, EnergyAllocation]):
"""
Args:
allocations: 能量分配字典
"""
self.allocations = allocations
self.history: List[GenerationResult] = []
def record_generation(self, result: GenerationResult) -> Dict[str, Any]:
"""
记录一次生成尝试
Args:
result: 生成结果
Returns:
更新后的状态信息
"""
self.history.append(result)
fp_name = result.function_point
if fp_name not in self.allocations:
logger.warning(f"Unknown function point: {fp_name}")
return {'status': 'unknown', 'message': 'Unknown function point'}
alloc = self.allocations[fp_name]
alloc.total_attempts += 1
# 消耗能量
energy_cost = result.energy_cost
alloc.consumed += energy_cost
alloc.remaining = max(0, alloc.remaining - energy_cost)
if result.success:
# 成功:重置失败计数,标记完成
alloc.consecutive_failures = 0
alloc.successful_attempts += 1
alloc.state = EnergyState.COMPLETED
logger.info(f"[SUCCESS] Target covered: {fp_name} (attempts={alloc.total_attempts}, "
f"energy_used={alloc.consumed:.1f})")
return {
'status': 'completed',
'function_point': fp_name,
'attempts': alloc.total_attempts,
'energy_used': alloc.consumed
}
else:
# 失败:增加失败计数
alloc.consecutive_failures += 1
# 检查是否需要降低能量或暂停
if alloc.consecutive_failures >= 3:
# 能量减半
old_remaining = alloc.remaining
alloc.remaining *= self.ENERGY_DECAY_FACTOR
logger.warning(f"Consecutive failures for {fp_name}: {alloc.consecutive_failures}. "
f"Energy reduced: {old_remaining:.1f} -> {alloc.remaining:.1f}")
# 如果剩余能量过低,暂停
if alloc.remaining < 0.5:
alloc.state = EnergyState.SUSPENDED
logger.warning(f"Target suspended due to low energy: {fp_name}")
return {
'status': 'suspended',
'function_point': fp_name,
'consecutive_failures': alloc.consecutive_failures,
'remaining_energy': alloc.remaining
}
# 检查能量是否耗尽
if alloc.remaining <= 0:
alloc.state = EnergyState.DEPLETED
logger.warning(f"Target depleted: {fp_name}")
return {
'status': 'depleted',
'function_point': fp_name,
'total_attempts': alloc.total_attempts
}
return {
'status': 'failed',
'function_point': fp_name,
'consecutive_failures': alloc.consecutive_failures,
'remaining_energy': alloc.remaining
}
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
total = len(self.history)
successful = sum(1 for r in self.history if r.success)
energy_by_fp = {}
for result in self.history:
fp = result.function_point
if fp not in energy_by_fp:
energy_by_fp[fp] = {'consumed': 0, 'attempts': 0, 'success': False}
energy_by_fp[fp]['consumed'] += result.energy_cost
energy_by_fp[fp]['attempts'] += 1
if result.success:
energy_by_fp[fp]['success'] = True
return {
'total_attempts': total,
'successful_attempts': successful,
'success_rate': successful / total if total > 0 else 0,
'energy_by_function_point': energy_by_fp
}
# ============================================================================
# 能量重分配器
# ============================================================================
class EnergyRedistributor:
"""
能量重分配器
当某个功能点被覆盖后,将其剩余能量重新分配给其他未覆盖功能点
"""
def __init__(self, allocations: Dict[str, EnergyAllocation]):
"""
Args:
allocations: 能量分配字典
"""
self.allocations = allocations
def redistribute(self, completed_fp: str) -> Dict[str, float]:
"""
重分配已完成功能点的剩余能量
Args:
completed_fp: 已完成的功能点名称
Returns:
重分配详情 {target_fp: gained_energy}
"""
if completed_fp not in self.allocations:
return {}
completed_alloc = self.allocations[completed_fp]
# 回收剩余能量
recovered_energy = completed_alloc.remaining
if recovered_energy <= 0:
logger.debug(f"No remaining energy to recover from {completed_fp}")
return {}
# 找出活跃的未完成目标
active_targets = [
alloc for alloc in self.allocations.values()
if alloc.state == EnergyState.ACTIVE and alloc.function_point != completed_fp
]
if not active_targets:
logger.info(f"No active targets to redistribute energy to.")
return {}
# 按重要性比例分配
total_importance = sum(a.importance for a in active_targets)
redistribution = {}
for alloc in active_targets:
if total_importance > 0:
gain = (alloc.importance / total_importance) * recovered_energy
else:
gain = recovered_energy / len(active_targets)
alloc.allocated += gain
alloc.remaining += gain
redistribution[alloc.function_point] = gain
# 清零已完成目标的剩余能量
completed_alloc.remaining = 0
logger.info(f"Redistributed {recovered_energy:.1f} energy from {completed_fp} "
f"to {len(redistribution)} targets")
return redistribution
def redistribute_all(self) -> Dict[str, Dict[str, float]]:
"""
重分配所有已完成/暂停目标的剩余能量
Returns:
完整的重分配详情
"""
all_redistributions = {}
# 收集所有可回收能量
completed_fps = [
name for name, alloc in self.allocations.items()
if alloc.state in [EnergyState.COMPLETED, EnergyState.SUSPENDED]
and alloc.remaining > 0
]
for fp in completed_fps:
redistribution = self.redistribute(fp)
if redistribution:
all_redistributions[fp] = redistribution
return all_redistributions
def revive_suspended(self, min_energy: float = 1.0) -> List[str]:
"""
尝试复活暂停的目标(如果有足够的回收能量)
Args:
min_energy: 复活所需的最小能量
Returns:
复活的目标列表
"""
revived = []
# 计算可用能量(来自已完成目标)
available_energy = sum(
alloc.remaining for alloc in self.allocations.values()
if alloc.state == EnergyState.COMPLETED and alloc.remaining > 0
)
# 找出暂停的目标
suspended = [
alloc for alloc in self.allocations.values()
if alloc.state == EnergyState.SUSPENDED
]
for alloc in suspended:
if available_energy >= min_energy:
# 复活
alloc.state = EnergyState.ACTIVE
alloc.remaining = min_energy
alloc.allocated += min_energy
alloc.consecutive_failures = 0
available_energy -= min_energy
revived.append(alloc.function_point)
logger.info(f"Revived suspended target: {alloc.function_point}")
return revived
# ============================================================================
# 能量分配器(主入口)
# ============================================================================
class EnergyAllocator:
"""
能量分配器 - 第4层主入口
整合所有子模块,提供统一的能量管理接口
"""
def __init__(self,
max_iterations: int = 5,
total_energy: float = None):
"""
Args:
max_iterations: 最大迭代次数
total_energy: 总能量预算(默认使用 max_iterations
"""
self.max_iterations = max_iterations
self.total_energy = total_energy or float(max_iterations)
# 子模块
self.initializer = EnergyInitializer(total_energy=self.total_energy)
self.allocations: Dict[str, EnergyAllocation] = {}
self.selector: Optional[TargetSelector] = None
self.tracker: Optional[EnergyConsumptionTracker] = None
self.redistributor: Optional[EnergyRedistributor] = None
# 状态
self.initialized = False
self.current_target: Optional[EnergyAllocation] = None
def initialize(self, function_points: List[Dict]) -> Dict[str, Any]:
"""
初始化能量分配
Args:
function_points: 功能点列表
Returns:
初始化结果摘要
"""
self.allocations = self.initializer.initialize(
function_points,
max_iterations=self.max_iterations
)
self.selector = TargetSelector(self.allocations)
self.tracker = EnergyConsumptionTracker(self.allocations)
self.redistributor = EnergyRedistributor(self.allocations)
self.initialized = True
return {
'total_energy': self.total_energy,
'targets': len(self.allocations),
'allocation_details': {
name: {
'importance': alloc.importance,
'allocated': alloc.allocated,
'state': alloc.state.value
}
for name, alloc in self.allocations.items()
}
}
def select_next_target(self) -> Optional[str]:
"""
选择下一个生成目标
Returns:
目标功能点名称,如果没有可用目标则返回 None
"""
if not self.initialized:
logger.warning("Energy allocator not initialized.")
return None
self.current_target = self.selector.select_next_target()
return self.current_target.function_point if self.current_target else None
def record_generation(self,
success: bool,
coverage_delta: float = 0.0,
energy_cost: float = 1.0,
quality_score: float = 0.0) -> Dict[str, Any]:
"""
记录一次生成尝试
Args:
success: 是否成功覆盖目标
coverage_delta: 覆盖率变化
energy_cost: 消耗的能量
quality_score: 代码质量分数
Returns:
更新结果
"""
if not self.current_target:
return {'status': 'error', 'message': 'No current target'}
result = GenerationResult(
function_point=self.current_target.function_point,
success=success,
coverage_delta=coverage_delta,
energy_cost=energy_cost,
quality_score=quality_score
)
update_result = self.tracker.record_generation(result)
# 如果成功,触发重分配
if success:
self.redistributor.redistribute(self.current_target.function_point)
return update_result
def get_status(self) -> Dict[str, Any]:
"""获取当前状态"""
if not self.initialized:
return {'initialized': False}
active_count = sum(1 for a in self.allocations.values()
if a.state == EnergyState.ACTIVE and a.remaining > 0)
completed_count = sum(1 for a in self.allocations.values()
if a.state == EnergyState.COMPLETED)
return {
'initialized': True,
'total_energy': self.total_energy,
'total_targets': len(self.allocations),
'active_targets': active_count,
'completed_targets': completed_count,
'current_target': self.current_target.function_point if self.current_target else None,
'statistics': self.tracker.get_statistics() if self.tracker else None
}
def get_target_context(self, target_name: str = None) -> str:
"""
获取目标功能的上下文信息(用于 Prompt
Args:
target_name: 目标名称(默认使用当前目标)
Returns:
上下文字符串
"""
if not target_name and self.current_target:
target_name = self.current_target.function_point
if not target_name or target_name not in self.allocations:
return ""
alloc = self.allocations[target_name]
context = []
context.append(f"[TARGET: {target_name}]")
context.append(f"Importance: {alloc.importance:.2f}")
context.append(f"Remaining Energy: {alloc.remaining:.1f} / {alloc.allocated:.1f}")
context.append(f"Previous Attempts: {alloc.total_attempts}")
if alloc.consecutive_failures > 0:
context.append(f"Warning: {alloc.consecutive_failures} consecutive failures")
context.append("Consider a different approach or sequence")
return "\n".join(context)
def mark_targets_completed(self, function_names: List[str]) -> Dict[str, str]:
"""
将已确认覆盖的功能点直接标记为完成。
这用于基线同步或一次迭代中命中多个功能点的情况,
避免仅依赖当前 target 的涨分信号来判断完成状态。
"""
if not self.initialized:
return {}
updates = {}
for name in function_names:
if name not in self.allocations:
continue
alloc = self.allocations[name]
if alloc.state == EnergyState.COMPLETED:
updates[name] = "already_completed"
continue
alloc.state = EnergyState.COMPLETED
alloc.consecutive_failures = 0
alloc.remaining = 0.0
updates[name] = "completed"
self.redistributor.redistribute(name)
return updates
def generate_report(self) -> str:
"""生成能量分配报告"""
if not self.initialized:
return "Energy allocator not initialized."
lines = []
lines.append("=" * 60)
lines.append("ENERGY ALLOCATION REPORT")
lines.append("=" * 60)
lines.append(f"Total Energy: {self.total_energy:.1f}")
lines.append(f"Max Iterations: {self.max_iterations}")
lines.append("")
lines.append("FUNCTION POINT STATUS:")
lines.append("-" * 60)
for name, alloc in sorted(self.allocations.items(),
key=lambda x: x[1].importance, reverse=True):
status_icon = {
EnergyState.ACTIVE: "🔄",
EnergyState.COMPLETED: "",
EnergyState.DEPLETED: "",
EnergyState.SUSPENDED: "⏸️"
}.get(alloc.state, "")
efficiency = (alloc.successful_attempts / alloc.total_attempts * 100
if alloc.total_attempts > 0 else 0)
lines.append(f"{status_icon} {name}")
lines.append(f" Importance: {alloc.importance:.2f} | "
f"Energy: {alloc.remaining:.1f}/{alloc.allocated:.1f} | "
f"Efficiency: {efficiency:.0f}%")
lines.append(f" Attempts: {alloc.total_attempts} | "
f"Consecutive Failures: {alloc.consecutive_failures}")
lines.append("")
lines.append("SUMMARY:")
lines.append("-" * 60)
stats = self.tracker.get_statistics()
lines.append(f"Total Attempts: {stats['total_attempts']}")
lines.append(f"Successful: {stats['successful_attempts']}")
lines.append(f"Success Rate: {stats['success_rate']*100:.1f}%")
completed = sum(1 for a in self.allocations.values()
if a.state == EnergyState.COMPLETED)
lines.append(f"Targets Covered: {completed} / {len(self.allocations)}")
lines.append("=" * 60)
return "\n".join(lines)
# ============================================================================
# 便捷函数
# ============================================================================
def create_energy_allocator(function_points: List[Dict],
max_iterations: int = 5) -> EnergyAllocator:
"""
便捷函数:创建并初始化能量分配器
Args:
function_points: 功能点列表
max_iterations: 最大迭代次数
Returns:
初始化完成的能量分配器
"""
allocator = EnergyAllocator(max_iterations=max_iterations)
allocator.initialize(function_points)
return allocator

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

580
autoline/test_history.py Normal file
View File

@@ -0,0 +1,580 @@
"""
Description : Test History Manager (Layer 1 Support Module)
- Store and manage test case history
- Support sequence pattern analysis
- Provide diversity statistics
Author : CGA Enhancement Project
Time : 2026/03/16
"""
import json
import os
import logging
from typing import List, Dict, Optional, Any, Tuple, Set
from dataclasses import dataclass, field, asdict
from datetime import datetime
from collections import defaultdict
import hashlib
import re
logger = logging.getLogger(__name__)
# ============================================================================
# 数据结构定义
# ============================================================================
@dataclass
class InputSequence:
"""
输入序列记录
Attributes:
signal_name: 信号名称
values: 赋值序列 [(time, value), ...]
"""
signal_name: str
values: List[Tuple[int, Any]] = field(default_factory=list)
def to_pattern_string(self) -> str:
"""转换为模式字符串(仅包含值)"""
return "->".join(str(v[1]) for v in self.values)
def get_hash(self) -> str:
"""获取序列哈希值"""
return hashlib.md5(self.to_pattern_string().encode()).hexdigest()[:8]
@dataclass
class TestRecord:
"""
测试用例记录
Attributes:
test_id: 测试ID
code: 生成的测试代码
input_sequences: 输入信号序列列表
target_function: 目标功能点
covered_lines: 覆盖的代码行
covered_functions: 覆盖的功能点
coverage_score: 覆盖率分数
diversity_scores: 多样性得分字典
iteration: 迭代次数
timestamp: 时间戳
success: 是否成功
"""
test_id: str
code: str = ""
input_sequences: List[InputSequence] = field(default_factory=list)
target_function: str = ""
covered_lines: List[int] = field(default_factory=list)
covered_functions: List[str] = field(default_factory=list)
coverage_score: float = 0.0
diversity_scores: Dict[str, float] = field(default_factory=dict)
iteration: int = 0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
success: bool = False
def get_sequence_patterns(self) -> Dict[str, str]:
"""获取所有输入序列的模式"""
return {seq.signal_name: seq.to_pattern_string() for seq in self.input_sequences}
@dataclass
class SequencePattern:
"""
序列模式统计
Attributes:
pattern: 模式字符串
count: 出现次数
signal_name: 所属信号
test_ids: 关联的测试ID列表
"""
pattern: str
count: int = 0
signal_name: str = ""
test_ids: List[str] = field(default_factory=list)
def is_overused(self, threshold: int = 3) -> bool:
"""判断是否过度使用"""
return self.count >= threshold
# ============================================================================
# 序列提取器
# ============================================================================
class SequenceExtractor:
"""
从测试代码中提取输入序列
解析Verilog测试代码提取信号赋值序列
"""
# 匹配信号赋值语句
ASSIGNMENT_PATTERNS = [
# 阻塞赋值: signal = value;
r'(\w+)\s*=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
# 非阻塞赋值: signal <= value;
r'(\w+)\s*<=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
# 简单赋值(无位宽)
r'(\w+)\s*=\s*(\d+)\s*;',
]
# 匹配延时
DELAY_PATTERN = r'#\s*(\d+)\s*;'
# 匹配时钟周期等待
CLOCK_WAIT_PATTERN = r'repeat\s*\(\s*(\d+)\s*\)\s*@\s*\(\s*posedge\s+(\w+)\s*\)'
def __init__(self):
self.known_signals: Set[str] = set()
def set_known_signals(self, signals: List[str]):
"""设置已知信号列表(用于过滤)"""
self.known_signals = set(signals)
def extract(self, code: str) -> List[InputSequence]:
"""
从代码中提取输入序列
Args:
code: Verilog测试代码
Returns:
输入序列列表
"""
sequences = {}
current_time = 0
# 按行处理代码
lines = code.split('\n')
for line in lines:
line = line.strip()
# 跳过注释和空行
if not line or line.startswith('//'):
continue
# 检测延时,更新时间
delay_match = re.search(self.DELAY_PATTERN, line)
if delay_match:
current_time += int(delay_match.group(1))
continue
# 检测时钟周期等待
clock_match = re.search(self.CLOCK_WAIT_PATTERN, line, re.IGNORECASE)
if clock_match:
cycles = int(clock_match.group(1))
current_time += cycles * 10 # 假设每周期10时间单位
continue
# 检测赋值语句
for pattern in self.ASSIGNMENT_PATTERNS:
matches = re.finditer(pattern, line, re.IGNORECASE)
for match in matches:
signal = match.group(1)
value = match.group(2)
# 过滤非目标信号
if self.known_signals and signal not in self.known_signals:
continue
# 跳过明显的非输入信号
if signal.lower() in ['i', 'j', 'k', 'cnt', 'count', 'temp']:
continue
if signal not in sequences:
sequences[signal] = InputSequence(signal_name=signal)
sequences[signal].values.append((current_time, value))
current_time += 1 # 赋值语句本身占用1时间单位
return list(sequences.values())
# ============================================================================
# 测试历史管理器
# ============================================================================
class TestHistoryManager:
"""
测试历史管理器
管理已生成测试用例的历史记录,支持:
- 测试用例存储和检索
- 序列模式统计分析
- 多样性分布统计
"""
def __init__(self, history_file: str = None):
"""
Args:
history_file: 历史记录文件路径(可选)
"""
#必须先保存 history_file否则 save() 方法无法找到文件路径
self.history_file = history_file
self.records: List[TestRecord] = []
self.patterns: Dict[str, SequencePattern] = {} # pattern_hash -> SequencePattern
self.signal_patterns: Dict[str, List[str]] = defaultdict(list) # signal_name -> [pattern_hashes]
self.sequence_extractor = SequenceExtractor()
# 统计信息
self.stats = {
'total_tests': 0,
'successful_tests': 0,
'total_coverage': 0.0,
'avg_diversity': 0.0
}
if history_file and os.path.exists(history_file):
self.load(history_file)
# ==================== 记录管理 ====================
def add_record(self,
code: str,
test_id: str = None,
target_function: str = "",
covered_lines: List[int] = None,
covered_functions: List[str] = None,
coverage_score: float = 0.0,
iteration: int = 0,
success: bool = False,
known_signals: List[str] = None) -> TestRecord:
"""
添加测试记录
Args:
code: 测试代码
test_id: 测试ID自动生成如果未提供
target_function: 目标功能点
covered_lines: 覆盖的代码行
covered_functions: 覆盖的功能点
coverage_score: 覆盖率分数
iteration: 迭代次数
success: 是否成功
known_signals: 已知信号列表
Returns:
创建的测试记录
"""
if test_id is None:
test_id = f"test_{len(self.records)}_{datetime.now().strftime('%H%M%S')}"
# 提取输入序列
if known_signals:
self.sequence_extractor.set_known_signals(known_signals)
input_sequences = self.sequence_extractor.extract(code)
# 创建记录
record = TestRecord(
test_id=test_id,
code=code,
input_sequences=input_sequences,
target_function=target_function,
covered_lines=covered_lines or [],
covered_functions=covered_functions or [],
coverage_score=coverage_score,
iteration=iteration,
success=success
)
self.records.append(record)
# 更新模式统计
self._update_patterns(record)
# 更新统计信息
self._update_stats()
logger.debug(f"Added test record: {test_id}, sequences: {len(input_sequences)}")
return record
def get_record(self, test_id: str) -> Optional[TestRecord]:
"""根据ID获取记录"""
for record in self.records:
if record.test_id == test_id:
return record
return None
def get_recent_records(self, n: int = 10) -> List[TestRecord]:
"""获取最近的N条记录"""
return self.records[-n:] if len(self.records) >= n else self.records
def get_successful_records(self) -> List[TestRecord]:
"""获取所有成功的记录"""
return [r for r in self.records if r.success]
# ==================== 模式分析 ====================
def _update_patterns(self, record: TestRecord):
"""更新序列模式统计"""
for seq in record.input_sequences:
pattern_str = seq.to_pattern_string()
pattern_hash = seq.get_hash()
if pattern_hash not in self.patterns:
self.patterns[pattern_hash] = SequencePattern(
pattern=pattern_str,
count=1,
signal_name=seq.signal_name,
test_ids=[record.test_id]
)
else:
self.patterns[pattern_hash].count += 1
self.patterns[pattern_hash].test_ids.append(record.test_id)
# 按信号索引
if pattern_hash not in self.signal_patterns[seq.signal_name]:
self.signal_patterns[seq.signal_name].append(pattern_hash)
def get_overused_patterns(self, threshold: int = 3) -> List[SequencePattern]:
"""
获取过度使用的模式
Args:
threshold: 过度使用阈值
Returns:
过度使用的模式列表
"""
return [p for p in self.patterns.values() if p.is_overused(threshold)]
def get_common_patterns(self, top_n: int = 5) -> List[Tuple[str, int]]:
"""
获取最常见的模式
Args:
top_n: 返回数量
Returns:
[(pattern, count), ...]
"""
sorted_patterns = sorted(
self.patterns.items(),
key=lambda x: x[1].count,
reverse=True
)
return [(p[1].pattern, p[1].count) for p in sorted_patterns[:top_n]]
def get_pattern_for_signal(self, signal_name: str) -> List[SequencePattern]:
"""获取特定信号的所有模式"""
pattern_hashes = self.signal_patterns.get(signal_name, [])
return [self.patterns[h] for h in pattern_hashes if h in self.patterns]
# ==================== 多样性分析 ====================
def calculate_sequence_diversity(self, new_sequences: List[InputSequence]) -> float:
"""
计算新序列与历史记录的多样性得分
Args:
new_sequences: 新的输入序列列表
Returns:
多样性得分 (0.0 - 1.0)
"""
if not self.records:
return 1.0 # 没有历史记录时,认为完全多样
if not new_sequences:
return 0.0 # 没有序列时多样性为0
# 检查模式重复度
new_patterns = {seq.get_hash() for seq in new_sequences}
total_patterns = len(new_patterns)
if total_patterns == 0:
return 0.0
# 计算新模式比例
new_pattern_count = sum(1 for h in new_patterns if h not in self.patterns)
pattern_diversity = new_pattern_count / total_patterns
return pattern_diversity
def calculate_edit_distance_diversity(self, new_code: str) -> float:
"""
基于编辑距离计算多样性
使用简化的编辑距离计算
"""
if not self.records:
return 1.0
# 获取最近的记录作为参考
recent_records = self.get_recent_records(5)
min_distance = float('inf')
for record in recent_records:
distance = self._levenshtein_distance(new_code, record.code)
min_distance = min(min_distance, distance)
# 归一化到 [0, 1]
max_len = max(len(new_code), max(len(r.code) for r in recent_records))
if max_len == 0:
return 0.0
return min_distance / max_len
def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""计算Levenshtein编辑距离简化版"""
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
# 使用简化的计算(抽样)
if len(s1) > 500:
s1 = s1[:500]
if len(s2) > 500:
s2 = s2[:500]
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
# ==================== 统计信息 ====================
def _update_stats(self):
"""更新统计信息"""
self.stats['total_tests'] = len(self.records)
self.stats['successful_tests'] = sum(1 for r in self.records if r.success)
if self.records:
self.stats['total_coverage'] = sum(r.coverage_score for r in self.records)
self.stats['avg_coverage'] = self.stats['total_coverage'] / len(self.records)
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
'total_patterns': len(self.patterns),
'overused_patterns': len(self.get_overused_patterns()),
'unique_signals': len(self.signal_patterns)
}
def get_diversity_report(self) -> str:
"""生成多样性报告"""
lines = []
lines.append("=" * 50)
lines.append("TEST HISTORY DIVERSITY REPORT")
lines.append("=" * 50)
lines.append(f"Total Tests: {self.stats['total_tests']}")
lines.append(f"Successful Tests: {self.stats['successful_tests']}")
lines.append(f"Total Patterns: {len(self.patterns)}")
lines.append("")
# 常见模式
lines.append("TOP 5 COMMON PATTERNS:")
common = self.get_common_patterns(5)
for i, (pattern, count) in enumerate(common, 1):
lines.append(f" {i}. {pattern[:40]}... (x{count})")
# 过度使用的模式
overused = self.get_overused_patterns()
if overused:
lines.append("")
lines.append("OVERUSED PATTERNS (need diversification):")
for p in overused[:5]:
lines.append(f" - {p.signal_name}: {p.pattern[:30]}... (used {p.count} times)")
lines.append("=" * 50)
return "\n".join(lines)
# ==================== 持久化 ====================
def save(self, filepath: str = None):
"""保存历史记录到文件"""
filepath = filepath or self.history_file
if not filepath:
return
# 手动构建可序列化的数据结构
records_data = []
for r in self.records:
record_dict = {
'test_id': r.test_id,
'code': r.code,
'input_sequences': [],
'target_function': r.target_function,
'covered_lines': r.covered_lines,
'covered_functions': r.covered_functions,
'coverage_score': r.coverage_score,
'diversity_scores': r.diversity_scores,
'iteration': r.iteration,
'timestamp': r.timestamp,
'success': r.success
}
# 手动转换 InputSequence 对象
for seq in r.input_sequences:
record_dict['input_sequences'].append({
'signal_name': seq.signal_name,
'values': seq.values
})
records_data.append(record_dict)
data = {
'records': records_data,
'stats': self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Test history saved to {filepath}")
def load(self, filepath: str):
"""从文件加载历史记录"""
if not os.path.exists(filepath):
return
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.records = []
for r in data.get('records', []):
sequences = [
InputSequence(**s) for s in r.get('input_sequences', [])
]
record = TestRecord(
test_id=r['test_id'],
code=r['code'],
input_sequences=sequences,
target_function=r.get('target_function', ''),
covered_lines=r.get('covered_lines', []),
covered_functions=r.get('covered_functions', []),
coverage_score=r.get('coverage_score', 0.0),
iteration=r.get('iteration', 0),
timestamp=r.get('timestamp', ''),
success=r.get('success', False)
)
self.records.append(record)
self._update_patterns(record)
self.stats = data.get('stats', self.stats)
logger.info(f"Loaded {len(self.records)} test records from {filepath}")
# ============================================================================
# 便捷函数
# ============================================================================
def create_test_history(history_file: str = None) -> TestHistoryManager:
"""创建测试历史管理器"""
return TestHistoryManager(history_file=history_file)