上传所有文件
This commit is contained in:
43
autoline/TB1_gen.py
Normal file
43
autoline/TB1_gen.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Description : The TB generation stage in the autoline. The main TB generation workflow is implemented in prompt_scriptws
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2024/7/24 11:27:21
|
||||
LastEdited : 2024/8/12 23:30:30
|
||||
"""
|
||||
|
||||
|
||||
from prompt_scripts import get_script, BaseScript
|
||||
from loader_saver import log_localprefix
|
||||
|
||||
class TaskTBgen():
|
||||
# TODO: in the future use pythonized prompt scripts and this class to replace the old TaskTBgen
|
||||
"""TBgen, in this class we generate tb by calling different python script according to stage_template"""
|
||||
def __init__(self, prob_data: dict, TBgen_prompt_script: str, task_dir: str, config):
|
||||
self.prob_data = prob_data
|
||||
self.prompt_script_name = TBgen_prompt_script
|
||||
self.task_dir = task_dir
|
||||
self.config = config
|
||||
WorkFlowClass = get_script(TBgen_prompt_script)
|
||||
self.workflow = WorkFlowClass(
|
||||
prob_data = prob_data,
|
||||
task_dir = task_dir,
|
||||
config = config
|
||||
)
|
||||
|
||||
@log_localprefix("TBgen")
|
||||
def run(self):
|
||||
self.workflow()
|
||||
|
||||
@property
|
||||
def scenario_num(self):
|
||||
return self.get_wf_attr("scenario_num")
|
||||
|
||||
@property
|
||||
def scenario_dict(self):
|
||||
return self.get_wf_attr("scenario_dict")
|
||||
|
||||
def get_wf_attr(self, attr_name:str):
|
||||
if hasattr(self.workflow, attr_name):
|
||||
return getattr(self.workflow, attr_name)
|
||||
else:
|
||||
return None
|
||||
416
autoline/TB2_syncheck.py
Normal file
416
autoline/TB2_syncheck.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Description : This is the TB syntactic checking stage in the autoline (previously named as TaskTBsim in autoline.py)
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2024/7/24 11:24:31
|
||||
LastEdited : 2024/8/23 15:53:15
|
||||
"""
|
||||
|
||||
import os
|
||||
import LLM_call as llm
|
||||
import iverilog_call as iv
|
||||
import python_call as py
|
||||
import loader_saver as ls
|
||||
from config import Config
|
||||
from loader_saver import autologger as logger
|
||||
from loader_saver import log_localprefix
|
||||
from prompt_scripts import get_script, BaseScript
|
||||
from utils.utils import Timer, get_time
|
||||
|
||||
IDENTIFIER = {
|
||||
"tb_start" : "```verilog",
|
||||
"tb_end" : "```"
|
||||
}
|
||||
|
||||
TESTBENCH_TEMPLATE = """%s
|
||||
`timescale 1ns / 1ps
|
||||
(more verilog testbench code here...)
|
||||
endmodule
|
||||
%s""" % (IDENTIFIER["tb_start"], IDENTIFIER["tb_end"])
|
||||
|
||||
DEBUG_TEMPLATE = """please fix the verilog testbench code below according to the error message below. please directly give me the corrected verilog testbench codes.
|
||||
Attention: never remove the irrelevant codes!!!
|
||||
your verilog testbench should be like:
|
||||
%s
|
||||
please only reply the full code modified. NEVER remove other irrelevant codes!!!
|
||||
The testbench I give you is the one with error. To be convienient, each of the line begins with a line number. The line number also appears at the error message. You should use the line number to locate the error with the help of error message.
|
||||
""" % (TESTBENCH_TEMPLATE)
|
||||
|
||||
DEBUG_FINAL_INSTR = """ please directly give me the corrected verilog codes, no other words needed. Your verilog codes should start with [```verilog] and end with [```]."""
|
||||
|
||||
DEBUG_TEMPLATE_PY = """please fix the python code below according to the error message below. please directly give me the corrected python codes.
|
||||
Attention: never remove the irrelevant codes!!!
|
||||
please only reply the full code modified. NEVER remove other irrelevant codes!!!
|
||||
The python code I give you is the one with error. To be convienient, each of the line begins with a line number. The line number also appears at the error message. You should use the line number to locate the error with the help of error message.
|
||||
"""
|
||||
|
||||
DEBUG_FINAL_INSTR_PY = """ please directly give me the corrected python codes, no other words needed. Your python codes should start with [```python] and end with [```]."""
|
||||
|
||||
# will be discarded by 15/08/2024
|
||||
# DEBUG_TEMPLATE_END = """
|
||||
# VERY IMPORTANT: please ONLY reply the full code modified. NEVER remove other irrelevant codes!!!
|
||||
# Your testbench SHOULD NOT have the line number at the beginning of each line!!!
|
||||
# """
|
||||
|
||||
class TaskTBsim():
|
||||
"""
|
||||
#### input:
|
||||
- ivcode_path:
|
||||
- the path of iverilog dir (xxx/TB_gen/), will contain all verilog files. generated .vvp will also be saved here
|
||||
#### output:
|
||||
- dict of the simulation result
|
||||
- "sim_pass" : bool (whether the simulation is successful. This is only the necessary condition of the correctness of the testbench)
|
||||
- "debug_iter" : int (the number of debug iterations)
|
||||
- "sim_out" : str (the output of the simulation)
|
||||
- "sim_err" : str (the error of the simulation)
|
||||
- "TB_gen_debugged" : str or None (the testbench code after debug)
|
||||
#### iverilog_path:
|
||||
- the path of iverilog dir, will contain all verilog files. generated .vvp will also be saved here
|
||||
#### task_id:
|
||||
- the name of the problem, will be used as the name of generated files
|
||||
file structure:
|
||||
- original
|
||||
- task_id.v
|
||||
- task_id_tb.v
|
||||
- task_id_vlist.txt
|
||||
- task_id.vvp
|
||||
- debug_1
|
||||
- task_id.v
|
||||
- task_id_tb.v
|
||||
- task_id_vlist.txt
|
||||
- task_id.vvp
|
||||
- debug_2
|
||||
- ...
|
||||
"""
|
||||
def __init__(self, TBgen: BaseScript, TB_code: str, module_header: str, task_dir: str, task_id: str, config):
|
||||
self.TBgen = TBgen
|
||||
self.TB_code_now = TB_code
|
||||
self.module_header = module_header
|
||||
self.task_dir = task_dir if task_dir.endswith("/") else task_dir + "/" # for the compatibility with the old version
|
||||
self.task_id = task_id
|
||||
self.config = config
|
||||
self.working_dir = TBgen.TB_code_dir if TBgen.TB_code_dir.endswith("/") else TBgen.TB_code_dir + "/" # will change during the debug process
|
||||
self.DUT_code = module_header + "\n\nendmodule\n"
|
||||
self.debug_iter_max = config.autoline.debug.max
|
||||
self.debug_iter_to_reboot = config.autoline.debug.reboot
|
||||
self.proc_timeout = config.autoline.timeout
|
||||
# self.debug_iter_now = 0 # this is a counter for both iverilog and python so it is possible to be larger than debug_iter_max
|
||||
self.debug_iter_iv_now = 0
|
||||
self.debug_iter_after_reboot_iv = 0
|
||||
self.debug_iter_py_now = 0
|
||||
self.debug_iter_after_reboot_py = 0
|
||||
self.reboot_both = False
|
||||
# self.debug_iter_after_reboot = 0
|
||||
# pychecker related
|
||||
self.pychecker_en = self.TBgen.Pychecker_en
|
||||
self.PY_code_now = ""
|
||||
if self.pychecker_en:
|
||||
self.TBout_content = "" # will get after the last iverilog run
|
||||
self.PY_code_now = self.TBgen.Pychecker_code
|
||||
self.py_fail_reboot_both_iter = config.autoline.debug.py_rollback # will reboot both iv and py if python simulation failed xxx times
|
||||
self.py_debug_focus = self.TBgen.py_debug_focus
|
||||
# infos
|
||||
self.sim_pass = False # this should be com_pass, but it is too late to change it now
|
||||
self.py_pass = False
|
||||
self.Eval0_pass = False
|
||||
self.iverilog_info = None
|
||||
self.reboot_both_times = 0
|
||||
self.iv_runing_time = 0.0 # the time of running the last iverilog
|
||||
self.py_runing_time = 0.0 # the time of running the last python
|
||||
self.tokens = {"prompt": 0, "completion": 0}
|
||||
|
||||
@log_localprefix("TBsim")
|
||||
def run(self):
|
||||
if not self.pychecker_en:
|
||||
self.run_iverilog()
|
||||
self.Eval0_pass = self.sim_pass
|
||||
else:
|
||||
exit_en = False
|
||||
while (not exit_en):
|
||||
self.run_iverilog()
|
||||
if self.sim_pass:
|
||||
self.run_python()
|
||||
# if (self.sim_pass and self.py_pass) or self.exceed_max_debug:
|
||||
if not self.reboot_both:
|
||||
exit_en = True
|
||||
else:
|
||||
exit_en = True
|
||||
self.Eval0_pass = False
|
||||
raise ValueError("TBsim: iverilog failed, python simulation is not allowed.")
|
||||
self.Eval0_pass = self.sim_pass and self.py_pass
|
||||
logger.info("TBsim finished : %s!"%(self.Eval0_pass))
|
||||
|
||||
def run_iverilog(self):
|
||||
"""
|
||||
- the main function of TBsim
|
||||
"""
|
||||
if not self.reboot_both:
|
||||
# this will only be called at the first time of runing run_iverilog
|
||||
self._save_code_run_iverilog()
|
||||
self.sim_pass = self.iverilog_info[0]
|
||||
while (self.debug_iter_iv_now < self.debug_iter_max) and (not self.sim_pass):
|
||||
self.debug_iter_iv_now += 1
|
||||
if self.debug_iter_after_reboot_iv < self.debug_iter_to_reboot:
|
||||
self.debug_iter_after_reboot_iv += 1
|
||||
self._debug_iv()
|
||||
else:
|
||||
self._reboot_iv()
|
||||
self.sim_pass = self.iverilog_info[0]
|
||||
self.reboot_both = False
|
||||
if self.reboot_both:
|
||||
# this means didn't enter the while, because debug_iter_max is already reached
|
||||
logger.info("iverilog compilation (reboot from python) : failed! iverilog exceeded max debug iteration (%s)"%(self.debug_iter_max))
|
||||
if self.sim_pass:
|
||||
logger.info("iverilog compilation : passed!")
|
||||
else:
|
||||
logger.info("iverilog compilation : failed! exceeded max debug iteration (%s)"%(self.debug_iter_max))
|
||||
# self.sim_out = self.iverilog_info[4]["out"] if self.iverilog_info[4] is not None else ""
|
||||
# self.sim_err = self.iverilog_info[-1]
|
||||
# clean .vcd wave files
|
||||
self.clean_vcd()
|
||||
|
||||
def run_python(self):
|
||||
# read the TBout.txt into TBout_content in working_dir
|
||||
with open(self.TBout_path, "r") as f:
|
||||
self.TBout_content = f.read()
|
||||
self.debug_iter_after_reboot_py = 0
|
||||
py_rollback = 0 # local variable
|
||||
self._save_code_run_python()
|
||||
# self.debug_iter_py_now
|
||||
while (self.debug_iter_py_now < self.debug_iter_max) and (not self.python_info[0]):
|
||||
if (not self.python_info[0]) and (py_rollback >= self.py_fail_reboot_both_iter):
|
||||
# +1: debug py fail + [generated py fail]
|
||||
self.reboot_both = True
|
||||
break
|
||||
py_rollback += 1
|
||||
self.debug_iter_py_now += 1
|
||||
if self.debug_iter_after_reboot_py < self.debug_iter_to_reboot:
|
||||
self.debug_iter_after_reboot_py += 1
|
||||
self._debug_py()
|
||||
else:
|
||||
self._reboot_py()
|
||||
# self._reboot_py() # only reboot, no debugging because python debugging is much harder than verilog
|
||||
# currently debug_py doesn't support reboot
|
||||
if self.reboot_both:
|
||||
self.py_pass = False
|
||||
self.sim_pass = False
|
||||
self.debug_iter_after_reboot_iv = self.debug_iter_to_reboot
|
||||
logger.info("python simulation : failed! will reboot both iverilog and python")
|
||||
elif self.python_info[0]:
|
||||
self.py_pass = True
|
||||
logger.info("python simulation : passed!")
|
||||
else:
|
||||
self.py_pass = False
|
||||
logger.info("python simulation : failed! exceeded max debug iteration (%s)"%(self.debug_iter_max))
|
||||
self.py_out = self.python_info[1]["out"] if self.python_info[1] is not None else ""
|
||||
self.py_err = self.python_info[-1]
|
||||
|
||||
def _debug_iv(self):
|
||||
with Timer(print_en=False) as debug_time:
|
||||
logger.info("iverilog simulation failed! Debuging... (debug_iter = %s)"%(self.debug_iter_iv_now))
|
||||
self.working_dir = self.task_dir + "debug_%s/" % (self.total_debug_iter_now)
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
debug_prompt = self._debug_prompt_gen_iv()
|
||||
debug_message = [{"role": "user", "content": debug_prompt}]
|
||||
gpt_response, info = llm.llm_call(debug_message, self.config.gpt.model, self.config.gpt.key_path)
|
||||
debug_message = info["messages"]
|
||||
self.TB_code_now = llm.extract_code(gpt_response, "verilog")[-1]
|
||||
self.TB_code_now = self.del_linemark(self.TB_code_now)
|
||||
self._save_code_run_iverilog()
|
||||
logger.info("%s: verilog DEBUG finished (%ss used)" % (self.debug_iter_info("iv"), round(debug_time.interval, 2)))
|
||||
self.tokens["prompt"] += info["usage"]["prompt_tokens"]
|
||||
self.tokens["completion"] += info["usage"]["completion_tokens"]
|
||||
ls.save_messages_to_txt(debug_message, self.working_dir+"debug_messages.txt")
|
||||
|
||||
def _reboot_iv(self):
|
||||
# change TBgen's code dir
|
||||
with Timer (print_en=False) as reboot_time:
|
||||
logger.info("iverilog simulation failed! Rebooting... (debug_iter = %s)"%(self.debug_iter_iv_now))
|
||||
self.working_dir = self.task_dir + "debug_%s_reboot/" % (self.total_debug_iter_now)
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
self.TBgen.run_reboot(self.working_dir, reboot_mode="TB")
|
||||
self.TB_code_now = self.TBgen.TB_code
|
||||
self._save_code_run_iverilog()
|
||||
logger.info("%s: verilog REBOOT finished (%ss used)" % (self.debug_iter_info("iv"), round(reboot_time.interval, 2)))
|
||||
# the tookens will be added into TBgen's tokens count, we don't count it again here.
|
||||
# reset reboot counter
|
||||
self.debug_iter_after_reboot_iv = 0
|
||||
|
||||
def _debug_py(self):
|
||||
with Timer(print_en=False) as debug_time:
|
||||
logger.info("python compilation failed! Debuging python... (debug_iter = %s)"%(self.debug_iter_py_now))
|
||||
self.working_dir = self.task_dir + "debug_%s/" % (self.total_debug_iter_now)
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
# run gpt
|
||||
debug_prompt = self._debug_prompt_gen_py()
|
||||
debug_message = [{"role": "user", "content": debug_prompt}]
|
||||
gpt_response, info = llm.llm_call(debug_message, self.config.gpt.model, self.config.gpt.key_path)
|
||||
debug_message = info["messages"]
|
||||
self.PY_code_now = llm.extract_code(gpt_response, "python")[-1]
|
||||
self.PY_code_now = self.del_linemark(self.PY_code_now)
|
||||
if self.py_debug_focus: # currently only support pychecker SEQ mode
|
||||
self.PY_code_now = self._py_focus(self.PY_code_now, before=False)
|
||||
self._save_code_run_python()
|
||||
logger.info("%s: python DEBUG finished (%ss used)" % (self.debug_iter_info("py"), round(debug_time.interval, 2)))
|
||||
self.tokens["prompt"] += info["usage"]["prompt_tokens"]
|
||||
self.tokens["completion"] += info["usage"]["completion_tokens"]
|
||||
ls.save_messages_to_txt(debug_message, self.working_dir+"debug_messages.txt")
|
||||
|
||||
def _reboot_py(self):
|
||||
# change TBgen's code dir
|
||||
with Timer (print_en=False) as reboot_time:
|
||||
logger.info("python compilation failed! Rebooting... (debug_iter = %s)"%(self.debug_iter_py_now))
|
||||
self.working_dir = self.task_dir + "debug_%s_reboot/" % (self.total_debug_iter_now)
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
self.TBgen.run_reboot(self.working_dir, reboot_mode="PY")
|
||||
self.PY_code_now = self.TBgen.Pychecker_code
|
||||
self._save_code_run_python()
|
||||
logger.info("%s: python REBOOT finished (%ss used)" % (self.debug_iter_info("py"), round(reboot_time.interval, 2)))
|
||||
# the tookens will be added into TBgen's tokens count, we don't count it again here.
|
||||
# reset reboot counter
|
||||
self.debug_iter_after_reboot_py = 0
|
||||
|
||||
def _save_code_run_iverilog(self):
|
||||
with open(self.TB_path, "w") as f:
|
||||
f.write(self.TB_code_now)
|
||||
with open(self.DUT_path, "w") as f:
|
||||
f.write(self.DUT_code)
|
||||
with Timer(print_en=False) as iverilog_time:
|
||||
self.iverilog_info = iv.iverilog_call_and_save(self.working_dir, silent=True, timeout=self.proc_timeout)
|
||||
self.iv_runing_time = round(iverilog_time.interval, 2)
|
||||
self.error_message_now = self.iverilog_info[-1]
|
||||
if "program is timeout" in self.error_message_now:
|
||||
# if the error message is timeout, we will delete the TBout.txt
|
||||
# this is to avoid the situation that infinite loop produces a large TBout.txt
|
||||
if os.path.exists(self.TBout_path):
|
||||
os.remove(self.TBout_path)
|
||||
self.clean_vvp()
|
||||
|
||||
def _save_code_run_python(self):
|
||||
with open(self.PY_path, "w") as f:
|
||||
f.write(self.PY_code_now)
|
||||
with open(self.TBout_path, "w") as f:
|
||||
f.write(self.TBout_content)
|
||||
with Timer(print_en=False) as python_time:
|
||||
self.python_info = py.python_call_and_save(pypath=self.PY_path, silent=True, timeout=self.proc_timeout)
|
||||
self.py_runing_time = round(python_time.interval, 2)
|
||||
self.error_message_now = self.python_info[-1]
|
||||
|
||||
def _debug_prompt_gen_iv(self):
|
||||
debug_prompt = DEBUG_TEMPLATE + "\n previous testbench with error:\n" + self.add_linemark(self.TB_code_now) + "\n compiling error message:\n" + self.error_message_now
|
||||
return debug_prompt
|
||||
|
||||
def _debug_prompt_gen_py(self):
|
||||
if self.py_debug_focus:
|
||||
py_code = self._py_focus(self.PY_code_now, before=True)
|
||||
else:
|
||||
py_code = self.PY_code_now
|
||||
if not ("program is timeout" in self.error_message_now):
|
||||
self.error_message_now = self._py_error_message_simplify(self.error_message_now)
|
||||
debug_prompt = DEBUG_TEMPLATE_PY + "\n previous python code with error:\n" + self.add_linemark(py_code) + "\n compiling error message:\n" + self.error_message_now + DEBUG_FINAL_INSTR_PY
|
||||
return debug_prompt
|
||||
|
||||
def _py_focus(self, code:str, before:bool):
|
||||
"""
|
||||
- code: the code under debug / after debug
|
||||
- before: True, if before debug, will split the code; False, if after debug, will restore the code
|
||||
"""
|
||||
# KEY_WORD = "\ndef check_dut"
|
||||
KEY_WORDs_1 = "def check_dut(vectors_in):\n golden_dut = GoldenDUT()\n failed_scenarios = []"
|
||||
KEY_WORDs_2 = "\ndef SignalTxt_to_dictlist"
|
||||
if before:
|
||||
key_words = KEY_WORDs_1 if KEY_WORDs_1 in code else KEY_WORDs_2
|
||||
if key_words not in code:
|
||||
py_code_focus = code
|
||||
self.py_code_nofocus = ""
|
||||
else:
|
||||
py_code_focus = code.split(key_words)[0]
|
||||
self.py_code_nofocus = key_words + code.split(key_words)[1]
|
||||
return py_code_focus
|
||||
else:
|
||||
return code + self.py_code_nofocus
|
||||
|
||||
@staticmethod
|
||||
def _py_error_message_simplify(error_message:str, error_depth:int=1):
|
||||
"""
|
||||
- extract the key point of python error message
|
||||
- error_depth: how many (how deep, from bottom to top) error messages to extract
|
||||
"""
|
||||
msg_lines = error_message.split("\n")
|
||||
msg_out = ""
|
||||
for line in reversed(msg_lines):
|
||||
msg_out = line + "\n" + msg_out
|
||||
if "File" in line:
|
||||
error_depth -= 1
|
||||
if error_depth == 0:
|
||||
break
|
||||
return msg_out
|
||||
|
||||
@property
|
||||
def exceed_max_debug(self):
|
||||
return (self.debug_iter_iv_now >= self.debug_iter_max) or (self.debug_iter_py_now >= self.debug_iter_max)
|
||||
|
||||
@property
|
||||
def total_debug_iter_now(self):
|
||||
return self.debug_iter_iv_now + self.debug_iter_py_now
|
||||
|
||||
@property
|
||||
def TB_path(self):
|
||||
return self.working_dir + self.task_id + "_tb.v"
|
||||
|
||||
@property
|
||||
def DUT_path(self):
|
||||
return self.working_dir + self.task_id + ".v"
|
||||
|
||||
@property
|
||||
def PY_path(self):
|
||||
return self.working_dir + self.task_id + "_tb.py"
|
||||
|
||||
@property
|
||||
def TBout_path(self):
|
||||
return self.working_dir + "TBout.txt"
|
||||
|
||||
def debug_iter_info(self, type):
|
||||
"""return debug iter info string. Type: "iv" or "py" """
|
||||
if self.pychecker_en:
|
||||
if type == "iv":
|
||||
return "verilog iter - %d/%d, total - %d/%d"%(self.debug_iter_iv_now, self.debug_iter_max, self.total_debug_iter_now, self.debug_iter_max*2)
|
||||
elif type == "py":
|
||||
return "python tier - %d/%d, total - %d/%d"%(self.debug_iter_py_now, self.debug_iter_max, self.total_debug_iter_now, self.debug_iter_max*2)
|
||||
else:
|
||||
raise ValueError("TaskTBsim.debug_iter_info(type): type should be 'iv' or 'py'")
|
||||
else:
|
||||
# only iverilog
|
||||
return "debug iter %d/%d"%(self.debug_iter_iv_now, self.debug_iter_max)
|
||||
|
||||
@staticmethod
|
||||
def add_linemark(code: str):
|
||||
"""add the line mark (1., 2., ...) to the code at the beginning of each line"""
|
||||
code = code.split("\n")
|
||||
code = [str(i+1) + ". " + line for i, line in enumerate(code)]
|
||||
return "\n".join(code)
|
||||
|
||||
@staticmethod
|
||||
def del_linemark(code: str):
|
||||
"""delete the line mark at the begening of each line if line mark exists"""
|
||||
code = code.split("\n")
|
||||
if code[1].split(".")[0].isdigit(): # use code[1] in case the first line is empty
|
||||
code = [line.split(". ")[1:] for line in code]
|
||||
for i, line in enumerate(code):
|
||||
code[i] = ". ".join(line)
|
||||
return "\n".join(code)
|
||||
|
||||
def clean_vcd(self):
|
||||
"""clean the .vcd files in the task_dir"""
|
||||
clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
|
||||
for root, dirs, files in os.walk(clean_dir):
|
||||
for file in files:
|
||||
if file.endswith(".vcd"):
|
||||
os.remove(os.path.join(root, file))
|
||||
|
||||
def clean_vvp(self):
|
||||
"""clean the .vvp files in the task_dir"""
|
||||
clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
|
||||
for root, dirs, files in os.walk(clean_dir):
|
||||
for file in files:
|
||||
if file.endswith(".vvp"):
|
||||
os.remove(os.path.join(root, file))
|
||||
836
autoline/TB3_funccheck.py
Normal file
836
autoline/TB3_funccheck.py
Normal file
@@ -0,0 +1,836 @@
|
||||
"""
|
||||
Description : The functionality checking of the generated TB, the submodule of Autoline
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2024/7/22 10:36:06
|
||||
LastEdited : 2025/2/25 22:11:13
|
||||
"""
|
||||
|
||||
import os
|
||||
import LLM_call as llm
|
||||
import iverilog_call as iv
|
||||
import python_call as py
|
||||
import numpy as np
|
||||
import loader_saver as ls
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
from loader_saver import autologger as logger
|
||||
from loader_saver import log_localprefix
|
||||
|
||||
|
||||
|
||||
class TaskTBcheck():
|
||||
"""
|
||||
### description
|
||||
- this is the self-checking stage of our pipeline; This is the main contribution of AutoBench2。
|
||||
- This stage is to check the functional correctness of the testbench generated by AutoBench.
|
||||
"""
|
||||
|
||||
def __init__(self, task_dir:str, task_id:str, description:str, module_header:str, TB_code_v:str, TB_code_py:str|None=None, rtl_list:list[str]=None, rtl_num:int=20, scenario_num=None, correct_max:int=3, runfiles_save:bool=True, discriminator_mode:str="col_full_wrong", corrector_mode:str="naive", circuit_type:str=None, rtl_compens_max_iter:int=3, rtl_compens_en:bool=True, desc_improve:bool=False, **LLM_kwargs) -> None:
|
||||
"""
|
||||
- input:
|
||||
- task_dir: the root directory of the taskTBcheck
|
||||
- task_id: the name of the problem
|
||||
- description: the description of the problem
|
||||
- module_header: the header of the module
|
||||
- TB_code_v: the generated verilog testbench code
|
||||
- TB_code_py (opt.): the generated python testbench code, if None, the tb is in a pure verilog mode
|
||||
- rtl_list (opt.): the list of llm-generated RTL codes, if None, will generate 20 RTL codes using LLM
|
||||
- rtl_num (opt.): the number of RTL codes to generate, only used when rtl_list is None
|
||||
- scenario_num (opt.): the number of scenarios in the testbench, if None, will be calculated from the failed scenarios (not accurate but won't impact the results)
|
||||
- correct_max (opt.): the maximum number of correction iterations
|
||||
- runfiles_save (opt.): whether to save the compilation files in TB_discrimination
|
||||
- discriminator_mode (default: col_full_wrong): the mode of the discriminator
|
||||
- corrector_mode (default: naive): the mode of the corrector
|
||||
- circuit_type (opt.): the type of the circuit, used in the corrector (better performance if provided)
|
||||
- rtl_compens_max_iter (default: 3): the maximum number of iterations of RTL compensation
|
||||
- rtl_compens_en (default: True): whether to enable RTL compensation
|
||||
- **LLM_kwargs: the keyword arguments for LLM (used in corrector and rtl generation), including:
|
||||
- "main_model": the main llm name used in TB_generation and correction
|
||||
- "rtlgen_model": the llm naem used in RTL generation
|
||||
- "api_key_path": the path of the api key
|
||||
- "temperature": the temperature of LLM
|
||||
"""
|
||||
self.task_dir = task_dir
|
||||
self.working_dir = self.task_dir
|
||||
self.task_id = task_id
|
||||
self.description = description
|
||||
self.module_header = module_header
|
||||
self.TB_code_v = TB_code_v
|
||||
self.TB_code_py = TB_code_py
|
||||
self.pychecker_en = TB_code_py is not None
|
||||
self.rtl_list = rtl_list
|
||||
self.scenario_num = scenario_num
|
||||
self.correct_max = correct_max
|
||||
self.runfiles_save = runfiles_save
|
||||
self.main_model = LLM_kwargs.get("main_model", None)
|
||||
self.rtlgen_model = LLM_kwargs.get("rtlgen_model", None)
|
||||
self.llm_api_key_path = LLM_kwargs.get("api_key_path", "config/key_API.json")
|
||||
self.llm_temperature = LLM_kwargs.get("temperature", None)
|
||||
self.circuit_type = circuit_type
|
||||
self.rtl_compens_max_iter = rtl_compens_max_iter # see self.discriminate_TB for more info
|
||||
self.rtl_compens_en = rtl_compens_en
|
||||
self.desc_improve = desc_improve
|
||||
self.tolerance_for_same_wrong_scen = 2
|
||||
self.same_wrong_scen_times = 0
|
||||
# discriminator and corrector
|
||||
self.discriminator_mode = discriminator_mode
|
||||
self.corrector_mode = corrector_mode
|
||||
self.discriminator = TB_discriminator(discriminator_mode)
|
||||
self.corrector = TB_corrector(self.corrector_mode, self.pychecker_en, self.circuit_type)
|
||||
self.improver = SPEC_improver(description, "naive", self.pychecker_en, self.main_model, self.circuit_type)
|
||||
# rtl list and number
|
||||
self.rtl_newly_gen_num = 0
|
||||
if self.rtl_list is None:
|
||||
# self.rtl_num = rtl_num
|
||||
self.set_rtl_num(rtl_num)
|
||||
self.rtl_list_gen()
|
||||
else:
|
||||
# self.rtl_num = len(self.rtl_list)
|
||||
self.set_rtl_num(len(self.rtl_list))
|
||||
self.scenario_matrix = None
|
||||
self.wrong_scen_num = 0 # a small number as default, will be replaced later
|
||||
self.previous_wrong_scen_num = 9999 # a very large number as default, will be replaced later
|
||||
self.TB_syntax_error = False
|
||||
# tb analysis results
|
||||
self.tb_pass = None
|
||||
self.wrong_col_index = None
|
||||
self.correct_col_index = None
|
||||
self.unsure_col_index = None
|
||||
# next_action
|
||||
self.next_action = None
|
||||
self.iter_now = 0
|
||||
self.corrected = False # this means the TB has been corrected
|
||||
if self.main_model is None:
|
||||
logger.warning("main_model not found, may have trouble while correcting tb")
|
||||
# record and runinfo
|
||||
self.op_record = [] # record the order of the operations, similar to the var in autoline; will be added to the funccheck's op_record in the final runinfo.
|
||||
|
||||
@property
|
||||
def rtl_num(self):
|
||||
"""protected attr. rtl_num is initialized at the beginning and will not be changed"""
|
||||
return self._rtl_num
|
||||
|
||||
@log_localprefix("TBcheck")
|
||||
def run(self):
|
||||
"""
|
||||
- the main function of TaskTBcheck
|
||||
- the TB check stage contains several sub-stages:
|
||||
- 1. TB discriminating
|
||||
- 2. TB correcting
|
||||
- output: will update the next action of the task, including:
|
||||
- "pass": the TB already passed the selfcheck, will go to the evaluation stage
|
||||
- "reboot": the whole pipeline will start from the very beginning
|
||||
- workflow: inital discrimination -> correction-discrimination lloop -> pass or reboot
|
||||
"""
|
||||
# TODO: if error occurs, go to reboot.
|
||||
# initial discrimination
|
||||
tolerance = 1
|
||||
syntax_error = False
|
||||
self.discriminate_TB()
|
||||
if self.TB_syntax_error:
|
||||
logger.negative("Testbench has syntax error, I give up. Reboot the whole process")
|
||||
syntax_error = True
|
||||
self.next_action = "reboot"
|
||||
elif self.tb_pass:
|
||||
logger.info("Testbench passed the funccheck")
|
||||
self.next_action = "pass"
|
||||
else:
|
||||
# enter the correction loop, the initial TB has no syntax error when entering
|
||||
if self.correct_max == 0:
|
||||
logger.negative("No correction is allowed, I give up. Reboot the whole autoline process")
|
||||
self.next_action = "reboot"
|
||||
for self.iter_now in range(1, self.correct_max+1):
|
||||
if (self.iter_now > 1) and (self.wrong_scen_num > self.previous_wrong_scen_num) and (not syntax_error):
|
||||
# give up, the correction makes it worse
|
||||
logger.negative(f"wrong scenarios increased ({self.wrong_scen_num} > {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
|
||||
self.next_action = "reboot"
|
||||
break
|
||||
elif (self.iter_now > 1) and (self.wrong_scen_num == self.previous_wrong_scen_num) and (not syntax_error):
|
||||
self.same_wrong_scen_times += 1
|
||||
if self.same_wrong_scen_times >= self.tolerance_for_same_wrong_scen:
|
||||
logger.info(f"wrong scenarios not decreased for {self.tolerance_for_same_wrong_scen} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
|
||||
self.next_action = "reboot"
|
||||
break
|
||||
else:
|
||||
logger.info(f"wrong scenarios not decreased for {self.same_wrong_scen_times} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), continue the correction")
|
||||
self.correct_TB()
|
||||
self.discriminate_TB()
|
||||
if self.tb_pass:
|
||||
logger.info("Testbench passed the funccheck after correction")
|
||||
self.next_action = "pass"
|
||||
break
|
||||
elif self.TB_syntax_error:
|
||||
# if the syntax error is from the corrector, we should roll back before correction
|
||||
if tolerance > 0:
|
||||
logger.negative(f"Testbench has syntax error after correction, I still have tolerance for that (tolerance={tolerance}). roll back and retry the self correction.")
|
||||
self.TB_code_v, self.TB_code_py = self.TB_code_v_before_cor, self.TB_code_py_before_cor
|
||||
tolerance -= 1
|
||||
syntax_error = True
|
||||
else:
|
||||
logger.negative("Testbench has syntax error after correction, I don't have tolerance. I give up. Reboot the whole autoline process")
|
||||
self.next_action = "reboot"
|
||||
syntax_error = True
|
||||
break
|
||||
self.next_action = "reboot" if self.iter_now == self.correct_max else None
|
||||
if (self.next_action == "reboot") and (not syntax_error) and (self.desc_improve):
|
||||
# the desc improver does not work well so we do not use it in this work
|
||||
self.improve_SPEC()
|
||||
logger.info(f"self funccheck finished. Next Action: [{self.next_action}]")
|
||||
return self.next_action
|
||||
|
||||
@log_localprefix("discriminator")
|
||||
def discriminate_TB(self, no_any_files:bool=False):
|
||||
"""
|
||||
- check the correctness of the testbench and return the rtl analysis results in matrix form
|
||||
- important data: the rtl list, the TB code
|
||||
- update the following data: `scenario_matrix`, `tb_pass`, `wrong_col_index`, `correct_col_index`, `unsure_col_index`, `wrong_scen_num`
|
||||
"""
|
||||
rtl_dir_prefix = "RTL_"
|
||||
self.op_record.append("discrim")
|
||||
self.working_dir = os.path.join(self.task_dir, f"discrim_{self.iter_now}")
|
||||
logger.info(f"Discriminating the testbench, NO.{self.iter_now} discrimination")
|
||||
for i in range(self.rtl_compens_max_iter):
|
||||
# the loop is for the case that too few RTL passed the syntax check, generate more rtls and recheck
|
||||
failed_scenario_matrix = []
|
||||
# for rtl_code in self.rtl_list:
|
||||
self.TB_syntax_error = True
|
||||
syntax_error_rtl = []
|
||||
for rtl_idx, rtl_code in enumerate(self.rtl_list):
|
||||
rtl_dir = os.path.join(self.working_dir, f"{rtl_dir_prefix}{rtl_idx+1}")
|
||||
scenario_vector = self.run_testbench(rtl_dir, self.TB_code_v, rtl_code, self.TB_code_py, rtl_idx+1, self.runfiles_save and (not no_any_files))
|
||||
failed_scenario_matrix.append(scenario_vector) # like [[2, 5], [3, 4, 5]]
|
||||
if scenario_vector != [-1]:
|
||||
self.TB_syntax_error = False
|
||||
else:
|
||||
syntax_error_rtl.append(rtl_idx+1)
|
||||
if syntax_error_rtl != []:
|
||||
logger.info(f"RTL(s) {syntax_error_rtl} have syntax error during discrimination")
|
||||
|
||||
if self.TB_syntax_error:
|
||||
# there are two cases for TB syntax error:
|
||||
# 1. this syntax error is from the previous stage, if so, we should directly reboot the whole autoline process
|
||||
# 2. this syntax error is from the corrector, if so, we roll back to the version before correction and retry the correction
|
||||
self.tb_pass = False
|
||||
return None
|
||||
|
||||
# the len of each scenario vector should be the same, thus we transform each vector into a onehot vector [[1,1,0,1,1,0], [1,1,1,0,0,0]]
|
||||
self.scenario_matrix = self.failed_scenarios_to_onehot_array(failed_scenario_matrix, max_scen_idx=self.scenario_num, taskid=self.task_id)
|
||||
if not no_any_files:
|
||||
# save this matrix into the working dir in a human readable form
|
||||
np.savetxt(os.path.join(self.working_dir, "scenario_matrix.csv"), self.scenario_matrix, delimiter=",", fmt="%d")
|
||||
# save this matrix in a plot form
|
||||
self.draw_scenario_matrix(self.scenario_matrix, self.task_id, os.path.join(self.working_dir, "scenario_matrix.png"))
|
||||
|
||||
# we delete the syntax errored rtl, if no syntax error in TB
|
||||
self.rtl_list = [rtl for rtl, scen in zip(self.rtl_list, failed_scenario_matrix) if scen != [-1]]
|
||||
failed_scenario_matrix = [scen for scen in failed_scenario_matrix if scen != [-1]]
|
||||
if len(self.rtl_list) < 0.5*self.rtl_num:
|
||||
# too few RTL passed the syntax check
|
||||
logger.info(f"too few RTL passed the syntax check ({len(self.rtl_list)}/{self.rtl_num}), I will generate more and recheck. This is not TB's fault.")
|
||||
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num-len(self.rtl_list), self.description, self.module_header, self.rtlgen_model, self.rtl_list)
|
||||
self.rtl_newly_gen_num += gen_num
|
||||
# delete the previous rtl dirs (dir start with rtl_dir_prefix and under the working dir)
|
||||
for subdir in os.listdir(self.working_dir):
|
||||
if subdir.startswith(rtl_dir_prefix) and os.path.isdir(os.path.join(self.working_dir, subdir)):
|
||||
os.system(f"rm -rf {os.path.join(self.working_dir, subdir)}")
|
||||
logger.info(f"re-discriminate the testbench with updated RTL list")
|
||||
if i == self.rtl_compens_max_iter-1:
|
||||
logger.info(f"no re-discrimination since the max iteration reached")
|
||||
else:
|
||||
break
|
||||
|
||||
# discriminate the testbench according to the one hot matrix
|
||||
self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index = self.discriminator.discriminate(self.scenario_matrix)
|
||||
self.previous_wrong_scen_num = self.wrong_scen_num
|
||||
self.wrong_scen_num = len(self.wrong_col_index)
|
||||
return self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index
|
||||
|
||||
@log_localprefix("corrector")
|
||||
def correct_TB(self):
|
||||
"""
|
||||
- correct the testbench by using the RTL analysis results
|
||||
"""
|
||||
self.op_record.append("correct")
|
||||
self.working_dir = os.path.join(self.task_dir, f"correct_{self.iter_now}")
|
||||
self.TB_code_v_before_cor, self.TB_code_py_before_cor = self.TB_code_v, self.TB_code_py
|
||||
self.TB_code_v, self.TB_code_py = self.corrector.correct(self.description, self.wrong_col_index, self.TB_code_v, self.main_model, self.TB_code_py, self.working_dir)
|
||||
self.corrected = True
|
||||
# self check if the testbench is corrected
|
||||
# self.working_dir = os.path.join(self.task_dir, f"TBcheck_after_correct")
|
||||
# self.discriminate_TB()
|
||||
# if self.tb_pass:
|
||||
# logger.info(f"[{self.task_id}] - Testbench passed the selfcheck after correction")
|
||||
# self.next_action = "pass"
|
||||
# else:
|
||||
# logger.warning(f"[{self.task_id}] - Testbench failed the selfcheck after correction, failed scenarios: {self.wrong_col_index}")
|
||||
|
||||
@log_localprefix("improver")
|
||||
def improve_SPEC(self):
|
||||
"""
|
||||
- improve the specification of the task according to the discrimination and correction results
|
||||
"""
|
||||
self.op_record.append("improve")
|
||||
self.working_dir = os.path.join(self.task_dir, "improve_Desc")
|
||||
self.description = self.improver.improve(self.wrong_col_index, self.correct_col_index, self.TB_code_v, self.TB_code_py, working_dir=self.working_dir)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def run_testbench(dir, driver_code:str, DUT_code:str, checker_code:str, rtl_index:int, save_en:bool=True):
|
||||
"""
|
||||
- modified from autoline.py TBEval.run_testbench
|
||||
- it has two mode: pychecker mode or verilog testbench mode
|
||||
-input:
|
||||
- dir: the dir to save the TB, DUT and pychecker code
|
||||
- driver_code: str; the testbench code
|
||||
- DUT_code: str; the DUT code
|
||||
- checker_code: str; the pychecker code
|
||||
- output:
|
||||
- a list of failed scenarios (if rtl has syntax error, return [-1])
|
||||
"""
|
||||
# iverilog part
|
||||
# save the TB and DUT
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
v_driver_path = os.path.join(dir, "driver.v")
|
||||
py_checker_path = os.path.join(dir, "checker.py")
|
||||
dut_path = os.path.join(dir, "DUT.v")
|
||||
with open(v_driver_path, "w") as f:
|
||||
f.write(driver_code)
|
||||
with open(dut_path, "w") as f:
|
||||
f.write(DUT_code)
|
||||
iv_run_info = iv.iverilog_call_and_save(dir, silent=True)
|
||||
if not iv_run_info[0]:
|
||||
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed, the PREREQUISITE of 'Evaluation' is no syntactic error from Testbench!!!")
|
||||
# raise RuntimeError("Iverilog Compilation Failed")
|
||||
return [-1]
|
||||
with open(py_checker_path, "w") as f:
|
||||
f.write(checker_code)
|
||||
py_run_info = py.python_call_and_save(pypath=py_checker_path, silent=True)
|
||||
if not py_run_info[0]:
|
||||
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Python code!!!")
|
||||
# raise RuntimeError("Python Compilation Failed")
|
||||
return [-1]
|
||||
python_info_out = py_run_info[1]["out"]
|
||||
python_info_out : str
|
||||
# find the last ] in the out
|
||||
last_bracket_end = python_info_out.rfind("]")
|
||||
# find the last [ in the out
|
||||
last_bracket_start = python_info_out.rfind("[")
|
||||
# if the last [..] is a [], return []
|
||||
if last_bracket_end == last_bracket_start+1:
|
||||
return []
|
||||
# extract the digits
|
||||
failed_scenarios = python_info_out[last_bracket_start+1:last_bracket_end].replace("'", "").split(",")
|
||||
# if the item is not pure digit such as 2b, then we only extract the digit part
|
||||
failed_scenarios = [int("".join([char for char in scenario if char.isdigit()])) for scenario in failed_scenarios]
|
||||
failed_scenarios = list(map(int, failed_scenarios))
|
||||
# if save_en false, we delete the dir
|
||||
if not save_en:
|
||||
os.system(f"rm -rf {dir}")
|
||||
return list(set(failed_scenarios))
|
||||
|
||||
def rtl_list_gen(self)->list[str]:
|
||||
"""
|
||||
- generate the RTL list using LLM, will empty the old rtl list
|
||||
- attr needed: description, module_header, rtl_num, llm_model
|
||||
- attr changed: rtl_list
|
||||
"""
|
||||
self.rtl_list = []
|
||||
logger.info(f"rtl list not found, generating naive rtls for testbench checking")
|
||||
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num, self.description, self.module_header, self.rtlgen_model, self.rtl_list)
|
||||
self.rtl_newly_gen_num += gen_num
|
||||
# save the rtl list
|
||||
save_path = os.path.join(self.task_dir, "rtl_list.json")
|
||||
os.makedirs(self.task_dir, exist_ok=True)
|
||||
ls.save_json_lines([{"task_id": self.task_id, "llmgen_RTL": self.rtl_list}], save_path)
|
||||
|
||||
@staticmethod
|
||||
def gen_rtl(num:int, description:str, header:str, llm_mode:str, rtl_list:list=[]):
|
||||
"""
|
||||
- input:
|
||||
- num (int): the number of RTLs to generate
|
||||
- description (str): the description of the rtl problem
|
||||
- header (str): the header of the module
|
||||
- llm (str): the llm model to use (official model name)
|
||||
- rtl_list (list) [optional]: the newly generated RTLs will be appended to this list, can be empty
|
||||
- output:
|
||||
- rtl_list (list): the list of the newly generated RTLs (and the old ones, if have)
|
||||
"""
|
||||
rtl_gen_num = 0
|
||||
prompt = "Your task is to write a verilog RTL design according to the design specification. The infomation we have is the problem description that guides student to write the RTL code (DUT) and the header of the desired module. here is the problem description:\n"
|
||||
prompt += description
|
||||
prompt += "\nHere is the header of the desired module:\n"
|
||||
prompt += header
|
||||
prompt += "\nPlease only return the module code (header should be included) in verilog, please do not include any other words."
|
||||
for i in range(num):
|
||||
# call llm
|
||||
answer = llm.llm_call(prompt, llm_mode)[0]
|
||||
# extract the module code
|
||||
module_code = llm.extract_code(answer, "verilog")[0]
|
||||
# logger.trace(f"[{self.task_id}] - {i+1} RTLs generated")
|
||||
rtl_list.append(module_code)
|
||||
rtl_gen_num += 1
|
||||
logger.info("%d naive rtls generated"%(rtl_gen_num))
|
||||
return rtl_list, rtl_gen_num
|
||||
|
||||
@staticmethod
|
||||
def failed_scenarios_to_onehot_array(failed_scenarios:list[list], max_scen_idx:int|None=None, taskid:str=""):
|
||||
"""
|
||||
- input: [failed_scenarios:list[int]] (for example: [[1,2,3], [2,3,4], [1,3,4], [-1]]), if one failed scenario list is [-1], it means the rtl has syntax error, should be skipped
|
||||
- output (np.array): a onehot array (for example: [[0,0,0,1], [1,0,0,0], [0,1,0,0], [-1,-1,-1,-1]]) (1 denots pass, 0 denotes fail, -1 means syntax error)
|
||||
"""
|
||||
|
||||
# find the max scenario index
|
||||
listlen = len(failed_scenarios)
|
||||
max_idx_given = max_scen_idx if max_scen_idx is not None else 1
|
||||
# we calculate the max_index_cal, and define the final max scenario index using max(max_index_cal, max_index_given)
|
||||
max_idx_cal = max(map(lambda x: max(x) if x != [] else 0, failed_scenarios))
|
||||
if max_idx_cal in [-1, 0]:
|
||||
# -1: all the scenarios in this rtl are -1
|
||||
# 0: usually not possible because the scenario index is from 1, but if exists, set to 1.
|
||||
max_idx_cal = 1
|
||||
if failed_scenarios == list(map(lambda x: [], range(listlen))):
|
||||
# this means all rtl passed
|
||||
max_idx_cal = 1 # set to 1 otherwise the plot will be empty
|
||||
max_idx = max(max_idx_cal, max_idx_given)
|
||||
|
||||
# if the failed scenario list is [-1], then all the scenarios in this rtl are -1
|
||||
# create the onehot array
|
||||
grid_map = [[1]*max_idx for _ in range(listlen)]
|
||||
for rtl_idx, failed_scens in enumerate(failed_scenarios):
|
||||
if failed_scens == [-1]:
|
||||
grid_map[rtl_idx] = [-1]*max_idx
|
||||
continue
|
||||
for scen_idx in failed_scens:
|
||||
grid_map[rtl_idx][scen_idx-1] = 0
|
||||
return np.array(grid_map)
|
||||
|
||||
@staticmethod
|
||||
def draw_scenario_matrix(scenario_matrix:np.ndarray, task_id:str, saving_path:str):
|
||||
"""
|
||||
- draw the 2D failed scenario array. The element in the array can only be 0, 1, -1. We use red for 0, green for 1, and gray for -1.
|
||||
- if the scenario is empty, will return a gray color block.
|
||||
"""
|
||||
if len(scenario_matrix) == 0:
|
||||
scenario_matrix = np.array([[-1]])
|
||||
# if the element in data not in [0, 1, -1], change the element to -1
|
||||
scenario_matrix = np.where(np.logical_or(scenario_matrix == 0, np.logical_or(scenario_matrix == 1, scenario_matrix == -1)), scenario_matrix, -1)
|
||||
# get the RGB values for salmon, grey and mediumseagreen
|
||||
salmon = mcolors.to_rgb("salmon")
|
||||
grey = mcolors.to_rgb("grey")
|
||||
mediumseagreen = mcolors.to_rgb("mediumseagreen")
|
||||
color_mapping = {
|
||||
0: salmon,
|
||||
1: mediumseagreen,
|
||||
-1: grey
|
||||
}
|
||||
rgb_image = np.array([[color_mapping[value] for value in row] for row in scenario_matrix])
|
||||
# assign the color to the scenario_matrix
|
||||
for value, color in color_mapping.items():
|
||||
rgb_image[scenario_matrix == value] = color
|
||||
plt.imshow(rgb_image)
|
||||
plt.ylabel("RTL index")
|
||||
plt.xlabel("Scenario index")
|
||||
current_xticks = np.arange(scenario_matrix.shape[1])
|
||||
plt.xticks(current_xticks, current_xticks + 1)
|
||||
current_yticks = np.arange(scenario_matrix.shape[0])
|
||||
plt.yticks(current_yticks, current_yticks + 1)
|
||||
plt.title(f"[{task_id}] - Matrix of RTL-TB Scenario Correctness")
|
||||
plt.savefig(saving_path)
|
||||
plt.close()
|
||||
|
||||
def update_description(self)->str:
|
||||
"""
|
||||
- will modify the description of the task according to the descrimination and correction results
|
||||
"""
|
||||
logger.info("the description of the task is updated")
|
||||
return self.description
|
||||
|
||||
def set_rtl_num(self, value:int):
|
||||
self._rtl_num = value
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.run()
|
||||
|
||||
|
||||
class TB_discriminator():
|
||||
"""
|
||||
this class is used to discriminate the testbench according to the failing matrix
|
||||
"""
|
||||
def __init__(self, mode:str) -> None:
|
||||
self.mode = mode
|
||||
match self.mode:
|
||||
case "col_full_wrong":
|
||||
# the most naive mode;
|
||||
# tb correction: if any scenario col is fully wrong, then the tb is wrong
|
||||
# scenario correction: the scenarios that are fully wrong are wrong
|
||||
pass
|
||||
case "col_80_wrong":
|
||||
# similar to the above, but the criterion is 80% wrong
|
||||
pass
|
||||
case "col_90_wrong":
|
||||
# similar to the above, but the criterion is 90% wrong
|
||||
pass
|
||||
case "col_70_wrong":
|
||||
# similar to the above, but the criterion is 70% wrong
|
||||
pass
|
||||
case "col_60_wrong":
|
||||
# similar to the above, but the criterion is 60% wrong
|
||||
pass
|
||||
case "col_50_wrong":
|
||||
# similar to the above, but the criterion is 50% wrong
|
||||
pass
|
||||
case "col_40_wrong":
|
||||
# similar to the above, but the criterion is 40% wrong
|
||||
pass
|
||||
case "col_70_wrong_row_25_correct":
|
||||
# similar to 70_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
|
||||
pass
|
||||
case "col_50_wrong_row_25_correct":
|
||||
# similar to 50_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
|
||||
pass
|
||||
case "col_70_wrong_row_1_correct":
|
||||
# similar to 70_wrong, but if 1% of the RTLs are fully correct, then the TB is correct
|
||||
pass
|
||||
case "col_70_wrong_row_10_correct":
|
||||
# similar to 70_wrong, but if 10% of the RTLs are fully correct, then the TB is correct
|
||||
pass
|
||||
case _:
|
||||
logger.critical("class discriminator - mode not found!!!")
|
||||
|
||||
def discriminate(self, failed_matrix:np.ndarray)->tuple[bool, list[int], list[int], list[int]]:
|
||||
"""
|
||||
- input: the failed matrix of the testbench in onehot form
|
||||
- output:
|
||||
- the idexes in the scen list are starting from 1
|
||||
- bool: whether the testbench is correct
|
||||
- list[int]: the list of the wrong scenarios
|
||||
- list[int]: the list of the correct scenarios
|
||||
- list[int]: the list of the scenarios that discriminator are not sure
|
||||
- the -1 row will not be considered here. See function 'failed_scenarios_to_onehot_array'
|
||||
"""
|
||||
# first check if all the scenarios are [-1], which means the tb has syntax error
|
||||
if np.all(failed_matrix == -1):
|
||||
return None, [], [], []
|
||||
failed_matrix = failed_matrix[~np.all(failed_matrix == -1, axis=1)]
|
||||
match self.mode:
|
||||
case "col_full_wrong":
|
||||
# check which column of the matrix is fully wrong (0)
|
||||
wrong_col_index = np.where(np.all(np.isin(failed_matrix, [0, -1]), axis=0))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.any(failed_matrix == 0, axis=0) & np.any(failed_matrix == 1, axis=0))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0 # as long as there is no fully wrong column, the tb is correct (loose criterion)
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
# my_log = logger.positive if tb_pass else logger.negative
|
||||
# my_log.info(f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_80_wrong":
|
||||
# check which column of the matrix is 80% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.8*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.8*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_90_wrong":
|
||||
# check which column of the matrix is 90% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.9*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.9*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_70_wrong":
|
||||
# check which column of the matrix is 70% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_70_wrong_row_25_correct":
|
||||
# check which column of the matrix is 70% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
|
||||
tb_pass = True
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_70_wrong_row_1_correct":
|
||||
# check which column of the matrix is 70% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.01*len(failed_matrix):
|
||||
tb_pass = True
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_70_wrong_row_10_correct":
|
||||
# check which column of the matrix is 70% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.1*len(failed_matrix):
|
||||
tb_pass = True
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_60_wrong":
|
||||
# check which column of the matrix is 60% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.6*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.6*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_50_wrong":
|
||||
# check which column of the matrix is 50% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_50_wrong_row_25_correct":
|
||||
# check which column of the matrix is 50% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
|
||||
tb_pass = True
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case "col_40_wrong":
|
||||
# check which column of the matrix is 40% wrong (0)
|
||||
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.4*len(failed_matrix))[0] + 1
|
||||
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
|
||||
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.4*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
|
||||
tb_pass = len(wrong_col_index) == 0
|
||||
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
|
||||
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
|
||||
case _:
|
||||
logger.critical("TB discriminator - mode not found!!!")
|
||||
raise RuntimeError("TB discriminator - mode not found!!!")
|
||||
|
||||
COR_PROMPT_1 = """Your task is to correct the testbench according to the failing scenarios. the information we have is the failed/passed scenarios of the testbench, the problem description and the testbench code.
|
||||
the testbench code is consisted of both verilog and python code. The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
|
||||
ATTENTION: The python code contains error, and your target is to find it and tell me how to correct it (you don't need to give me the code in this stage).
|
||||
"""
|
||||
|
||||
HINT_SEQ = """
|
||||
Hints - explaination of the given python code:
|
||||
the python class "GoldenDUT": This python class can represent the golden DUT (the ideal one). In "GoldenDUT", following methods are defined:
|
||||
- 1. a method "def __init__(self)": Set the inner states/values of the golden DUT. These values have suffix "_reg". The initial value of these inner values is "x", but later will be digits. The "__init__" method has no input parameters except "self".
|
||||
- 2. a method "def load(self, signal_vector)": This method is to load the important input signals and the inner values of "GoldenDUT" shall change according to the input signals. There is no clock signal in the input signal vector, every time the "load" method is called, it means a new clock cycle. The initial values "x" should be changed according to the input signals. This method has no return value.
|
||||
- 3. a method "def check(self, signal_vector)": This method is to determine the expected output values and compare them with output signals from DUT. It should return True or False only.
|
||||
"""
|
||||
|
||||
HINT_CMB = """
|
||||
Hints - explaination of the given python code:
|
||||
The given python code contains one class "GoldenDUT". this python class can represent the golden DUT (the ideal one). By calling the inner method "check", the signal vector from DUT will be checked. The details of the golden DUT are as follows:
|
||||
|
||||
- a. a method "def __init__(self)". Set the inner states/values of the golden DUT. The "__init__" method has no input parameters except "self".
|
||||
- b. a method "def load(self, signal_vector)". This method is to load the important input signals and get the expected output signals. it should return the expected output values. It can call other methods to help computing the expected output. It will be called by other inner methods later.
|
||||
- c. a method "def check(self, signal_vector)". This method is to call "load" to get the expected output values, and compare them with output signals from DUT. It should return True or False only. It can call other methods to help checking.
|
||||
- d. other methods, they can be called by "__init__", "load" or "check".
|
||||
- e. the input of "load" and "check" is the signal vector. The signal vector is a dictionary, the key is the signal name, the value is the signal value.
|
||||
"""
|
||||
|
||||
COR_PROMPT_2_PART1 = """
|
||||
please correct the python code according to the following rules:
|
||||
|
||||
PYTHON code rule: please do not change the original high level structure of the python code. i.e., if python code only contains one class and several functions such as init, load, check and more, only modify the implementation of the function, but do not change the name or delete the functions/class methods. you can add new class methods or functions if needed. you can use python libraries such as numpy or math.
|
||||
|
||||
"""
|
||||
COR_PROMPT_2_PART2 = """
|
||||
i.e., your python code format in response should still be like:
|
||||
|
||||
class <class_name>:
|
||||
def __init__(self):
|
||||
...(omitted)
|
||||
|
||||
def load(self, ...):
|
||||
...
|
||||
|
||||
def check(self, ...):
|
||||
...
|
||||
|
||||
def <other_functions>(self, ...):
|
||||
...
|
||||
|
||||
ATTENTION: please give me the corrected python code according to our previous conversation and the hints above. please give me the corrected full python code (not the part but the whole python code like I give you in our previous conversation).
|
||||
"""
|
||||
|
||||
class TB_corrector():
|
||||
def __init__(self, mode:str, pychecker_en:bool, circuit_type:str="") -> None:
|
||||
self.mode = mode
|
||||
self.pychecker_en = pychecker_en
|
||||
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
|
||||
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
|
||||
# logger.debug(f"TB_corrector class - mode: {self.mode}, pychecker_en: {self.pychecker_en}, circuit_type: {self.circuit_type}; The input circuit type is {circuit_type}")
|
||||
match self.mode:
|
||||
case "naive":
|
||||
# the most naive mode;
|
||||
pass
|
||||
case _:
|
||||
logger.critical("TB_corrector class - mode not found!!!")
|
||||
|
||||
def correct(self, description, failed_scenarios, TB_code_v:str, llm_model:str, TB_code_py:str|None=None, working_dir:str=None) -> tuple[str, str]:
|
||||
match self.mode:
|
||||
case "naive":
|
||||
if self.pychecker_en:
|
||||
TB_code_py = self._py_focus(TB_code_py, before=True)
|
||||
py_code_hint = HINT_CMB if self.circuit_type == "combinational" else HINT_SEQ
|
||||
prompt = COR_PROMPT_1
|
||||
prompt += "Here is the problem description:\n"
|
||||
prompt += description
|
||||
prompt += "\nHere is the testbench code:\n"
|
||||
prompt += "ATTENTION: the following scenarios are wrong: " + str(failed_scenarios) + "\n"
|
||||
# circuit type
|
||||
prompt += f"the circuit type of this task is {self.circuit_type}\n"
|
||||
prompt += "Here is the verilog code. it contains the meaning of each scenario. you can combine the wong scenario info above and the following code to better understand the reason of failing:\n"
|
||||
prompt += TB_code_v
|
||||
prompt += "\nHere is the python code, it contains error, please combine it with the wrong scenario info and the verilog code to understand:\n"
|
||||
prompt += TB_code_py
|
||||
prompt += "\nHere is some hints for your better understanding of the python codes above:"
|
||||
prompt += py_code_hint
|
||||
prompt += "\nplease reply me with the following steps:"
|
||||
prompt += "\n1. please analyze the reason of the failed scenarios. If possible, please find the in common between the failed scenarios."
|
||||
prompt += f"\n2. please analyze which part of the python code is related to the failed test scenarios ({str(failed_scenarios)})."
|
||||
prompt += "\n3. please tell me how to correct the wrong part (in natural language, do not give me the complete code implementation. please explain it in English.)"
|
||||
prompt += "\nhere is an example of the reply:"
|
||||
prompt += "\n1. the failed scenarios are all related to the same signal x\n2. the mid part of the function_X is related to the failed scenarios\n3. the correct logic of signal x should be y."
|
||||
logger.info("naive corrector mode begins")
|
||||
answer = llm.llm_call(prompt, llm_model)[0]
|
||||
prompt_2 = COR_PROMPT_2_PART1 + py_code_hint + COR_PROMPT_2_PART2
|
||||
message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, {"role": "user", "content": prompt_2}]
|
||||
answer_2, more_info = llm.llm_call(message, llm_model)
|
||||
# if "VERILOG" in answer_2:
|
||||
# TB_code_v = llm.extract_code(answer_2, "verilog")[0]
|
||||
# else:
|
||||
# TB_code_py = llm.extract_code(answer_2, "python")[0]
|
||||
TB_code_py = llm.extract_code(answer_2, "python")[0]
|
||||
TB_code_py = self._py_focus(TB_code_py, before=False)
|
||||
if working_dir is not None:
|
||||
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
|
||||
ls.save_code(TB_code_v, os.path.join(working_dir, "TB.v"))
|
||||
if TB_code_py is not None:
|
||||
ls.save_code(TB_code_py, os.path.join(working_dir, "TB.py"))
|
||||
logger.info("naive corrector mode ends; conversation and codes saved")
|
||||
else:
|
||||
logger.critical("TB_corrector - pychecker not enabled")
|
||||
raise RuntimeError("TB_corrector - pychecker not enabled")
|
||||
return TB_code_v, TB_code_py
|
||||
case _:
|
||||
logger.critical("TB_corrector - mode not found!!!")
|
||||
raise RuntimeError("TB_corrector - mode not found!!!")
|
||||
|
||||
def _py_focus(self, code:str, before:bool):
|
||||
"""
|
||||
- imigrated from TB2_syncheck.py
|
||||
- code: the code under debug / after debug
|
||||
- before: True, if before debug, will split the code; False, if after debug, will restore the code
|
||||
"""
|
||||
KEY_WORDs_1 = "def check_dut(vectors_in):\n golden_dut = GoldenDUT()\n failed_scenarios = []"
|
||||
KEY_WORDs_2 = "\ndef SignalTxt_to_dictlist"
|
||||
if before:
|
||||
key_words = KEY_WORDs_1 if KEY_WORDs_1 in code else KEY_WORDs_2 # for compatibility with the old version
|
||||
if key_words not in code:
|
||||
py_code_focus = code
|
||||
self._py_code_nofocus = ""
|
||||
else:
|
||||
py_code_focus = code.split(key_words)[0]
|
||||
self._py_code_nofocus = key_words + code.split(key_words)[1]
|
||||
return py_code_focus
|
||||
else:
|
||||
return code + self._py_code_nofocus
|
||||
|
||||
IMPR_PROMPT_1 = """Your task is to improve the quality of an RTL problem description using the following given information. Our final target is using the description to generate a testbench for the RTL design. Currently we already have the testbench but it is not perfect correct. Now in this stage the target is to generate a better description.
|
||||
The information we have the is original RTL description, the testbench code. The testbench code includes the verilog code for test scenario generation, and the python code for checking the output vector.The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
|
||||
Attention: the testbench we provide is not the perfect one, and it contains error. However, we already know the test scenarios where the testbench works correctly and the scenarios where the testbench has problem. We will provide you the scenarios index later.
|
||||
Here, firstly, is the problem description to be imporved:
|
||||
\n"""
|
||||
|
||||
DESC_MARK_BEGIN = "***description begins***"
|
||||
DESC_MARK_END = "***description ends***"
|
||||
|
||||
DESC_STEP_INSTRUCT = f"""
|
||||
please reply me with the following steps:
|
||||
1. please analyze which part of the testbench (especially the python checker code) is correct and can be used to improve the description.
|
||||
2. please analyze how can we improve the descriptin. for example, which part of the technical details can be more detailed, which part can be more clear, which part can be more concise.
|
||||
3. please provide the improved complete description. We will directly use it in the later stages.
|
||||
the format of description should be like:
|
||||
{DESC_MARK_BEGIN}
|
||||
... (the improved description, should be complete)
|
||||
{DESC_MARK_END}
|
||||
"""
|
||||
|
||||
DESC_FINAL_INSTRUCT = f"""
|
||||
ATTENTION: please know that the provided testbench is not perfect and may contains many errors. Thus, your modification on the description should not change the function of the original description. When there are conflicts between the testbench and the description, always believe the description is correct. Do not delete the information in the description, but you can rewrite it in a better way. You can also add more details to it. But NEVER mention any scenario index because the scenarios will not be the same at the next stage.
|
||||
when you answer the last question (provide the improved complete description), the descriptino should start with "{DESC_MARK_BEGIN}" and end with "{DESC_MARK_END}". Only in this way can we recognize the improved description.
|
||||
"""
|
||||
class SPEC_improver():
|
||||
def __init__(self, description, mode:str, pychecker_en:bool, llm_model:str, circuit_type:str="") -> None:
|
||||
self.description = description
|
||||
self.mode = mode
|
||||
self.pychecker_en = pychecker_en
|
||||
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
|
||||
self.llm_model = llm_model
|
||||
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
|
||||
|
||||
def improve(self, wrong_scenarios, correct_scenarios, TB_code_v:str, TB_code_py:str|None=None, working_dir:str|None=None) -> str:
|
||||
# not implemented yet
|
||||
match self.mode:
|
||||
case "naive":
|
||||
logger.info("naive description improver mode begins")
|
||||
prompt = ""
|
||||
prompt += IMPR_PROMPT_1
|
||||
prompt += DESC_MARK_BEGIN + "\n"
|
||||
prompt += self.description + "\n"
|
||||
prompt += DESC_MARK_END + "\n"
|
||||
prompt += "\nHere is the testbench codes:\n"
|
||||
prompt += "ATTENTION: the following scenarios are wrong: " + str(wrong_scenarios) + "\n"
|
||||
prompt += "ATTENTION: the following scenarios are correct, you can rely on these scenarios to improve the description: " + str(correct_scenarios) + "\n"
|
||||
prompt += TB_code_v + "\n"
|
||||
if self.pychecker_en:
|
||||
prompt += f"\nHere is the python code (python checker). please note that the python checker has correct function under the scenarios {str(correct_scenarios)}, but wrong under the scenarios {str(wrong_scenarios)}:\n"
|
||||
prompt += TB_code_py + "\n"
|
||||
prompt += DESC_STEP_INSTRUCT
|
||||
prompt += DESC_FINAL_INSTRUCT
|
||||
message = [{"role": "user", "content": prompt}]
|
||||
answer, more_info = llm.llm_call(message, self.llm_model)
|
||||
try:
|
||||
improved_description = answer.split(DESC_MARK_BEGIN)[1].split(DESC_MARK_END)[0]
|
||||
if improved_description == "":
|
||||
improved_description = self.description
|
||||
except:
|
||||
improved_description = self.description
|
||||
if working_dir is not None:
|
||||
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
|
||||
# save description
|
||||
with open(os.path.join(working_dir, "description.txt"), "w") as f:
|
||||
f.write(improved_description)
|
||||
logger.info("naive description improver mode ends")
|
||||
return improved_description
|
||||
case "hint":
|
||||
logger.info("description improver 'hint' mode begins")
|
||||
|
||||
return self.description
|
||||
|
||||
def test():
|
||||
from loguru import logger
|
||||
failed_scenarios = [
|
||||
[1,3,5],
|
||||
[-1],
|
||||
[2,3]
|
||||
]
|
||||
max_scenario = 7
|
||||
onehot_array = TaskTBcheck.failed_scenarios_to_onehot_array(failed_scenarios, max_scenario)
|
||||
print(np.array(onehot_array))
|
||||
my_disc = TB_discriminator("col_full_wrong")
|
||||
print(my_disc.discriminate(np.array(onehot_array)))
|
||||
|
||||
245
autoline/TB4_eval.py
Normal file
245
autoline/TB4_eval.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Description : This is the testbench eval stage in autoline
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2024/7/24 11:24:43
|
||||
LastEdited : 2024/8/28 21:08:21
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import iverilog_call as iv
|
||||
import python_call as py
|
||||
from loader_saver import autologger as logger
|
||||
from loader_saver import log_localprefix
|
||||
from utils.utils import Timer, get_time
|
||||
|
||||
TC_PASS_CHECK_LIST_TB_GEN = ["All test cases passed", "all test cases passed", "All Test Cases Passed"]
|
||||
TC_PASS_CHECK_LIST_TB_GOLDEN = ['Mismatches: 0 in ', 'Hint: Total mismatched samples is 0 out of']
|
||||
TC_PASS_CHECK_LIST_PYCHECKER = ["[]"]
|
||||
|
||||
class TaskTBeval():
|
||||
"""
|
||||
### description
|
||||
- this is the evaluation stage of our pipeline; the priority of this stage is that TB is generated and the empty DUT compilation is passed;
|
||||
- please use `try` to catch the exception of this function.
|
||||
- this module is independent from the previous modules.
|
||||
#### input
|
||||
- task_id: the name of the problem
|
||||
- root_dir: the dir of one problem
|
||||
- TB_gen: the testbench under evaluation (str)
|
||||
- TB_golden: the golden testbench (str)
|
||||
- DUT_golden: the golden RTL DUT (str)
|
||||
- DUT_mutant_list: the list of RTL DUT mutants modified from DUT_golden;[str]
|
||||
#### output
|
||||
- dict
|
||||
- "Eval1_pass" : bool (whether the golden RTL checking passed)
|
||||
- "Eval2_pass" : bool (whether the golden TB comparison on RTL mutants passed)
|
||||
- "Eval2_failed_mutant_idxes" : list of int (the index of the failed mutants)
|
||||
"""
|
||||
"""main structure: run(), run_Eval1(), run_Eval2()"""
|
||||
def __init__(self, task_id: str, task_dir: str, TB_gen: str, TB_golden:str=None, DUT_golden:str=None, DUT_mutant_list:list=None, DUT_gptgen_list:list = None, pychecker_en:bool = False, pychecker_code:str = "", runfiles_save:bool = True):
|
||||
self.task_id = task_id
|
||||
self.task_dir = task_dir
|
||||
self.TB_gen = TB_gen
|
||||
self.TB_golden = TB_golden
|
||||
self.DUT_golden = DUT_golden
|
||||
self.DUT_mutant_list = DUT_mutant_list
|
||||
self.DUT_gptgen_list = DUT_gptgen_list
|
||||
self.pychecker_en = pychecker_en
|
||||
self.save_en = runfiles_save
|
||||
self.TB_gen_mode = "TB_gen" if not self.pychecker_en else "Pychecker"
|
||||
self.pychecker_code = pychecker_code
|
||||
self.working_dir = ""
|
||||
# Eval1 related
|
||||
self.Eval1_exist = False
|
||||
# self.Eval1_dir = task_dir + "eval1_GoldenRTL/"
|
||||
self.Eval1_dir = os.path.join(task_dir, "eval1_GoldenRTL")
|
||||
self.Eval1_results = None
|
||||
self.Eval1_pass = None
|
||||
# Eval2 related
|
||||
self.Eval2_exist = False
|
||||
# self.Eval2_dir = task_dir + "eval2_GoldenTB_and_mutants/"
|
||||
self.Eval2_dir = os.path.join(task_dir, "eval2_GoldenTB_and_mutants")
|
||||
self.Eval2_pass = None
|
||||
self.Eval2_failed_mutant_idx = None
|
||||
self.Eval2_passed_mutant_idx = None
|
||||
# Eval2b related
|
||||
self.Eval2b_exist = False
|
||||
# self.Eval2b_dir = task_dir + "eval2b_GPTgenTB/"
|
||||
self.Eval2b_dir = os.path.join(task_dir, "eval2b_GPTgenTB")
|
||||
self.Eval2b_pass = None
|
||||
self.Eval2b_failed_mutant_idx = None
|
||||
self.Eval2b_passed_mutant_idx = None
|
||||
|
||||
@log_localprefix("TBeval")
|
||||
def run(self):
|
||||
# Eval 1
|
||||
if self.DUT_golden is not None:
|
||||
self.run_Eval1()
|
||||
if self.Eval1_pass:
|
||||
# Eval 2
|
||||
if self.TB_golden is not None and self.DUT_mutant_list is not None:
|
||||
self.run_Eval2(mode="mutant")
|
||||
# Eval 2b
|
||||
if self.TB_golden is not None and self.DUT_gptgen_list is not None:
|
||||
self.run_Eval2(mode="gptgen")
|
||||
else:
|
||||
logger.info("[%s] Eval 2/2b is skipped because Eval 1 failed" % (self.task_id))
|
||||
self.clean_wave_vcd() # some golden TBs may generate wave.vcd files
|
||||
|
||||
def run_Eval1(self):
|
||||
silent = True
|
||||
### Eval 1: Golden RTL checking
|
||||
logger.info("Eval 1: Golden RTL checking begins")
|
||||
self.Eval1_pass = self.run_testbench(self.Eval1_dir, self.TB_gen, self.DUT_golden, self.TB_gen_mode, self.pychecker_code, raise_when_fail=True, save_en=self.save_en)
|
||||
logger.match_level(self.Eval1_pass, "positive", "failed", "Eval 1: Golden RTL checking %s!" % ("passed" if self.Eval1_pass else "failed"))
|
||||
# my_log = logger.positive if self.Eval1_pass else logger.failed
|
||||
# my_log("[%s] Eval 1: Golden RTL checking %s!" % (self.task_id, "passed" if self.Eval1_pass else "failed"))
|
||||
self.Eval1_exist = True
|
||||
|
||||
def run_Eval2(self, mode:str="mutant"):
|
||||
""" mode: "mutant" or "gptgen" """
|
||||
silent = True
|
||||
assert mode in ["mutant", "gptgen"], "Invalid mode in run_Eval2: " + mode
|
||||
if mode == "mutant": # Eval2
|
||||
print_str = "Eval 2: Golden TB checking on RTL mutants"
|
||||
mutant_subdir_name = "mutant"
|
||||
DUT_list = self.DUT_mutant_list
|
||||
eval_dir = self.Eval2_dir
|
||||
elif mode == "gptgen": # Eval2b
|
||||
print_str = "Eval 2b: Golden TB checking on GPT generated RTL codes"
|
||||
mutant_subdir_name = "gptgen_DUT"
|
||||
DUT_list = self.DUT_gptgen_list
|
||||
eval_dir = self.Eval2b_dir
|
||||
### Eval 2: Golden TB comparison on RTL mutants
|
||||
logger.info(print_str)
|
||||
mutant_results = []
|
||||
for idx, DUT_mutant in enumerate(DUT_list):
|
||||
# mutant_subdir = eval_dir + "%s_%d/"%(mutant_subdir_name, idx+1)
|
||||
mutant_subdir = os.path.join(eval_dir, "%s_%d"%(mutant_subdir_name, idx+1))
|
||||
# GoldenTB_subsubdir = mutant_subdir + "GoldenTB/"
|
||||
GoldenTB_subsubdir = os.path.join(mutant_subdir, "GoldenTB")
|
||||
# GenedTB_subsubdir = mutant_subdir + "GeneratedTB/"
|
||||
GenedTB_subsubdir = os.path.join(mutant_subdir, "GeneratedTB")
|
||||
try: #in case the mutant has syntax error
|
||||
TBgolden_pass = self.run_testbench(GoldenTB_subsubdir, self.TB_golden, DUT_mutant, "TB_golden", save_en=self.save_en)
|
||||
except:
|
||||
TBgolden_pass = False
|
||||
try:
|
||||
TBgen_pass = self.run_testbench(GenedTB_subsubdir, self.TB_gen, DUT_mutant, self.TB_gen_mode, self.pychecker_code, save_en=self.save_en)
|
||||
except:
|
||||
TBgen_pass = False
|
||||
if not TBgolden_pass and not TBgen_pass:
|
||||
mutant_pass = True
|
||||
elif TBgolden_pass and TBgen_pass:
|
||||
mutant_pass = True
|
||||
else:
|
||||
mutant_pass = False
|
||||
mutant_results.append(mutant_pass)
|
||||
eval_pass = all(mutant_results)
|
||||
failed_mutant_idx = [idx + 1 for idx, result in enumerate(mutant_results) if not result]
|
||||
passed_mutant_idx = [idx + 1 for idx, result in enumerate(mutant_results) if result]
|
||||
if mode == "mutant":
|
||||
self.Eval2_pass, self.Eval2_failed_mutant_idx, self.Eval2_passed_mutant_idx, self.Eval2_exist = eval_pass, failed_mutant_idx, passed_mutant_idx, True
|
||||
elif mode == "gptgen":
|
||||
self.Eval2b_pass, self.Eval2b_failed_mutant_idx, self.Eval2b_passed_mutant_idx, self.Eval2b_exist = eval_pass, failed_mutant_idx, passed_mutant_idx, True
|
||||
result = "perfectly passed" if eval_pass else ("finished (%d/%d)" % (len(passed_mutant_idx), len(mutant_results)))
|
||||
my_log = logger.success if (eval_pass or (len(passed_mutant_idx)/len(mutant_results)>=0.8)) else logger.failed
|
||||
my_log("%s %s!" % (print_str, result))
|
||||
|
||||
def run_testbench(self, dir, TB_code, DUT_code, TB_type, pychecker_code = "", raise_when_fail = False, save_en = True):
|
||||
"""
|
||||
it has two mode: pychecker mode or verilog testbench mode
|
||||
-input:
|
||||
- dir: the dir to save the TB, DUT and pychecker code
|
||||
- TB_code: str; the testbench code
|
||||
- DUT_code: str; the DUT code
|
||||
- TB_type: str: TB_gen, TB_golden, Pychecker
|
||||
- pychecker_code: str; the pychecker code
|
||||
- output:
|
||||
- pass: bool; if the DUT passed the testbench
|
||||
"""
|
||||
# iverilog part
|
||||
# save the TB and DUT
|
||||
assert TB_type in ["TB_gen", "TB_golden", "Pychecker"], "Invalid TB_type in run_testbench: " + TB_type
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.working_dir = dir
|
||||
with open(self.TB_path, "w") as f:
|
||||
f.write(TB_code)
|
||||
with open(self.DUT_path, "w") as f:
|
||||
f.write(DUT_code)
|
||||
iv_run_info = iv.iverilog_call_and_save(dir, silent=True)
|
||||
if raise_when_fail:
|
||||
assert iv_run_info[0], "%s Iverilog Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Testbench!!!"%(TB_type)
|
||||
# pychecker part (if enabled)
|
||||
if TB_type == "Pychecker":
|
||||
with open(self.PY_path, "w") as f:
|
||||
f.write(pychecker_code)
|
||||
py_run_info = py.python_call_and_save(pypath=self.PY_path, silent=True)
|
||||
if raise_when_fail:
|
||||
assert py_run_info[0], "%s Python Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Python code!!!"%(TB_type)
|
||||
# check if the DUT passed the testbench
|
||||
TC_pass = self.TC_pass_from_TC_out(sim_pass=True, sim_out=py_run_info[1]["out"], TB_type="Pychecker") & iv_run_info[0] & py_run_info[0]
|
||||
else:
|
||||
TC_pass = self.TC_pass_from_TC_out(sim_pass=True, sim_out=iv_run_info[4]["out"], TB_type=TB_type) & iv_run_info[0]
|
||||
if not save_en:
|
||||
# os.system(f"rm -rf {dir}")
|
||||
cmd = f"find {dir} -type f ! -name 'run_info*'" + r" -exec rm -f {} +"
|
||||
os.system(cmd)
|
||||
return TC_pass
|
||||
|
||||
def clean_wave_vcd(self):
|
||||
"""clean the .vcd files in the task_dir"""
|
||||
# clean_dir = self.task_dir[:-1] if self.task_dir.endswith("/") else self.task_dir
|
||||
clean_dir = self.task_dir
|
||||
for root, dirs, files in os.walk(clean_dir):
|
||||
for file in files:
|
||||
# clean wave.vcd
|
||||
if file.endswith(".vcd"):
|
||||
os.remove(os.path.join(root, file))
|
||||
|
||||
@property
|
||||
def TB_path(self):
|
||||
# return self.working_dir + self.task_id + "_tb.v"
|
||||
return os.path.join(self.working_dir, self.task_id + "_tb.v")
|
||||
|
||||
@property
|
||||
def DUT_path(self):
|
||||
# return self.working_dir + self.task_id + ".v"
|
||||
return os.path.join(self.working_dir, self.task_id + ".v")
|
||||
|
||||
@property
|
||||
def PY_path(self):
|
||||
# return self.working_dir + self.task_id + "_tb.py"
|
||||
return os.path.join(self.working_dir, self.task_id + "_tb.py")
|
||||
|
||||
@staticmethod
|
||||
def TC_pass_from_TC_out(sim_pass: bool, sim_out: str, TB_type="TB_gen"):
|
||||
"""
|
||||
get the information if DUT passed all the test cases from the testbench
|
||||
#### input
|
||||
- sim_pass: bool; if TB passed the compilation. if not, will return False without check
|
||||
- sim_out: the simulation output message;
|
||||
- TB_ty: "TB_gen" or "TB_golden" or "Pychecker"; the type of the testbench
|
||||
"""
|
||||
if not sim_pass:
|
||||
return False
|
||||
assert TB_type in ["TB_gen", "TB_golden", "Pychecker"], "Invalid TB_type during 'TC_pass_from_TC_out': " + TB_type
|
||||
tc_pass_check_list_dict = {"TB_gen": TC_PASS_CHECK_LIST_TB_GEN, "TB_golden": TC_PASS_CHECK_LIST_TB_GOLDEN, "Pychecker": TC_PASS_CHECK_LIST_PYCHECKER}
|
||||
tc_pass_check_list = tc_pass_check_list_dict[TB_type]
|
||||
if TB_type in ["TB_gen", "TB_golden"]:
|
||||
for check_str in tc_pass_check_list:
|
||||
if check_str in sim_out:
|
||||
return True
|
||||
return False
|
||||
elif TB_type in ['Pychecker']:
|
||||
# check if the last [] contains any element
|
||||
# find the last ] in the out
|
||||
last_bracket_end = sim_out.rfind("]")
|
||||
# find the last [ in the out
|
||||
last_bracket_start = sim_out.rfind("[")
|
||||
# check if the last bracket pair is "[]", containing no element
|
||||
if (last_bracket_end - last_bracket_start) == 1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
537
autoline/TB_autoline.py
Normal file
537
autoline/TB_autoline.py
Normal file
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
Description : The main function of autoline, originally the first part of autoline.py in AutoBench 1.0
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2024/7/24 11:44:15
|
||||
LastEdited : 2024/9/1 10:32:18
|
||||
"""
|
||||
import os
|
||||
import analyze as al
|
||||
import loader_saver as ls
|
||||
|
||||
import time
|
||||
|
||||
from config import Config
|
||||
from loader_saver import save_dict_json_form, log_localprefix
|
||||
from data.probset import HDLBitsProbset
|
||||
from loader_saver import autologger as logger
|
||||
from utils.utils import Timer
|
||||
from autoline.TB1_gen import TaskTBgen
|
||||
from autoline.TB2_syncheck import TaskTBsim
|
||||
from autoline.TB3_funccheck import TaskTBcheck
|
||||
from autoline.TB4_eval import TaskTBeval
|
||||
from prompt_scripts import BaseScript
|
||||
from LLM_call import llm_manager
|
||||
|
||||
# [新增] 引入我们刚写的模块
|
||||
from autoline.TB_cga import TaskTBCGA
|
||||
|
||||
|
||||
def run_autoline():
|
||||
# load config
|
||||
config = Config()
|
||||
autoline = AutoLine(config)
|
||||
autoline()
|
||||
|
||||
class AutoLine():
|
||||
"""the class of the autoline"""
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logger.assert_(config.get_item("autoline", "promptscript") is not None, "config.autoline.promptscript is None, please check the config file.")
|
||||
self.load_data()
|
||||
# set run info
|
||||
# self.run_info_path = config.save.root + "Chatbench_RunInfo.json"
|
||||
self.run_info_path = os.path.join(config.save.root, "Chatbench_RunInfo.json")
|
||||
self.run_info = []
|
||||
self.analyzer_en = (config.autoline.onlyrun is None) or (config.autoline.onlyrun == "TBgensimeval") # only run the analyzer when not in the onlyrun mode (partial run)
|
||||
|
||||
def run(self):
|
||||
for idx, probdata_single in enumerate(self.probset.data):
|
||||
task_id = probdata_single["task_id"]
|
||||
self.logger.info("")
|
||||
self.logger.info("######################### task %d/%d [%s] #########################" % (idx+1, self.probset.num, task_id))
|
||||
# run_info_single = pipeline_one_prob(probdata_single, self.config)
|
||||
one_task = AutoLine_Task(probdata_single, self.config)
|
||||
run_info_single = one_task.run()
|
||||
self.run_info.append(run_info_single)
|
||||
# save run info: (write to file every iteration and will overwrite the previous one)
|
||||
save_dict_json_form(self.run_info, self.run_info_path)
|
||||
if self.analyzer_en:
|
||||
self.run_analyzer()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def load_data(self):
|
||||
cfg_probset = self.config.autoline.probset
|
||||
self.probset = HDLBitsProbset()
|
||||
self.probset.load_by_config(cfg_probset)
|
||||
|
||||
def run_analyzer(self):
|
||||
analyzer = al.Analyzer(self.run_info, self.config.gpt.model)
|
||||
analyzer.run()
|
||||
logger.info(analyzer.messages)
|
||||
|
||||
|
||||
|
||||
class AutoLine_Task():
|
||||
def __init__(self, prob_data:dict, config:Config):
|
||||
# config:
|
||||
self.config = config
|
||||
# probdata:
|
||||
self.prob_data = prob_data
|
||||
self.main_model = self.config.gpt.model # The main llm model used in the autoline (generation, correction...)
|
||||
self.task_id = prob_data["task_id"]
|
||||
self.task_NO = prob_data["task_number"]
|
||||
self.prob_description = prob_data["description"]
|
||||
self.header = prob_data["header"]
|
||||
self.DUT_golden = prob_data['module_code']
|
||||
self.TB_golden = prob_data.get("testbench", None)
|
||||
self.mutant_list = prob_data.get("mutants", None)
|
||||
self.rtlgen_list = prob_data.get('llmgen_RTL', None)
|
||||
self.rtlgen_model = self.config.gpt.rtlgen_model # if llmgen_list is none, this will be used
|
||||
self.rtl_num = self.config.autoline.TBcheck.rtl_num # will be covered if llmgen_list is not None
|
||||
# system config:
|
||||
# self.task_dir = self.config.save.root + self.task_id + "/"
|
||||
self.task_dir = os.path.join(self.config.save.root, self.task_id)
|
||||
self.working_dir = self.task_dir
|
||||
os.makedirs(self.task_dir, exist_ok=True)
|
||||
# === [CGA Mod] Save DUT immediately to task dir for CGA access ===
|
||||
self.dut_path = os.path.join(self.task_dir, "DUT.v")
|
||||
ls.save_code(self.DUT_golden, self.dut_path)
|
||||
# ==============================================================
|
||||
self.update_desc = config.autoline.update_desc
|
||||
self.error_interuption = config.autoline.error_interruption # for debug'
|
||||
self.save_codes = config.autoline.save_finalcodes
|
||||
self.save_compile = self.config.autoline.save_compile # save the compiling codes in TBcheck and TBeval or not.
|
||||
# TBgen paras:
|
||||
self.TBgen_prompt_script = config.autoline.promptscript
|
||||
self.circuit_type = None
|
||||
self.scenario_dict = None
|
||||
self.scenario_num = None
|
||||
self.checklist_worked = None
|
||||
# TBcheck paras:
|
||||
self.TBcheck_correct_max = self.config.autoline.TBcheck.correct_max
|
||||
self.iter_max = config.autoline.itermax
|
||||
self.discrim_mode = config.autoline.TBcheck.discrim_mode
|
||||
self.correct_mode = config.autoline.TBcheck.correct_mode
|
||||
self.rtl_compens_en = config.autoline.TBcheck.rtl_compens_en
|
||||
self.rtl_compens_max_iter = config.autoline.TBcheck.rtl_compens_max_iter
|
||||
self.cga_enabled = config.autoline.cga.enabled
|
||||
# stages:
|
||||
self.TBgen_manager:TaskTBgen = None
|
||||
self.TBgen:BaseScript = None
|
||||
self.TBsim:TaskTBsim = None
|
||||
self.TBcheck:TaskTBcheck = None
|
||||
self.TBeval:TaskTBeval = None
|
||||
self.stage_now = "initialization"
|
||||
# changing paras:
|
||||
self.autoline_iter_now = 0
|
||||
self.TB_code_v = None
|
||||
self.TB_code_py = None
|
||||
self.next_action = None
|
||||
# results:
|
||||
self.incomplete_running = True
|
||||
self.full_pass = False
|
||||
self.TB_corrected = False
|
||||
self.run_info = {}
|
||||
self.run_info_short = {}
|
||||
self.TBcheck_rtl_newly_gen_num = 0 # in autoline, "funccheck" = "TBcheck"
|
||||
self.op_record = [] # will record the order of each stage, for example: ["gen", "syncheck", "funccheck", "gen", "syncheck", "funccheck", "eval"]
|
||||
self.funccheck_op_record = []
|
||||
self.funccheck_iters = []
|
||||
#初始化
|
||||
self.cga_coverage = 0.0
|
||||
# === [CGA Mod] Initialize result dictionary for final reporting ===
|
||||
self.result_dict = {
|
||||
"task_id": self.task_id,
|
||||
"stage": "Init",
|
||||
"pass": False,
|
||||
"coverage": 0.0,
|
||||
"cga_enabled": self.cga_enabled
|
||||
}
|
||||
# =================================================================
|
||||
# renew current section of llm_manager and logger
|
||||
llm_manager.new_section()
|
||||
logger.set_temp_log()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
The main function of running the autoline for one problem
|
||||
"""
|
||||
with log_localprefix(self.task_id):
|
||||
self.run_stages()
|
||||
self.runinfo_update()
|
||||
if self.save_codes:
|
||||
self.save_TB_codes()
|
||||
# === [CGA Mod] Save Result JSON for Analyzer ===
|
||||
self.result_dict['stage'] = self.stage_now
|
||||
|
||||
try:
|
||||
result_save_path = self.config.autoline.result_path
|
||||
except AttributeError:
|
||||
# 如果 config 对象没这个属性,或者它是字典且没这个key
|
||||
result_save_path = "results"
|
||||
|
||||
# 确保是绝对路径或相对于项目根目录
|
||||
if not os.path.exists(result_save_path):
|
||||
os.makedirs(result_save_path, exist_ok=True)
|
||||
ls.save_dict_json_form(self.result_dict, os.path.join(result_save_path, f"{self.task_id}.json"))
|
||||
# ===============================================
|
||||
return self.run_info
|
||||
|
||||
def run_TBgen(self, subdir:str=None):
|
||||
# TODO: export the circuit type and scenario number
|
||||
self.op_record.append("gen")
|
||||
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
|
||||
self.stage_now = "TBgen"
|
||||
self.TBgen_manager = TaskTBgen(self.prob_data, self.TBgen_prompt_script, working_dir, self.config)
|
||||
self.TBgen = self.TBgen_manager.workflow
|
||||
with log_localprefix("TBgen"):
|
||||
self.TBgen()
|
||||
self.TB_code_v = self.TBgen.get_attr("TB_code_v")
|
||||
self.TB_code_py = self.TBgen.get_attr("TB_code_py")
|
||||
self.scenario_dict = self.TBgen.get_attr("scenario_dict")
|
||||
self.scenario_num = self.TBgen.get_attr("scenario_num")
|
||||
self.circuit_type = self.TBgen.get_attr("circuit_type")
|
||||
self.checklist_worked = self.TBgen.get_attr("checklist_worked")
|
||||
self.incomplete_running = True
|
||||
self._blank_log()
|
||||
|
||||
def run_TBsim(self, subdir:str=None):
|
||||
self.op_record.append("syncheck")
|
||||
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
|
||||
self.stage_now = "TBsim"
|
||||
self.TBsim = TaskTBsim(
|
||||
self.TBgen,
|
||||
self.TBgen.TB_code,
|
||||
self.header,
|
||||
working_dir,
|
||||
self.task_id,
|
||||
self.config
|
||||
)
|
||||
self.TBsim.run()
|
||||
self.TB_code_v = self.TBsim.TB_code_now
|
||||
self.TB_code_py = self.TBsim.PY_code_now
|
||||
self._blank_log()
|
||||
|
||||
def run_TBcheck(self, subdir:str=None):
|
||||
self.op_record.append("funccheck")
|
||||
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
|
||||
self.stage_now = "TBcheck"
|
||||
self.TBcheck = TaskTBcheck(
|
||||
task_dir = working_dir,
|
||||
task_id = self.task_id,
|
||||
description = self.prob_description,
|
||||
module_header = self.header,
|
||||
TB_code_v = self.TB_code_v,
|
||||
TB_code_py = self.TB_code_py,
|
||||
rtl_list = self.rtlgen_list,
|
||||
rtl_num = self.rtl_num,
|
||||
scenario_num = self.scenario_num,
|
||||
correct_max = self.TBcheck_correct_max,
|
||||
runfiles_save=self.save_compile,
|
||||
discriminator_mode=self.discrim_mode,
|
||||
corrector_mode=self.correct_mode,
|
||||
circuit_type=self.circuit_type,
|
||||
rtl_compens_en=self.rtl_compens_en,
|
||||
rtl_compens_max_iter=self.rtl_compens_max_iter,
|
||||
main_model = self.main_model,
|
||||
rtlgen_model = self.rtlgen_model,
|
||||
desc_improve=self.update_desc
|
||||
)
|
||||
self.rtlgen_list = self.TBcheck.rtl_list
|
||||
self.TBcheck.run()
|
||||
self.TB_code_v = self.TBcheck.TB_code_v
|
||||
self.TB_code_py = self.TBcheck.TB_code_py
|
||||
self.TB_corrected = self.TBcheck.corrected
|
||||
self.funccheck_op_record.append(self.TBcheck.op_record)
|
||||
self.funccheck_iters.append(self.TBcheck.iter_now)
|
||||
self.TBcheck_rtl_newly_gen_num += self.TBcheck.rtl_newly_gen_num
|
||||
self.next_action = self.TBcheck.next_action
|
||||
if self.update_desc:
|
||||
self.prob_data['description'] = self.TBcheck.update_description()
|
||||
self.prob_description = self.prob_data['description']
|
||||
self._blank_log()
|
||||
|
||||
def run_TBeval(self, subdir:str=None):
|
||||
self.op_record.append("eval")
|
||||
working_dir = os.path.join(self.task_dir, subdir) if subdir is not None else self.task_dir
|
||||
self.stage_now = "TBeval"
|
||||
self.TBeval = TaskTBeval(
|
||||
self.task_id,
|
||||
working_dir,
|
||||
TB_gen=self.TB_code_v,
|
||||
TB_golden=self.TB_golden,
|
||||
DUT_golden=self.DUT_golden,
|
||||
DUT_mutant_list=self.mutant_list,
|
||||
DUT_gptgen_list=None,
|
||||
pychecker_en=self.TBsim.pychecker_en,
|
||||
pychecker_code=self.TB_code_py,
|
||||
runfiles_save=self.save_compile
|
||||
)
|
||||
# attention: the rtls in DUT_gptgen_list are not the same as the rtls used in TBcheck, so currently we just block this feature
|
||||
try:
|
||||
self.TBeval.run()
|
||||
except:
|
||||
logger.failed("error when running TBeval, the autoline for this task stopped.")
|
||||
self.incomplete_running = True
|
||||
self._blank_log()
|
||||
# 在 run_TB4_eval 或其他方法旁边添加这个新方法
|
||||
def run_TBCGA(self, work_subdir="CGA", optimize=True, op_name="cga"):
|
||||
"""
|
||||
Coverage-Guided Agent 阶段
|
||||
"""
|
||||
self.stage_now = "TBCGA"
|
||||
self.op_record.append(op_name)
|
||||
|
||||
cga = TaskTBCGA(
|
||||
task_dir=self.task_dir,
|
||||
task_id=self.task_id,
|
||||
header=self.header,
|
||||
DUT_code=self.DUT_golden,
|
||||
TB_code=self.TB_code_v,
|
||||
config=self.config,
|
||||
work_subdir=work_subdir,
|
||||
max_iter=(self.config.autoline.cga.max_iter if optimize else 0)
|
||||
)
|
||||
|
||||
# [修改] 接收分数
|
||||
final_tb, final_score = cga.run()
|
||||
|
||||
|
||||
self.cga_coverage = final_score
|
||||
# 更新状态
|
||||
self.TB_code_v = final_tb
|
||||
self.result_dict['coverage'] = final_score
|
||||
|
||||
# [新增] 强制归档 final_TB.v 到工作目录
|
||||
final_tb_path = os.path.join(self.task_dir, "final_TB.v")
|
||||
ls.save_code(final_tb, final_tb_path)
|
||||
logger.info(f"Saved optimized TB to: {final_tb_path}")
|
||||
|
||||
def run_stages(self):
|
||||
with Timer(print_en=False) as self.running_time:
|
||||
if not self.error_interuption:
|
||||
self.run_stages_core()
|
||||
else:
|
||||
try:
|
||||
self.run_stages_core()
|
||||
except Exception as e:
|
||||
self.incomplete_running = True
|
||||
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
|
||||
if self.error_interuption:
|
||||
# if True, stop the pipeline
|
||||
raise e
|
||||
self.incomplete_running = False
|
||||
|
||||
def run_stages_core(self):
|
||||
match self.config.autoline.onlyrun:
|
||||
case "TBgen":
|
||||
self.run_TBgen()
|
||||
case "TBgensim":
|
||||
self.run_TBgen()
|
||||
self.run_TBsim()
|
||||
# case _: # default, run all
|
||||
case "TBgensimeval":
|
||||
try:
|
||||
self.run_TBgen("1_TBgen")
|
||||
self.run_TBsim("2_TBsim")
|
||||
self.run_TBeval("3_TBeval")
|
||||
except Exception as e:
|
||||
self.incomplete_running = True
|
||||
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
|
||||
else:
|
||||
self.incomplete_running = False
|
||||
case _: # default, run all
|
||||
for i in range(self.iter_max):
|
||||
self.autoline_iter_now = i
|
||||
try:
|
||||
self.run_TBgen(f"{i+1}_1_TBgen")
|
||||
self.run_TBsim(f"{i+1}_2_TBsim")
|
||||
self.run_TBcheck(f"{i+1}_3_TBcheck")
|
||||
except Exception as e:
|
||||
|
||||
|
||||
# logger.error(f"error when running {self.stage_now}, current pipeline iter: {i+1}, will {"REBOOT" if i<self.iter_max-1 else "go to NEXT STAGE"}. error message: {str(e)}")
|
||||
# self.next_action = "reboot"
|
||||
# continue
|
||||
err_msg = str(e)
|
||||
logger.error(f"Error when running {self.stage_now}, iter: {i+1}. Message: {err_msg}")
|
||||
|
||||
# === [关键修改:API 降温冷静期] ===
|
||||
# 如果是 iverilog 失败或 API 超时,强制休息 15 秒
|
||||
# 这能有效防止阿里云 API 报 429 错误或连接被重置
|
||||
logger.warning("⚠️ Pipeline interrupted. Cooling down for 15s to avoid API Rate Limit...")
|
||||
time.sleep(15)
|
||||
# ================================
|
||||
|
||||
# 如果配置里要求一报错就退出,则抛出异常
|
||||
if getattr(self.config.autoline, 'error_interruption', False):
|
||||
raise e
|
||||
|
||||
# 否则,标记为重启,准备进入下一次循环
|
||||
self.next_action = "reboot"
|
||||
self.incomplete_running = True # 标记当前运行不完整
|
||||
continue
|
||||
|
||||
|
||||
|
||||
match self.next_action:
|
||||
case "pass":
|
||||
break
|
||||
case "reboot":
|
||||
continue
|
||||
# === [CGA 插入点 START] ===
|
||||
# 只有当任务状态正常,且没有要求重启时
|
||||
if self.next_action == "pass":
|
||||
# 在进入 CGA 前,手动标记当前状态为完成,防止内部逻辑误判
|
||||
self.incomplete_running = False
|
||||
try:
|
||||
if self.cga_enabled:
|
||||
self.run_TBCGA(work_subdir="CGA", optimize=True, op_name="cga")
|
||||
else:
|
||||
self.run_TBCGA(work_subdir="CGA_baseline", optimize=False, op_name="coverage_eval")
|
||||
except Exception as e:
|
||||
logger.error(f"CGA Stage Failed: {e}. Fallback to original TB.")
|
||||
self.result_dict['error'] = str(e)
|
||||
# === [CGA 插入点 END] ===
|
||||
|
||||
try:
|
||||
self.run_TBeval(f"{self.autoline_iter_now+1}_4_TBeval")
|
||||
except Exception as e:
|
||||
self.incomplete_running = True
|
||||
logger.error("error when running %s, the autoline for this task stopped. error message: %s"%(self.stage_now, str(e)))
|
||||
|
||||
def runinfo_update(self):
|
||||
# general
|
||||
self.run_info = {
|
||||
"task_id": self.task_id,
|
||||
"task_number": self.task_NO,
|
||||
"time": round(self.running_time.interval, 2),
|
||||
"prompt_tokens": llm_manager.tokens_in_section,
|
||||
"completion_tokens": llm_manager.tokens_out_section,
|
||||
"token_cost": llm_manager.cost_section,
|
||||
"ERROR(incomplete)": self.incomplete_running,
|
||||
"op_record": self.op_record,
|
||||
"reboot_times": self.autoline_iter_now,
|
||||
"max_iter": self.iter_max,
|
||||
|
||||
# === [新增] 将覆盖率写入最终报告 ===
|
||||
"coverage": self.cga_coverage
|
||||
}
|
||||
# token and cost from llm_manager
|
||||
|
||||
# TBgen
|
||||
if self.TBgen is not None:
|
||||
# self.run_info["prompt_tokens"] += self.TBgen.tokens["prompt"]
|
||||
# self.run_info["completion_tokens"] += self.TBgen.tokens["completion"]
|
||||
self.run_info["circuit_type"] = self.circuit_type
|
||||
self.run_info["checklist_worked"] = self.checklist_worked
|
||||
self.run_info["scenario_num"] = self.scenario_num
|
||||
# TBsim
|
||||
if self.TBsim is not None:
|
||||
# self.run_info["prompt_tokens"] += self.TBsim.tokens["prompt"]
|
||||
# self.run_info["completion_tokens"] += self.TBsim.tokens["completion"]
|
||||
self.run_info.update({
|
||||
"Eval0_pass": self.TBsim.Eval0_pass,
|
||||
"Eval0_iv_pass": self.TBsim.sim_pass,
|
||||
"debug_iter_iv": self.TBsim.debug_iter_iv_now,
|
||||
"iv_runing_time": self.TBsim.iv_runing_time
|
||||
})
|
||||
if self.TBsim.pychecker_en:
|
||||
self.run_info.update({
|
||||
"Eval0_py_pass": self.TBsim.py_pass,
|
||||
"debug_iter_py": self.TBsim.debug_iter_py_now,
|
||||
"py_runing_time": self.TBsim.py_runing_time
|
||||
})
|
||||
# TODO: TBcheck runinfo update
|
||||
if self.TBcheck is not None:
|
||||
self.run_info.update({
|
||||
"TB_corrected": self.TB_corrected,
|
||||
"TBcheck_oprecord": self.funccheck_op_record,
|
||||
"rtl_num_newly_gen": self.TBcheck_rtl_newly_gen_num
|
||||
})
|
||||
# TBeval
|
||||
if self.TBeval is not None:
|
||||
if self.TBeval.Eval1_exist:
|
||||
self.run_info.update({"Eval1_pass": self.TBeval.Eval1_pass})
|
||||
self.result_dict["Eval1_pass"] = self.TBeval.Eval1_pass
|
||||
if self.TBeval.Eval2_exist:
|
||||
self.run_info.update({
|
||||
"Eval2_pass": self.TBeval.Eval2_pass,
|
||||
"Eval2_ratio": "%d/%d"%(len(self.TBeval.Eval2_passed_mutant_idx), len(self.prob_data['mutants'])),
|
||||
"Eval2_failed_mutant_idxes": self.TBeval.Eval2_failed_mutant_idx
|
||||
})
|
||||
self.result_dict.update({
|
||||
"Eval2_pass": self.TBeval.Eval2_pass,
|
||||
"Eval2_ratio": "%d/%d"%(len(self.TBeval.Eval2_passed_mutant_idx), len(self.prob_data['mutants'])),
|
||||
"Eval2_failed_mutant_idxes": self.TBeval.Eval2_failed_mutant_idx
|
||||
})
|
||||
if self.TBeval.Eval2b_exist:
|
||||
self.run_info.update({
|
||||
"Eval2b_pass": self.TBeval.Eval2b_pass,
|
||||
"Eval2b_ratio": "%d/%d"%(len(self.TBeval.Eval2b_passed_mutant_idx), len(self.prob_data['gptgen_RTL'])),
|
||||
"Eval2b_failed_mutant_idxes": self.TBeval.Eval2b_failed_mutant_idx
|
||||
})
|
||||
self.result_dict.update({
|
||||
"Eval2b_pass": self.TBeval.Eval2b_pass,
|
||||
"Eval2b_ratio": "%d/%d"%(len(self.TBeval.Eval2b_passed_mutant_idx), len(self.prob_data['gptgen_RTL'])),
|
||||
"Eval2b_failed_mutant_idxes": self.TBeval.Eval2b_failed_mutant_idx
|
||||
})
|
||||
# full pass
|
||||
if not self.incomplete_running:
|
||||
self.full_pass = self.TBsim.sim_pass and self.TBeval.Eval1_pass and self.TBeval.Eval2_pass
|
||||
self.run_info.update({
|
||||
"full_pass": self.full_pass
|
||||
})
|
||||
self.result_dict["full_pass"] = self.full_pass
|
||||
self.result_dict["pass"] = self.full_pass
|
||||
else:
|
||||
self.result_dict["full_pass"] = False
|
||||
self.result_dict["pass"] = False
|
||||
self.result_dict["stage"] = self.stage_now
|
||||
self.result_dict["coverage"] = self.cga_coverage
|
||||
save_dict_json_form(self.run_info, os.path.join(self.task_dir, "run_info.json"))
|
||||
|
||||
# short run info
|
||||
if "Eval2_ratio" in self.run_info.keys():
|
||||
eval_progress = "Eval2 - " + self.run_info["Eval2_ratio"]
|
||||
elif "Eval1_pass" in self.run_info.keys() and self.run_info["Eval1_pass"]:
|
||||
eval_progress = "Eval1 - passed"
|
||||
elif "Eval0_pass" in self.run_info.keys() and self.run_info["Eval0_pass"]:
|
||||
eval_progress = "Eval1 - failed"
|
||||
elif "Eval0_pass" in self.run_info.keys() and not self.run_info["Eval0_pass"]:
|
||||
eval_progress = "Eval0 - failed"
|
||||
else:
|
||||
eval_progress = "Eval0 - not found"
|
||||
self.run_info_short = {
|
||||
"task_id": self.run_info.get("task_id", None),
|
||||
"eval_progress": eval_progress,
|
||||
"TB_corrected": self.run_info.get("TB_corrected", None),
|
||||
"reboot_times": self.run_info.get("reboot_times", None),
|
||||
"time": self.run_info.get("time", None),
|
||||
"cost": self.run_info.get("token_cost", None),
|
||||
}
|
||||
save_dict_json_form(self.run_info_short, os.path.join(self.task_dir, "run_info_short.json"))
|
||||
|
||||
# run log
|
||||
running_log = logger.reset_temp_log()
|
||||
tasklog_path = os.path.join(self.task_dir, "task_log.log")
|
||||
os.makedirs(os.path.dirname(tasklog_path), exist_ok=True)
|
||||
with open(tasklog_path, "w") as f:
|
||||
f.write(running_log)
|
||||
|
||||
return self.run_info
|
||||
|
||||
def save_TB_codes(self):
|
||||
save_dir = self.task_dir
|
||||
ls.save_code(self.TB_code_v if isinstance(self.TB_code_v, str) else "// TB code (Verilog) unavailable", os.path.join(save_dir, "final_TB.v"))
|
||||
ls.save_code(self.TB_code_py if isinstance(self.TB_code_py, str) else "## TB code (Python) unavailable", os.path.join(save_dir, "final_TB.py"))
|
||||
|
||||
@staticmethod
|
||||
def _blank_log():
|
||||
logger.info("")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
1453
autoline/TB_cga.py
Normal file
1453
autoline/TB_cga.py
Normal file
File diff suppressed because it is too large
Load Diff
22
autoline/__init__.py
Normal file
22
autoline/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Description : Automatic pipeline of Chatbench: from HDLBits problem to simulation
|
||||
Author : Ruidi Qiu (r.qiu@tum.de)
|
||||
Time : 2023/12/7 15:13:00
|
||||
LastEdited : 2024/8/16 13:37:31
|
||||
autoline.py (c) 2023
|
||||
"""
|
||||
|
||||
from autoline.TB_autoline import run_autoline
|
||||
|
||||
from autoline.TB1_gen import TaskTBgen
|
||||
from autoline.TB2_syncheck import TaskTBsim
|
||||
from autoline.TB3_funccheck import TaskTBcheck, TB_corrector, TB_discriminator
|
||||
from autoline.TB4_eval import TaskTBeval
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError("you cannot run autoline.py directly!")
|
||||
# probset = Probset("data/HDLBits/HDLBits_data.jsonl", "data/HDLBits/HDLBits_data_miniset_mutants.jsonl", "data/HDLBits/HDLBits_circuit_type.jsonl", exclude_tasks=['rule110'], filter_content={'circuit_type': 'SEQ'})
|
||||
# print(probset.num)
|
||||
|
||||
BIN
autoline/__pycache__/TB1_gen.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB1_gen.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/TB2_syncheck.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB2_syncheck.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/TB3_funccheck.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB3_funccheck.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/TB4_eval.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB4_eval.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/TB_autoline.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB_autoline.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/TB_cga.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/TB_cga.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/cga_utils.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/cga_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/diversity_injector.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/diversity_injector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/energy_allocator.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/energy_allocator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/quality_evaluator.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/quality_evaluator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/semantic_analyzer.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/semantic_analyzer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
autoline/__pycache__/test_history.cpython-312.pyc
Normal file
BIN
autoline/__pycache__/test_history.cpython-312.pyc
Normal file
Binary file not shown.
5444
autoline/cga_utils.py
Normal file
5444
autoline/cga_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
601
autoline/diversity_injector.py
Normal file
601
autoline/diversity_injector.py
Normal file
@@ -0,0 +1,601 @@
|
||||
"""
|
||||
Description : Diversity Constraint Injector (Layer 1)
|
||||
- Analyze existing test sequences
|
||||
- Detect overused patterns
|
||||
- Generate diversity constraints for Prompt
|
||||
- Recommend new test scenarios
|
||||
Author : CGA Enhancement Project
|
||||
Time : 2026/03/16
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Optional, Any, Tuple, Set
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
# 支持两种导入方式:包导入和直接加载
|
||||
try:
|
||||
from .test_history import (
|
||||
TestHistoryManager,
|
||||
TestRecord,
|
||||
InputSequence,
|
||||
SequencePattern
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
from test_history import (
|
||||
TestHistoryManager,
|
||||
TestRecord,
|
||||
InputSequence,
|
||||
SequencePattern
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 配置常量
|
||||
# ============================================================================
|
||||
|
||||
class DiversityConfig:
|
||||
"""多样性约束配置"""
|
||||
|
||||
# 编辑距离阈值
|
||||
MIN_EDIT_DISTANCE = 3
|
||||
|
||||
# 模式过度使用阈值
|
||||
OVERUSE_THRESHOLD = 3
|
||||
|
||||
# 新场景推荐数量
|
||||
NEW_SCENARIO_COUNT = 3
|
||||
|
||||
# 序列长度限制(用于约束生成)
|
||||
MAX_SEQUENCE_LENGTH = 10
|
||||
|
||||
# 多样性得分权重
|
||||
PATTERN_WEIGHT = 0.4
|
||||
EDIT_DISTANCE_WEIGHT = 0.3
|
||||
COVERAGE_WEIGHT = 0.3
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 约束类型定义
|
||||
# ============================================================================
|
||||
|
||||
class ConstraintType(Enum):
|
||||
"""约束类型枚举"""
|
||||
FORBID_SEQUENCE = "forbid_sequence" # 禁止特定序列
|
||||
MIN_EDIT_DISTANCE = "min_edit_distance" # 最小编辑距离
|
||||
AVOID_PATTERN = "avoid_pattern" # 革免模式
|
||||
TRY_SCENARIO = "try_scenario" # 尝试新场景
|
||||
EXPLORE_RANGE = "explore_range" # 探索范围
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 约束数据结构
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class DiversityConstraint:
|
||||
"""
|
||||
多样性约束
|
||||
|
||||
Attributes:
|
||||
constraint_type: 约束类型
|
||||
description: 约束描述
|
||||
details: 详细信息
|
||||
priority: 优先级 (1-5, 5最高)
|
||||
"""
|
||||
constraint_type: ConstraintType
|
||||
description: str
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
priority: int = 3
|
||||
|
||||
def to_prompt_text(self) -> str:
|
||||
"""转换为Prompt文本"""
|
||||
if self.constraint_type == ConstraintType.FORBID_SEQUENCE:
|
||||
return f"- AVOID using this sequence pattern: {self.details.get('pattern', 'unknown')}"
|
||||
|
||||
elif self.constraint_type == ConstraintType.MIN_EDIT_DISTANCE:
|
||||
return f"- Your test sequence MUST differ from previous tests (edit distance >= {self.details.get('min_distance', 3)})"
|
||||
|
||||
elif self.constraint_type == ConstraintType.AVOID_PATTERN:
|
||||
signal = self.details.get('signal', '')
|
||||
pattern = self.details.get('pattern', '')
|
||||
return f"- AVOID the pattern '{pattern}' for signal '{signal}' (already used {self.details.get('count', 0)} times)"
|
||||
|
||||
elif self.constraint_type == ConstraintType.TRY_SCENARIO:
|
||||
return f"- TRY this new approach: {self.details.get('scenario', 'unknown')}"
|
||||
|
||||
elif self.constraint_type == ConstraintType.EXPLORE_RANGE:
|
||||
return f"- EXPLORE values in range [{self.details.get('min', 0)}, {self.details.get('max', 255)}] for {self.details.get('signal', 'signal')}"
|
||||
|
||||
return f"- {self.description}"
|
||||
# ============================================================================
|
||||
# 序列分析器
|
||||
# ============================================================================
|
||||
|
||||
class SequenceAnalyzer:
|
||||
"""
|
||||
序列分析器
|
||||
|
||||
分析输入序列的特征和模式
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def extract_value_range(values: List[Tuple[int, Any]]) -> Tuple[Any, Any]:
|
||||
"""提取值范围"""
|
||||
if not values:
|
||||
return (0, 0)
|
||||
|
||||
numeric_values = []
|
||||
for _, v in values:
|
||||
# 尝试转换为数值
|
||||
if isinstance(v, (int, float)):
|
||||
numeric_values.append(v)
|
||||
elif isinstance(v, str):
|
||||
# 处理 '0, '1, 'x 等
|
||||
if v in ['0', '1', 'x', 'z']:
|
||||
numeric_values.append(int(v) if v.isdigit() else 0)
|
||||
# 处理带位宽的值
|
||||
match = re.match(r"(\d+)'[bdh]([0-9a-fA-fA-FxXzZ_]+)", v)
|
||||
if match:
|
||||
try:
|
||||
numeric_values.append(int(match.group(2), 16))
|
||||
except:
|
||||
pass
|
||||
|
||||
if numeric_values:
|
||||
return (min(numeric_values), max(numeric_values))
|
||||
return (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def detect_transition_pattern(values: List[Tuple[int, Any]]) -> str:
|
||||
"""检测转换模式"""
|
||||
if len(values) < 2:
|
||||
return "single"
|
||||
|
||||
# 提取值序列
|
||||
val_seq = [v for _, v in values]
|
||||
|
||||
# 检测递增
|
||||
if all(str(val_seq[i]) <= str(val_seq[i+1]) for i in range(len(val_seq)-1)):
|
||||
return "incremental"
|
||||
|
||||
# 检测递减
|
||||
if all(str(val_seq[i]) >= str(val_seq[i+1]) for i in range(len(val_seq)-1)):
|
||||
return "decremental"
|
||||
|
||||
# 检测交替
|
||||
if len(val_seq) >= 4:
|
||||
if val_seq[0] == val_seq[2] and val_seq[1] == val_seq[3]:
|
||||
return "alternating"
|
||||
|
||||
# 检测脉冲(单个变化后恢复)
|
||||
if len(val_seq) == 3 and val_seq[0] == val_seq[2] != val_seq[1]:
|
||||
return "pulse"
|
||||
|
||||
return "random"
|
||||
|
||||
@staticmethod
|
||||
def calculate_sequence_length(code: str) -> int:
|
||||
"""计算代码中的操作序列长度"""
|
||||
# 统计赋值语句数量
|
||||
assignments = len(re.findall(r'\w+\s*=\s*\S+\s*;', code))
|
||||
# 统计repeat语句
|
||||
repeats = re.findall(r'repeat\s*\(\s*(\d+)\s*\)', code)
|
||||
repeat_cycles = sum(int(r) for r in repeats)
|
||||
|
||||
return assignments + repeat_cycles
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 场景推荐器
|
||||
# ============================================================================
|
||||
|
||||
class ScenarioRecommender:
|
||||
"""
|
||||
场景推荐器
|
||||
|
||||
根据历史记录和未覆盖功能点推荐新测试场景
|
||||
"""
|
||||
|
||||
# 场景模板
|
||||
SCENARIO_TEMPLATES = {
|
||||
'fsm': [
|
||||
"Test state transition from {state_a} to {state_b}",
|
||||
"Test illegal state transition handling",
|
||||
"Test state machine reset behavior",
|
||||
"Test state holding under stable inputs"
|
||||
],
|
||||
'counter': [
|
||||
"Test counter overflow behavior (count to max value)",
|
||||
"Test counter underflow (if applicable)",
|
||||
"Test counter reset during counting",
|
||||
"Test counter enable/disable control"
|
||||
],
|
||||
'branch': [
|
||||
"Test boundary condition: {condition} at threshold",
|
||||
"Test all branches of nested if-else",
|
||||
"Test case statement with all possible values"
|
||||
],
|
||||
'protocol': [
|
||||
"Test handshake timeout scenario",
|
||||
"Test back-to-back transactions",
|
||||
"Test protocol violation handling"
|
||||
],
|
||||
'general': [
|
||||
"Apply random input patterns for extended duration",
|
||||
"Test with boundary values (all 0s, all 1s)",
|
||||
"Test rapid signal transitions",
|
||||
"Test power-on/reset sequence variations"
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, history_manager: TestHistoryManager):
|
||||
self.history = history_manager
|
||||
|
||||
def recommend_scenarios(self,
|
||||
uncovered_functions: List[Dict],
|
||||
covered_patterns: Set[str] = None) -> List[str]:
|
||||
"""
|
||||
推荐新的测试场景
|
||||
|
||||
Args:
|
||||
uncovered_functions: 未覆盖的功能点列表
|
||||
covered_patterns: 已覆盖的模式集合
|
||||
|
||||
Returns:
|
||||
推荐场景列表
|
||||
"""
|
||||
recommendations = []
|
||||
covered_patterns = covered_patterns or set()
|
||||
|
||||
# 基于未覆盖功能点推荐
|
||||
for func in uncovered_functions[:3]:
|
||||
func_type = func.get('type', 'general')
|
||||
func_name = func.get('name', '')
|
||||
|
||||
templates = self.SCENARIO_TEMPLATES.get(func_type, self.SCENARIO_TEMPLATES['general'])
|
||||
|
||||
for template in templates[:1]: # 每个功能点取一个模板
|
||||
scenario = self._fill_template(template, func)
|
||||
if scenario not in covered_patterns:
|
||||
recommendations.append(scenario)
|
||||
|
||||
# 基于历史分析推荐
|
||||
if self.history.records:
|
||||
# 分析已使用的场景类型
|
||||
used_patterns = set()
|
||||
for record in self.history.records:
|
||||
for seq in record.input_sequences:
|
||||
pattern = SequenceAnalyzer.detect_transition_pattern(seq.values)
|
||||
used_patterns.add(pattern)
|
||||
|
||||
# 推荐未使用的场景类型
|
||||
all_patterns = {'incremental', 'decremental', 'alternating', 'pulse', 'random'}
|
||||
unused_patterns = all_patterns - used_patterns
|
||||
|
||||
|
||||
if unused_patterns:
|
||||
recommendations.append(f"Try {list(unused_patterns)[0]} input pattern (different from your usual approach)")
|
||||
|
||||
# 确保有足够的推荐
|
||||
while len(recommendations) < DiversityConfig.NEW_SCENARIO_COUNT:
|
||||
recommendations.append("Explore a completely different input sequence than before")
|
||||
|
||||
return recommendations[:DiversityConfig.NEW_SCENARIO_COUNT]
|
||||
|
||||
def _fill_template(self, template: str, func: Dict) -> str:
|
||||
"""填充场景模板"""
|
||||
result = template
|
||||
|
||||
# 替换占位符
|
||||
if '{state_a}' in template or '{state_b}' in template:
|
||||
states = func.get('states', ['STATE_A', 'STATE_B'])
|
||||
if len(states) >= 2:
|
||||
result = result.replace('{state_a}', states[0])
|
||||
result = result.replace('{state_b}', states[1])
|
||||
|
||||
if '{condition}' in template:
|
||||
condition = func.get('condition', 'signal')
|
||||
result = result.replace('{condition}', condition)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 约束生成器
|
||||
# ============================================================================
|
||||
|
||||
class ConstraintGenerator:
|
||||
"""
|
||||
约束生成器
|
||||
|
||||
根据历史分析生成多样性约束
|
||||
"""
|
||||
|
||||
def __init__(self, history_manager: TestHistoryManager):
|
||||
self.history = history_manager
|
||||
self.analyzer = SequenceAnalyzer()
|
||||
|
||||
def generate_constraints(self,
|
||||
target_function: str = None,
|
||||
uncovered_functions: List[Dict] = None) -> List[DiversityConstraint]:
|
||||
"""
|
||||
生成多样性约束
|
||||
|
||||
Args:
|
||||
target_function: 当前目标功能点
|
||||
uncovered_functions: 未覆盖功能点列表
|
||||
|
||||
Returns:
|
||||
约束列表
|
||||
"""
|
||||
constraints = []
|
||||
|
||||
if not self.history.records:
|
||||
return constraints
|
||||
|
||||
# 1. 生成过度使用模式约束
|
||||
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
|
||||
for pattern in overused[:3]: # 最多3个
|
||||
constraints.append(DiversityConstraint(
|
||||
constraint_type=ConstraintType.AVOID_PATTERN,
|
||||
description=f"Avoid overused pattern for {pattern.signal_name}",
|
||||
details={
|
||||
'signal': pattern.signal_name,
|
||||
'pattern': pattern.pattern,
|
||||
'count': pattern.count
|
||||
},
|
||||
priority=5
|
||||
))
|
||||
|
||||
# 2. 生成编辑距离约束
|
||||
recent_count = min(5, len(self.history.records))
|
||||
if recent_count > 00:
|
||||
constraints.append(DiversityConstraint(
|
||||
constraint_type=ConstraintType.MIN_EDIT_DISTANCE,
|
||||
description="Maintain minimum edit distance from recent tests",
|
||||
details={
|
||||
'min_distance': DiversityConfig.MIN_EDIT_DISTANCE,
|
||||
'reference_count': recent_count
|
||||
},
|
||||
priority=4
|
||||
))
|
||||
|
||||
# 3. 生成值范围探索约束
|
||||
if uncovered_functions:
|
||||
for func in uncovered_functions[:2]:
|
||||
# 根据功能点类型生成范围约束
|
||||
if func.get('type') == 'counter':
|
||||
max_val = func.get('max_value', 255)
|
||||
constraints.append(DiversityConstraint(
|
||||
constraint_type=ConstraintType.EXPLORE_RANGE,
|
||||
description=f"Explore counter boundary values",
|
||||
details={
|
||||
'signal': func.get('name', 'counter'),
|
||||
'min': 0,
|
||||
'max': max_val
|
||||
},
|
||||
priority=3
|
||||
))
|
||||
|
||||
# 按优先级排序
|
||||
constraints.sort(key=lambda c: c.priority, reverse=True)
|
||||
|
||||
return constraints
|
||||
|
||||
def generate_forbidden_sequence_prompt(self) -> str:
|
||||
"""生成禁止序列提示"""
|
||||
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
|
||||
|
||||
if not overused:
|
||||
return ""
|
||||
|
||||
lines = ["[DIVERSITY CONSTRAINTS - AVOID THESE OVERUSED PATTERNS]"]
|
||||
|
||||
for i, pattern in enumerate(overused[:5], 1):
|
||||
lines.append(f"{i}. Signal '{pattern.signal_name}': {pattern.pattern[:50]}")
|
||||
lines.append(f" (This pattern has been used {pattern.count} times already)")
|
||||
|
||||
lines.append("\nPlease create a DIFFERENT input sequence to improve test diversity.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 多样性约束注入器(主入口)
|
||||
# ============================================================================
|
||||
|
||||
class DiversityInjector:
|
||||
"""
|
||||
多样性约束注入器 - 第1层主入口
|
||||
|
||||
整合序列分析、模式检测、约束生成,提供统一的多样性约束接口
|
||||
"""
|
||||
|
||||
def __init__(self, history_manager: TestHistoryManager = None):
|
||||
"""
|
||||
Args:
|
||||
history_manager: 测试历史管理器
|
||||
"""
|
||||
self.history = history_manager or TestHistoryManager()
|
||||
self.constraint_generator = ConstraintGenerator(self.history)
|
||||
self.scenario_recommender = ScenarioRecommender(self.history)
|
||||
|
||||
def inject_diversity_constraints(self,
|
||||
prompt: str,
|
||||
target_function: str = None,
|
||||
uncovered_functions: List[Dict] = None) -> str:
|
||||
"""
|
||||
将多样性约束注入到Prompt中
|
||||
|
||||
Args:
|
||||
prompt: 废始Prompt
|
||||
target_function: 当前目标功能点
|
||||
uncovered_functions: 未覆盖功能点列表
|
||||
|
||||
Returns:
|
||||
注入约束后的Prompt
|
||||
"""
|
||||
if not self.history.records:
|
||||
return prompt # 没有历史记录时不注入
|
||||
|
||||
# 生成约束
|
||||
constraints = self.constraint_generator.generate_constraints(
|
||||
target_function=target_function,
|
||||
uncovered_functions=uncovered_functions
|
||||
)
|
||||
|
||||
# 生成推荐场景
|
||||
recommendations = self.scenario_recommender.recommend_scenarios(
|
||||
uncovered_functions=uncovered_functions or []
|
||||
)
|
||||
|
||||
# 构建约束文本
|
||||
constraint_text = self._build_constraint_section(constraints, recommendations)
|
||||
|
||||
# 找到插入点(在 [OUTPUT REQUIREMENTS] 之前插入)
|
||||
insert_marker = "[OUTPUT REQUIREMENTS"
|
||||
if insert_marker in prompt:
|
||||
parts = prompt.split(insert_marker, 1)
|
||||
enhanced_prompt = parts[0] + constraint_text + "\n\n" + insert_marker + parts[1]
|
||||
else:
|
||||
# 如果找不到标记,追加到末尾
|
||||
enhanced_prompt = prompt + "\n\n" + constraint_text
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def _build_constraint_section(self,
|
||||
constraints: List[DiversityConstraint],
|
||||
recommendations: List[str]) -> str:
|
||||
"""构建约束章节"""
|
||||
lines = []
|
||||
lines.append("[DIVERSITY CONSTRAINTS - CRITICAL]")
|
||||
lines.append("To improve test effectiveness, follow these diversity requirements:")
|
||||
lines.append("")
|
||||
|
||||
# 添加约束
|
||||
for constraint in constraints:
|
||||
lines.append(constraint.to_prompt_text())
|
||||
|
||||
lines.append("")
|
||||
|
||||
# 添加推荐场景
|
||||
if recommendations:
|
||||
lines.append("[RECOMMENDED NEW APPROACHES]")
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
lines.append(f"{i}. {rec}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("IMPORTANT: Repeated test patterns reduce coverage improvement efficiency.")
|
||||
lines.append("Generate a DISTINCTLY DIFFERENT test sequence from previous attempts.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_diversity_context(self) -> str:
|
||||
"""获取多样性上下文信息(用于Prompt)"""
|
||||
if not self.history.records:
|
||||
return ""
|
||||
|
||||
stats = self.history.get_statistics()
|
||||
overused = self.history.get_overused_patterns(DiversityConfig.OVERUSE_THRESHOLD)
|
||||
|
||||
context_lines = []
|
||||
context_lines.append(f"Test History: {stats['total_tests']} tests generated")
|
||||
context_lines.append(f"Unique Patterns: {stats['total_patterns']}")
|
||||
|
||||
if overused:
|
||||
context_lines.append(f"Overused Patterns: {len(overused)} (avoid these)")
|
||||
|
||||
return "\n".join(context_lines)
|
||||
|
||||
def evaluate_diversity(self,
|
||||
new_code: str,
|
||||
known_signals: List[str] = None) -> Dict[str, float]:
|
||||
"""
|
||||
评估新代码的多样性
|
||||
|
||||
Args:
|
||||
new_code: 新生成的测试代码
|
||||
known_signals: 已知信号列表
|
||||
|
||||
Returns:
|
||||
多样性评估结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 1. 序列多样性
|
||||
if known_signals:
|
||||
self.history.sequence_extractor.set_known_signals(known_signals)
|
||||
new_sequences = self.history.sequence_extractor.extract(new_code)
|
||||
results['sequence_diversity'] = self.history.calculate_sequence_diversity(new_sequences)
|
||||
|
||||
# 2. 编辑距离多样性
|
||||
results['edit_distance_diversity'] = self.history.calculate_edit_distance_diversity(new_code)
|
||||
|
||||
# 3. 综合得分
|
||||
results['overall_diversity'] = (
|
||||
DiversityConfig.PATTERN_WEIGHT * results['sequence_diversity'] +
|
||||
DiversityConfig.EDIT_DISTANCE_WEIGHT * results['edit_distance_diversity']
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def record_test(self,
|
||||
code: str,
|
||||
target_function: str = "",
|
||||
coverage_score: float = 0.0,
|
||||
success: bool = False,
|
||||
iteration: int = 0,
|
||||
known_signals: List[str] = None) -> TestRecord:
|
||||
"""
|
||||
记录新的测试用例
|
||||
|
||||
Args:
|
||||
code: 测试代码
|
||||
target_function: 目标功能点
|
||||
coverage_score: 覆盖率分数
|
||||
success: 是否成功
|
||||
iteration: 迭代次数
|
||||
known_signals: 已知信号列表
|
||||
|
||||
Returns:
|
||||
测试记录
|
||||
"""
|
||||
return self.history.add_record(
|
||||
code=code,
|
||||
target_function=target_function,
|
||||
coverage_score=coverage_score,
|
||||
success=success,
|
||||
iteration=iteration,
|
||||
known_signals=known_signals
|
||||
)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return self.history.get_statistics()
|
||||
|
||||
def generate_diversity_report(self) -> str:
|
||||
"""生成多样性报告"""
|
||||
return self.history.get_diversity_report()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数
|
||||
# ============================================================================
|
||||
|
||||
def create_diversity_injector(history_file: str = None) -> DiversityInjector:
|
||||
"""
|
||||
创建多样性约束注入器
|
||||
|
||||
Args:
|
||||
history_file: 屆史记录文件路径
|
||||
|
||||
Returns:
|
||||
初始化完成的多样性约束注入器
|
||||
"""
|
||||
history_manager = TestHistoryManager(history_file=history_file)
|
||||
return DiversityInjector(history_manager=history_manager)
|
||||
787
autoline/energy_allocator.py
Normal file
787
autoline/energy_allocator.py
Normal file
@@ -0,0 +1,787 @@
|
||||
"""
|
||||
Description : Energy Allocation Layer (Layer 4)
|
||||
- Adaptive Resource Scheduling
|
||||
- Dynamic energy distribution based on function point importance
|
||||
Author : CGA Enhancement Project
|
||||
Time : 2026/03/11
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 数据结构定义
|
||||
# ============================================================================
|
||||
|
||||
class EnergyState(Enum):
|
||||
"""能量状态枚举"""
|
||||
ACTIVE = "active" # 活跃,有剩余能量
|
||||
DEPLETED = "depleted" # 能量耗尽
|
||||
COMPLETED = "completed" # 已完成覆盖
|
||||
SUSPENDED = "suspended" # 暂停(连续失败过多)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyAllocation:
|
||||
"""
|
||||
能量分配记录
|
||||
|
||||
Attributes:
|
||||
function_point: 功能点名称
|
||||
importance: 重要性评分 (0.0 - 1.0)
|
||||
allocated: 分配的总能量
|
||||
consumed: 已消耗的能量
|
||||
remaining: 剩余能量
|
||||
consecutive_failures: 连续失败次数
|
||||
state: 当前能量状态
|
||||
"""
|
||||
function_point: str
|
||||
importance: float
|
||||
allocated: float = 0.0
|
||||
consumed: float = 0.0
|
||||
remaining: float = 0.0
|
||||
consecutive_failures: int = 0
|
||||
state: EnergyState = EnergyState.ACTIVE
|
||||
total_attempts: int = 0
|
||||
successful_attempts: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResult:
|
||||
"""
|
||||
生成结果记录
|
||||
|
||||
Attributes:
|
||||
function_point: 目标功能点
|
||||
success: 是否成功覆盖
|
||||
coverage_delta: 覆盖率变化
|
||||
energy_cost: 消耗的能量
|
||||
code_generated: 生成的代码
|
||||
quality_score: 代码质量分数
|
||||
"""
|
||||
function_point: str
|
||||
success: bool
|
||||
coverage_delta: float = 0.0
|
||||
energy_cost: float = 1.0
|
||||
code_generated: str = ""
|
||||
quality_score: float = 0.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 能量初始化器
|
||||
# ============================================================================
|
||||
|
||||
class EnergyInitializer:
|
||||
"""
|
||||
能量初始化器
|
||||
|
||||
根据总能量预算和功能点重要性评分,初始化各功能点的能量分配
|
||||
"""
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_TOTAL_ENERGY = 10.0 # 默认总能量(对应最大迭代次数)
|
||||
MIN_ENERGY_PER_FP = 1.0 # 每个功能点最小能量
|
||||
ENERGY_BUFFER_RATIO = 0.1 # 能量缓冲比例(保留用于重分配)
|
||||
|
||||
def __init__(self,
|
||||
total_energy: float = None,
|
||||
min_energy: float = None,
|
||||
buffer_ratio: float = None):
|
||||
"""
|
||||
Args:
|
||||
total_energy: 总能量预算(默认为 max_iter)
|
||||
min_energy: 每个功能点最小能量
|
||||
buffer_ratio: 能量缓冲比例
|
||||
"""
|
||||
self.total_energy = total_energy or self.DEFAULT_TOTAL_ENERGY
|
||||
self.min_energy = min_energy or self.MIN_ENERGY_PER_FP
|
||||
self.buffer_ratio = buffer_ratio or self.ENERGY_BUFFER_RATIO
|
||||
|
||||
def initialize(self,
|
||||
function_points: List[Dict],
|
||||
max_iterations: int = None) -> Dict[str, EnergyAllocation]:
|
||||
"""
|
||||
初始化能量分配
|
||||
|
||||
Args:
|
||||
function_points: 功能点列表,每个元素包含 name, importance, covered 等
|
||||
max_iterations: 最大迭代次数(用于设置总能量)
|
||||
|
||||
Returns:
|
||||
功能点名称 -> 能量分配记录 的字典
|
||||
"""
|
||||
# 如果提供了最大迭代次数,使用它作为总能量
|
||||
if max_iterations:
|
||||
self.total_energy = float(max_iterations)
|
||||
|
||||
# 过滤出未覆盖的功能点
|
||||
uncovered_fps = [fp for fp in function_points if not fp.get('covered', False)]
|
||||
|
||||
if not uncovered_fps:
|
||||
logger.info("All function points are covered. No energy allocation needed.")
|
||||
return {}
|
||||
|
||||
# 计算总重要性
|
||||
total_importance = sum(fp.get('importance', 0.5) for fp in uncovered_fps)
|
||||
|
||||
# 预留缓冲能量
|
||||
buffer_energy = self.total_energy * self.buffer_ratio
|
||||
available_energy = self.total_energy - buffer_energy
|
||||
|
||||
# 按重要性比例分配能量
|
||||
allocations = {}
|
||||
|
||||
for fp in uncovered_fps:
|
||||
name = fp.get('name', 'unknown')
|
||||
importance = fp.get('importance', 0.5)
|
||||
|
||||
# 按比例计算分配能量,但不少于最小值
|
||||
if total_importance > 0:
|
||||
proportional_energy = (importance / total_importance) * available_energy
|
||||
else:
|
||||
proportional_energy = available_energy / len(uncovered_fps)
|
||||
|
||||
allocated = max(self.min_energy, proportional_energy)
|
||||
|
||||
allocations[name] = EnergyAllocation(
|
||||
function_point=name,
|
||||
importance=importance,
|
||||
allocated=allocated,
|
||||
consumed=0.0,
|
||||
remaining=allocated,
|
||||
consecutive_failures=0,
|
||||
state=EnergyState.ACTIVE,
|
||||
total_attempts=0,
|
||||
successful_attempts=0
|
||||
)
|
||||
|
||||
# 记录分配情况
|
||||
logger.info(f"Energy initialized: total={self.total_energy:.1f}, "
|
||||
f"allocated={sum(a.allocated for a in allocations.values()):.1f}, "
|
||||
f"buffer={buffer_energy:.1f}, "
|
||||
f"targets={len(allocations)}")
|
||||
|
||||
return allocations
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 目标选择器
|
||||
# ============================================================================
|
||||
|
||||
class TargetSelector:
|
||||
"""
|
||||
目标选择器
|
||||
|
||||
选择下一个需要生成测试的目标功能点
|
||||
采用优先级策略:重要性 × (剩余能量/分配能量)
|
||||
"""
|
||||
|
||||
# 连续失败阈值
|
||||
MAX_CONSECUTIVE_FAILURES = 3
|
||||
|
||||
def __init__(self, allocations: Dict[str, EnergyAllocation]):
|
||||
"""
|
||||
Args:
|
||||
allocations: 能量分配字典
|
||||
"""
|
||||
self.allocations = allocations
|
||||
|
||||
def select_next_target(self) -> Optional[EnergyAllocation]:
|
||||
"""
|
||||
选择下一个目标功能点
|
||||
|
||||
优先级计算:importance × (remaining / allocated) × (1 / (1 + consecutive_failures))
|
||||
|
||||
Returns:
|
||||
选中的能量分配记录,如果没有可用目标则返回 None
|
||||
"""
|
||||
# 筛选候选:未覆盖、有剩余能量、非暂停状态
|
||||
candidates = [
|
||||
alloc for alloc in self.allocations.values()
|
||||
if alloc.state == EnergyState.ACTIVE
|
||||
and alloc.remaining > 0
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
logger.info("No active targets with remaining energy.")
|
||||
return None
|
||||
|
||||
# 计算优先级并排序
|
||||
def calculate_priority(alloc: EnergyAllocation) -> float:
|
||||
# 重要性权重
|
||||
importance_weight = alloc.importance
|
||||
|
||||
# 能量剩余比例
|
||||
energy_ratio = alloc.remaining / alloc.allocated if alloc.allocated > 0 else 0
|
||||
|
||||
# 失败惩罚因子
|
||||
failure_penalty = 1.0 / (1.0 + alloc.consecutive_failures * 0.5)
|
||||
|
||||
# 综合优先级
|
||||
priority = importance_weight * energy_ratio * failure_penalty
|
||||
return priority
|
||||
|
||||
candidates.sort(key=calculate_priority, reverse=True)
|
||||
|
||||
selected = candidates[0]
|
||||
logger.debug(f"Selected target: {selected.function_point} "
|
||||
f"(importance={selected.importance:.2f}, "
|
||||
f"remaining={selected.remaining:.1f}, "
|
||||
f"failures={selected.consecutive_failures})")
|
||||
|
||||
return selected
|
||||
|
||||
def get_candidates_count(self) -> int:
|
||||
"""获取候选目标数量"""
|
||||
return len([a for a in self.allocations.values()
|
||||
if a.state == EnergyState.ACTIVE and a.remaining > 0])
|
||||
|
||||
def get_top_candidates(self, n: int = 3) -> List[EnergyAllocation]:
|
||||
"""获取优先级最高的 N 个候选目标"""
|
||||
candidates = [
|
||||
alloc for alloc in self.allocations.values()
|
||||
if alloc.state == EnergyState.ACTIVE and alloc.remaining > 0
|
||||
]
|
||||
|
||||
def calculate_priority(alloc: EnergyAllocation) -> float:
|
||||
importance_weight = alloc.importance
|
||||
energy_ratio = alloc.remaining / alloc.allocated if alloc.allocated > 0 else 0
|
||||
failure_penalty = 1.0 / (1.0 + alloc.consecutive_failures * 0.5)
|
||||
return importance_weight * energy_ratio * failure_penalty
|
||||
|
||||
candidates.sort(key=calculate_priority, reverse=True)
|
||||
return candidates[:n]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 能量消耗跟踪器
|
||||
# ============================================================================
|
||||
|
||||
class EnergyConsumptionTracker:
|
||||
"""
|
||||
能量消耗跟踪器
|
||||
|
||||
跟踪每次生成尝试的能量消耗,根据结果更新状态
|
||||
"""
|
||||
|
||||
# 能量衰减因子(连续失败时)
|
||||
ENERGY_DECAY_FACTOR = 0.7
|
||||
|
||||
def __init__(self, allocations: Dict[str, EnergyAllocation]):
|
||||
"""
|
||||
Args:
|
||||
allocations: 能量分配字典
|
||||
"""
|
||||
self.allocations = allocations
|
||||
self.history: List[GenerationResult] = []
|
||||
|
||||
def record_generation(self, result: GenerationResult) -> Dict[str, Any]:
|
||||
"""
|
||||
记录一次生成尝试
|
||||
|
||||
Args:
|
||||
result: 生成结果
|
||||
|
||||
Returns:
|
||||
更新后的状态信息
|
||||
"""
|
||||
self.history.append(result)
|
||||
|
||||
fp_name = result.function_point
|
||||
if fp_name not in self.allocations:
|
||||
logger.warning(f"Unknown function point: {fp_name}")
|
||||
return {'status': 'unknown', 'message': 'Unknown function point'}
|
||||
|
||||
alloc = self.allocations[fp_name]
|
||||
alloc.total_attempts += 1
|
||||
|
||||
# 消耗能量
|
||||
energy_cost = result.energy_cost
|
||||
alloc.consumed += energy_cost
|
||||
alloc.remaining = max(0, alloc.remaining - energy_cost)
|
||||
|
||||
if result.success:
|
||||
# 成功:重置失败计数,标记完成
|
||||
alloc.consecutive_failures = 0
|
||||
alloc.successful_attempts += 1
|
||||
alloc.state = EnergyState.COMPLETED
|
||||
|
||||
logger.info(f"[SUCCESS] Target covered: {fp_name} (attempts={alloc.total_attempts}, "
|
||||
f"energy_used={alloc.consumed:.1f})")
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'function_point': fp_name,
|
||||
'attempts': alloc.total_attempts,
|
||||
'energy_used': alloc.consumed
|
||||
}
|
||||
else:
|
||||
# 失败:增加失败计数
|
||||
alloc.consecutive_failures += 1
|
||||
|
||||
# 检查是否需要降低能量或暂停
|
||||
if alloc.consecutive_failures >= 3:
|
||||
# 能量减半
|
||||
old_remaining = alloc.remaining
|
||||
alloc.remaining *= self.ENERGY_DECAY_FACTOR
|
||||
|
||||
logger.warning(f"Consecutive failures for {fp_name}: {alloc.consecutive_failures}. "
|
||||
f"Energy reduced: {old_remaining:.1f} -> {alloc.remaining:.1f}")
|
||||
|
||||
# 如果剩余能量过低,暂停
|
||||
if alloc.remaining < 0.5:
|
||||
alloc.state = EnergyState.SUSPENDED
|
||||
logger.warning(f"Target suspended due to low energy: {fp_name}")
|
||||
|
||||
return {
|
||||
'status': 'suspended',
|
||||
'function_point': fp_name,
|
||||
'consecutive_failures': alloc.consecutive_failures,
|
||||
'remaining_energy': alloc.remaining
|
||||
}
|
||||
|
||||
# 检查能量是否耗尽
|
||||
if alloc.remaining <= 0:
|
||||
alloc.state = EnergyState.DEPLETED
|
||||
logger.warning(f"Target depleted: {fp_name}")
|
||||
|
||||
return {
|
||||
'status': 'depleted',
|
||||
'function_point': fp_name,
|
||||
'total_attempts': alloc.total_attempts
|
||||
}
|
||||
|
||||
return {
|
||||
'status': 'failed',
|
||||
'function_point': fp_name,
|
||||
'consecutive_failures': alloc.consecutive_failures,
|
||||
'remaining_energy': alloc.remaining
|
||||
}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
total = len(self.history)
|
||||
successful = sum(1 for r in self.history if r.success)
|
||||
|
||||
energy_by_fp = {}
|
||||
for result in self.history:
|
||||
fp = result.function_point
|
||||
if fp not in energy_by_fp:
|
||||
energy_by_fp[fp] = {'consumed': 0, 'attempts': 0, 'success': False}
|
||||
energy_by_fp[fp]['consumed'] += result.energy_cost
|
||||
energy_by_fp[fp]['attempts'] += 1
|
||||
if result.success:
|
||||
energy_by_fp[fp]['success'] = True
|
||||
|
||||
return {
|
||||
'total_attempts': total,
|
||||
'successful_attempts': successful,
|
||||
'success_rate': successful / total if total > 0 else 0,
|
||||
'energy_by_function_point': energy_by_fp
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 能量重分配器
|
||||
# ============================================================================
|
||||
|
||||
class EnergyRedistributor:
|
||||
"""
|
||||
能量重分配器
|
||||
|
||||
当某个功能点被覆盖后,将其剩余能量重新分配给其他未覆盖功能点
|
||||
"""
|
||||
|
||||
def __init__(self, allocations: Dict[str, EnergyAllocation]):
|
||||
"""
|
||||
Args:
|
||||
allocations: 能量分配字典
|
||||
"""
|
||||
self.allocations = allocations
|
||||
|
||||
def redistribute(self, completed_fp: str) -> Dict[str, float]:
|
||||
"""
|
||||
重分配已完成功能点的剩余能量
|
||||
|
||||
Args:
|
||||
completed_fp: 已完成的功能点名称
|
||||
|
||||
Returns:
|
||||
重分配详情 {target_fp: gained_energy}
|
||||
"""
|
||||
if completed_fp not in self.allocations:
|
||||
return {}
|
||||
|
||||
completed_alloc = self.allocations[completed_fp]
|
||||
|
||||
# 回收剩余能量
|
||||
recovered_energy = completed_alloc.remaining
|
||||
|
||||
if recovered_energy <= 0:
|
||||
logger.debug(f"No remaining energy to recover from {completed_fp}")
|
||||
return {}
|
||||
|
||||
# 找出活跃的未完成目标
|
||||
active_targets = [
|
||||
alloc for alloc in self.allocations.values()
|
||||
if alloc.state == EnergyState.ACTIVE and alloc.function_point != completed_fp
|
||||
]
|
||||
|
||||
if not active_targets:
|
||||
logger.info(f"No active targets to redistribute energy to.")
|
||||
return {}
|
||||
|
||||
# 按重要性比例分配
|
||||
total_importance = sum(a.importance for a in active_targets)
|
||||
redistribution = {}
|
||||
|
||||
for alloc in active_targets:
|
||||
if total_importance > 0:
|
||||
gain = (alloc.importance / total_importance) * recovered_energy
|
||||
else:
|
||||
gain = recovered_energy / len(active_targets)
|
||||
|
||||
alloc.allocated += gain
|
||||
alloc.remaining += gain
|
||||
redistribution[alloc.function_point] = gain
|
||||
|
||||
# 清零已完成目标的剩余能量
|
||||
completed_alloc.remaining = 0
|
||||
|
||||
logger.info(f"Redistributed {recovered_energy:.1f} energy from {completed_fp} "
|
||||
f"to {len(redistribution)} targets")
|
||||
|
||||
return redistribution
|
||||
|
||||
def redistribute_all(self) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
重分配所有已完成/暂停目标的剩余能量
|
||||
|
||||
Returns:
|
||||
完整的重分配详情
|
||||
"""
|
||||
all_redistributions = {}
|
||||
|
||||
# 收集所有可回收能量
|
||||
completed_fps = [
|
||||
name for name, alloc in self.allocations.items()
|
||||
if alloc.state in [EnergyState.COMPLETED, EnergyState.SUSPENDED]
|
||||
and alloc.remaining > 0
|
||||
]
|
||||
|
||||
for fp in completed_fps:
|
||||
redistribution = self.redistribute(fp)
|
||||
if redistribution:
|
||||
all_redistributions[fp] = redistribution
|
||||
|
||||
return all_redistributions
|
||||
|
||||
def revive_suspended(self, min_energy: float = 1.0) -> List[str]:
|
||||
"""
|
||||
尝试复活暂停的目标(如果有足够的回收能量)
|
||||
|
||||
Args:
|
||||
min_energy: 复活所需的最小能量
|
||||
|
||||
Returns:
|
||||
复活的目标列表
|
||||
"""
|
||||
revived = []
|
||||
|
||||
# 计算可用能量(来自已完成目标)
|
||||
available_energy = sum(
|
||||
alloc.remaining for alloc in self.allocations.values()
|
||||
if alloc.state == EnergyState.COMPLETED and alloc.remaining > 0
|
||||
)
|
||||
|
||||
# 找出暂停的目标
|
||||
suspended = [
|
||||
alloc for alloc in self.allocations.values()
|
||||
if alloc.state == EnergyState.SUSPENDED
|
||||
]
|
||||
|
||||
for alloc in suspended:
|
||||
if available_energy >= min_energy:
|
||||
# 复活
|
||||
alloc.state = EnergyState.ACTIVE
|
||||
alloc.remaining = min_energy
|
||||
alloc.allocated += min_energy
|
||||
alloc.consecutive_failures = 0
|
||||
available_energy -= min_energy
|
||||
revived.append(alloc.function_point)
|
||||
|
||||
logger.info(f"Revived suspended target: {alloc.function_point}")
|
||||
|
||||
return revived
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 能量分配器(主入口)
|
||||
# ============================================================================
|
||||
|
||||
class EnergyAllocator:
|
||||
"""
|
||||
能量分配器 - 第4层主入口
|
||||
|
||||
整合所有子模块,提供统一的能量管理接口
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_iterations: int = 5,
|
||||
total_energy: float = None):
|
||||
"""
|
||||
Args:
|
||||
max_iterations: 最大迭代次数
|
||||
total_energy: 总能量预算(默认使用 max_iterations)
|
||||
"""
|
||||
self.max_iterations = max_iterations
|
||||
self.total_energy = total_energy or float(max_iterations)
|
||||
|
||||
# 子模块
|
||||
self.initializer = EnergyInitializer(total_energy=self.total_energy)
|
||||
self.allocations: Dict[str, EnergyAllocation] = {}
|
||||
self.selector: Optional[TargetSelector] = None
|
||||
self.tracker: Optional[EnergyConsumptionTracker] = None
|
||||
self.redistributor: Optional[EnergyRedistributor] = None
|
||||
|
||||
# 状态
|
||||
self.initialized = False
|
||||
self.current_target: Optional[EnergyAllocation] = None
|
||||
|
||||
def initialize(self, function_points: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
初始化能量分配
|
||||
|
||||
Args:
|
||||
function_points: 功能点列表
|
||||
|
||||
Returns:
|
||||
初始化结果摘要
|
||||
"""
|
||||
self.allocations = self.initializer.initialize(
|
||||
function_points,
|
||||
max_iterations=self.max_iterations
|
||||
)
|
||||
|
||||
self.selector = TargetSelector(self.allocations)
|
||||
self.tracker = EnergyConsumptionTracker(self.allocations)
|
||||
self.redistributor = EnergyRedistributor(self.allocations)
|
||||
self.initialized = True
|
||||
|
||||
return {
|
||||
'total_energy': self.total_energy,
|
||||
'targets': len(self.allocations),
|
||||
'allocation_details': {
|
||||
name: {
|
||||
'importance': alloc.importance,
|
||||
'allocated': alloc.allocated,
|
||||
'state': alloc.state.value
|
||||
}
|
||||
for name, alloc in self.allocations.items()
|
||||
}
|
||||
}
|
||||
|
||||
def select_next_target(self) -> Optional[str]:
|
||||
"""
|
||||
选择下一个生成目标
|
||||
|
||||
Returns:
|
||||
目标功能点名称,如果没有可用目标则返回 None
|
||||
"""
|
||||
if not self.initialized:
|
||||
logger.warning("Energy allocator not initialized.")
|
||||
return None
|
||||
|
||||
self.current_target = self.selector.select_next_target()
|
||||
return self.current_target.function_point if self.current_target else None
|
||||
|
||||
def record_generation(self,
|
||||
success: bool,
|
||||
coverage_delta: float = 0.0,
|
||||
energy_cost: float = 1.0,
|
||||
quality_score: float = 0.0) -> Dict[str, Any]:
|
||||
"""
|
||||
记录一次生成尝试
|
||||
|
||||
Args:
|
||||
success: 是否成功覆盖目标
|
||||
coverage_delta: 覆盖率变化
|
||||
energy_cost: 消耗的能量
|
||||
quality_score: 代码质量分数
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
if not self.current_target:
|
||||
return {'status': 'error', 'message': 'No current target'}
|
||||
|
||||
result = GenerationResult(
|
||||
function_point=self.current_target.function_point,
|
||||
success=success,
|
||||
coverage_delta=coverage_delta,
|
||||
energy_cost=energy_cost,
|
||||
quality_score=quality_score
|
||||
)
|
||||
|
||||
update_result = self.tracker.record_generation(result)
|
||||
|
||||
# 如果成功,触发重分配
|
||||
if success:
|
||||
self.redistributor.redistribute(self.current_target.function_point)
|
||||
|
||||
return update_result
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取当前状态"""
|
||||
if not self.initialized:
|
||||
return {'initialized': False}
|
||||
|
||||
active_count = sum(1 for a in self.allocations.values()
|
||||
if a.state == EnergyState.ACTIVE and a.remaining > 0)
|
||||
completed_count = sum(1 for a in self.allocations.values()
|
||||
if a.state == EnergyState.COMPLETED)
|
||||
|
||||
return {
|
||||
'initialized': True,
|
||||
'total_energy': self.total_energy,
|
||||
'total_targets': len(self.allocations),
|
||||
'active_targets': active_count,
|
||||
'completed_targets': completed_count,
|
||||
'current_target': self.current_target.function_point if self.current_target else None,
|
||||
'statistics': self.tracker.get_statistics() if self.tracker else None
|
||||
}
|
||||
|
||||
def get_target_context(self, target_name: str = None) -> str:
|
||||
"""
|
||||
获取目标功能的上下文信息(用于 Prompt)
|
||||
|
||||
Args:
|
||||
target_name: 目标名称(默认使用当前目标)
|
||||
|
||||
Returns:
|
||||
上下文字符串
|
||||
"""
|
||||
if not target_name and self.current_target:
|
||||
target_name = self.current_target.function_point
|
||||
|
||||
if not target_name or target_name not in self.allocations:
|
||||
return ""
|
||||
|
||||
alloc = self.allocations[target_name]
|
||||
|
||||
context = []
|
||||
context.append(f"[TARGET: {target_name}]")
|
||||
context.append(f"Importance: {alloc.importance:.2f}")
|
||||
context.append(f"Remaining Energy: {alloc.remaining:.1f} / {alloc.allocated:.1f}")
|
||||
context.append(f"Previous Attempts: {alloc.total_attempts}")
|
||||
|
||||
if alloc.consecutive_failures > 0:
|
||||
context.append(f"Warning: {alloc.consecutive_failures} consecutive failures")
|
||||
context.append("Consider a different approach or sequence")
|
||||
|
||||
return "\n".join(context)
|
||||
|
||||
def mark_targets_completed(self, function_names: List[str]) -> Dict[str, str]:
|
||||
"""
|
||||
将已确认覆盖的功能点直接标记为完成。
|
||||
|
||||
这用于基线同步或一次迭代中命中多个功能点的情况,
|
||||
避免仅依赖当前 target 的涨分信号来判断完成状态。
|
||||
"""
|
||||
if not self.initialized:
|
||||
return {}
|
||||
|
||||
updates = {}
|
||||
for name in function_names:
|
||||
if name not in self.allocations:
|
||||
continue
|
||||
|
||||
alloc = self.allocations[name]
|
||||
if alloc.state == EnergyState.COMPLETED:
|
||||
updates[name] = "already_completed"
|
||||
continue
|
||||
|
||||
alloc.state = EnergyState.COMPLETED
|
||||
alloc.consecutive_failures = 0
|
||||
alloc.remaining = 0.0
|
||||
updates[name] = "completed"
|
||||
self.redistributor.redistribute(name)
|
||||
|
||||
return updates
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""生成能量分配报告"""
|
||||
if not self.initialized:
|
||||
return "Energy allocator not initialized."
|
||||
|
||||
lines = []
|
||||
lines.append("=" * 60)
|
||||
lines.append("ENERGY ALLOCATION REPORT")
|
||||
lines.append("=" * 60)
|
||||
lines.append(f"Total Energy: {self.total_energy:.1f}")
|
||||
lines.append(f"Max Iterations: {self.max_iterations}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("FUNCTION POINT STATUS:")
|
||||
lines.append("-" * 60)
|
||||
|
||||
for name, alloc in sorted(self.allocations.items(),
|
||||
key=lambda x: x[1].importance, reverse=True):
|
||||
status_icon = {
|
||||
EnergyState.ACTIVE: "🔄",
|
||||
EnergyState.COMPLETED: "✅",
|
||||
EnergyState.DEPLETED: "❌",
|
||||
EnergyState.SUSPENDED: "⏸️"
|
||||
}.get(alloc.state, "❓")
|
||||
|
||||
efficiency = (alloc.successful_attempts / alloc.total_attempts * 100
|
||||
if alloc.total_attempts > 0 else 0)
|
||||
|
||||
lines.append(f"{status_icon} {name}")
|
||||
lines.append(f" Importance: {alloc.importance:.2f} | "
|
||||
f"Energy: {alloc.remaining:.1f}/{alloc.allocated:.1f} | "
|
||||
f"Efficiency: {efficiency:.0f}%")
|
||||
lines.append(f" Attempts: {alloc.total_attempts} | "
|
||||
f"Consecutive Failures: {alloc.consecutive_failures}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("SUMMARY:")
|
||||
lines.append("-" * 60)
|
||||
stats = self.tracker.get_statistics()
|
||||
lines.append(f"Total Attempts: {stats['total_attempts']}")
|
||||
lines.append(f"Successful: {stats['successful_attempts']}")
|
||||
lines.append(f"Success Rate: {stats['success_rate']*100:.1f}%")
|
||||
|
||||
completed = sum(1 for a in self.allocations.values()
|
||||
if a.state == EnergyState.COMPLETED)
|
||||
lines.append(f"Targets Covered: {completed} / {len(self.allocations)}")
|
||||
|
||||
lines.append("=" * 60)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数
|
||||
# ============================================================================
|
||||
|
||||
def create_energy_allocator(function_points: List[Dict],
|
||||
max_iterations: int = 5) -> EnergyAllocator:
|
||||
"""
|
||||
便捷函数:创建并初始化能量分配器
|
||||
|
||||
Args:
|
||||
function_points: 功能点列表
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
初始化完成的能量分配器
|
||||
"""
|
||||
allocator = EnergyAllocator(max_iterations=max_iterations)
|
||||
allocator.initialize(function_points)
|
||||
return allocator
|
||||
1039
autoline/quality_evaluator.py
Normal file
1039
autoline/quality_evaluator.py
Normal file
File diff suppressed because it is too large
Load Diff
1029
autoline/semantic_analyzer.py
Normal file
1029
autoline/semantic_analyzer.py
Normal file
File diff suppressed because it is too large
Load Diff
580
autoline/test_history.py
Normal file
580
autoline/test_history.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""
|
||||
Description : Test History Manager (Layer 1 Support Module)
|
||||
- Store and manage test case history
|
||||
- Support sequence pattern analysis
|
||||
- Provide diversity statistics
|
||||
Author : CGA Enhancement Project
|
||||
Time : 2026/03/16
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Tuple, Set
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 数据结构定义
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class InputSequence:
|
||||
"""
|
||||
输入序列记录
|
||||
|
||||
Attributes:
|
||||
signal_name: 信号名称
|
||||
values: 赋值序列 [(time, value), ...]
|
||||
"""
|
||||
signal_name: str
|
||||
values: List[Tuple[int, Any]] = field(default_factory=list)
|
||||
|
||||
def to_pattern_string(self) -> str:
|
||||
"""转换为模式字符串(仅包含值)"""
|
||||
return "->".join(str(v[1]) for v in self.values)
|
||||
|
||||
def get_hash(self) -> str:
|
||||
"""获取序列哈希值"""
|
||||
return hashlib.md5(self.to_pattern_string().encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestRecord:
|
||||
"""
|
||||
测试用例记录
|
||||
|
||||
Attributes:
|
||||
test_id: 测试ID
|
||||
code: 生成的测试代码
|
||||
input_sequences: 输入信号序列列表
|
||||
target_function: 目标功能点
|
||||
covered_lines: 覆盖的代码行
|
||||
covered_functions: 覆盖的功能点
|
||||
coverage_score: 覆盖率分数
|
||||
diversity_scores: 多样性得分字典
|
||||
iteration: 迭代次数
|
||||
timestamp: 时间戳
|
||||
success: 是否成功
|
||||
"""
|
||||
test_id: str
|
||||
code: str = ""
|
||||
input_sequences: List[InputSequence] = field(default_factory=list)
|
||||
target_function: str = ""
|
||||
covered_lines: List[int] = field(default_factory=list)
|
||||
covered_functions: List[str] = field(default_factory=list)
|
||||
coverage_score: float = 0.0
|
||||
diversity_scores: Dict[str, float] = field(default_factory=dict)
|
||||
iteration: int = 0
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
success: bool = False
|
||||
|
||||
def get_sequence_patterns(self) -> Dict[str, str]:
|
||||
"""获取所有输入序列的模式"""
|
||||
return {seq.signal_name: seq.to_pattern_string() for seq in self.input_sequences}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequencePattern:
|
||||
"""
|
||||
序列模式统计
|
||||
|
||||
Attributes:
|
||||
pattern: 模式字符串
|
||||
count: 出现次数
|
||||
signal_name: 所属信号
|
||||
test_ids: 关联的测试ID列表
|
||||
"""
|
||||
pattern: str
|
||||
count: int = 0
|
||||
signal_name: str = ""
|
||||
test_ids: List[str] = field(default_factory=list)
|
||||
|
||||
def is_overused(self, threshold: int = 3) -> bool:
|
||||
"""判断是否过度使用"""
|
||||
return self.count >= threshold
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 序列提取器
|
||||
# ============================================================================
|
||||
|
||||
class SequenceExtractor:
|
||||
"""
|
||||
从测试代码中提取输入序列
|
||||
|
||||
解析Verilog测试代码,提取信号赋值序列
|
||||
"""
|
||||
|
||||
# 匹配信号赋值语句
|
||||
ASSIGNMENT_PATTERNS = [
|
||||
# 阻塞赋值: signal = value;
|
||||
r'(\w+)\s*=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
|
||||
# 非阻塞赋值: signal <= value;
|
||||
r'(\w+)\s*<=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
|
||||
# 简单赋值(无位宽)
|
||||
r'(\w+)\s*=\s*(\d+)\s*;',
|
||||
]
|
||||
|
||||
# 匹配延时
|
||||
DELAY_PATTERN = r'#\s*(\d+)\s*;'
|
||||
|
||||
# 匹配时钟周期等待
|
||||
CLOCK_WAIT_PATTERN = r'repeat\s*\(\s*(\d+)\s*\)\s*@\s*\(\s*posedge\s+(\w+)\s*\)'
|
||||
|
||||
def __init__(self):
|
||||
self.known_signals: Set[str] = set()
|
||||
|
||||
def set_known_signals(self, signals: List[str]):
|
||||
"""设置已知信号列表(用于过滤)"""
|
||||
self.known_signals = set(signals)
|
||||
|
||||
def extract(self, code: str) -> List[InputSequence]:
|
||||
"""
|
||||
从代码中提取输入序列
|
||||
|
||||
Args:
|
||||
code: Verilog测试代码
|
||||
|
||||
Returns:
|
||||
输入序列列表
|
||||
"""
|
||||
sequences = {}
|
||||
current_time = 0
|
||||
|
||||
# 按行处理代码
|
||||
lines = code.split('\n')
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 跳过注释和空行
|
||||
if not line or line.startswith('//'):
|
||||
continue
|
||||
|
||||
# 检测延时,更新时间
|
||||
delay_match = re.search(self.DELAY_PATTERN, line)
|
||||
if delay_match:
|
||||
current_time += int(delay_match.group(1))
|
||||
continue
|
||||
|
||||
# 检测时钟周期等待
|
||||
clock_match = re.search(self.CLOCK_WAIT_PATTERN, line, re.IGNORECASE)
|
||||
if clock_match:
|
||||
cycles = int(clock_match.group(1))
|
||||
current_time += cycles * 10 # 假设每周期10时间单位
|
||||
continue
|
||||
|
||||
# 检测赋值语句
|
||||
for pattern in self.ASSIGNMENT_PATTERNS:
|
||||
matches = re.finditer(pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
signal = match.group(1)
|
||||
value = match.group(2)
|
||||
|
||||
# 过滤非目标信号
|
||||
if self.known_signals and signal not in self.known_signals:
|
||||
continue
|
||||
|
||||
# 跳过明显的非输入信号
|
||||
if signal.lower() in ['i', 'j', 'k', 'cnt', 'count', 'temp']:
|
||||
continue
|
||||
|
||||
if signal not in sequences:
|
||||
sequences[signal] = InputSequence(signal_name=signal)
|
||||
|
||||
sequences[signal].values.append((current_time, value))
|
||||
current_time += 1 # 赋值语句本身占用1时间单位
|
||||
|
||||
return list(sequences.values())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 测试历史管理器
|
||||
# ============================================================================
|
||||
|
||||
class TestHistoryManager:
|
||||
"""
|
||||
测试历史管理器
|
||||
|
||||
管理已生成测试用例的历史记录,支持:
|
||||
- 测试用例存储和检索
|
||||
- 序列模式统计分析
|
||||
- 多样性分布统计
|
||||
"""
|
||||
|
||||
def __init__(self, history_file: str = None):
|
||||
"""
|
||||
Args:
|
||||
history_file: 历史记录文件路径(可选)
|
||||
"""
|
||||
|
||||
#必须先保存 history_file,否则 save() 方法无法找到文件路径
|
||||
self.history_file = history_file
|
||||
|
||||
self.records: List[TestRecord] = []
|
||||
self.patterns: Dict[str, SequencePattern] = {} # pattern_hash -> SequencePattern
|
||||
self.signal_patterns: Dict[str, List[str]] = defaultdict(list) # signal_name -> [pattern_hashes]
|
||||
self.sequence_extractor = SequenceExtractor()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'total_tests': 0,
|
||||
'successful_tests': 0,
|
||||
'total_coverage': 0.0,
|
||||
'avg_diversity': 0.0
|
||||
}
|
||||
|
||||
if history_file and os.path.exists(history_file):
|
||||
self.load(history_file)
|
||||
|
||||
# ==================== 记录管理 ====================
|
||||
|
||||
def add_record(self,
|
||||
code: str,
|
||||
test_id: str = None,
|
||||
target_function: str = "",
|
||||
covered_lines: List[int] = None,
|
||||
covered_functions: List[str] = None,
|
||||
coverage_score: float = 0.0,
|
||||
iteration: int = 0,
|
||||
success: bool = False,
|
||||
known_signals: List[str] = None) -> TestRecord:
|
||||
"""
|
||||
添加测试记录
|
||||
|
||||
Args:
|
||||
code: 测试代码
|
||||
test_id: 测试ID(自动生成如果未提供)
|
||||
target_function: 目标功能点
|
||||
covered_lines: 覆盖的代码行
|
||||
covered_functions: 覆盖的功能点
|
||||
coverage_score: 覆盖率分数
|
||||
iteration: 迭代次数
|
||||
success: 是否成功
|
||||
known_signals: 已知信号列表
|
||||
|
||||
Returns:
|
||||
创建的测试记录
|
||||
"""
|
||||
if test_id is None:
|
||||
test_id = f"test_{len(self.records)}_{datetime.now().strftime('%H%M%S')}"
|
||||
|
||||
# 提取输入序列
|
||||
if known_signals:
|
||||
self.sequence_extractor.set_known_signals(known_signals)
|
||||
input_sequences = self.sequence_extractor.extract(code)
|
||||
|
||||
# 创建记录
|
||||
record = TestRecord(
|
||||
test_id=test_id,
|
||||
code=code,
|
||||
input_sequences=input_sequences,
|
||||
target_function=target_function,
|
||||
covered_lines=covered_lines or [],
|
||||
covered_functions=covered_functions or [],
|
||||
coverage_score=coverage_score,
|
||||
iteration=iteration,
|
||||
success=success
|
||||
)
|
||||
|
||||
self.records.append(record)
|
||||
|
||||
# 更新模式统计
|
||||
self._update_patterns(record)
|
||||
|
||||
# 更新统计信息
|
||||
self._update_stats()
|
||||
|
||||
logger.debug(f"Added test record: {test_id}, sequences: {len(input_sequences)}")
|
||||
|
||||
return record
|
||||
|
||||
def get_record(self, test_id: str) -> Optional[TestRecord]:
|
||||
"""根据ID获取记录"""
|
||||
for record in self.records:
|
||||
if record.test_id == test_id:
|
||||
return record
|
||||
return None
|
||||
|
||||
def get_recent_records(self, n: int = 10) -> List[TestRecord]:
|
||||
"""获取最近的N条记录"""
|
||||
return self.records[-n:] if len(self.records) >= n else self.records
|
||||
|
||||
def get_successful_records(self) -> List[TestRecord]:
|
||||
"""获取所有成功的记录"""
|
||||
return [r for r in self.records if r.success]
|
||||
|
||||
# ==================== 模式分析 ====================
|
||||
|
||||
def _update_patterns(self, record: TestRecord):
|
||||
"""更新序列模式统计"""
|
||||
for seq in record.input_sequences:
|
||||
pattern_str = seq.to_pattern_string()
|
||||
pattern_hash = seq.get_hash()
|
||||
|
||||
if pattern_hash not in self.patterns:
|
||||
self.patterns[pattern_hash] = SequencePattern(
|
||||
pattern=pattern_str,
|
||||
count=1,
|
||||
signal_name=seq.signal_name,
|
||||
test_ids=[record.test_id]
|
||||
)
|
||||
else:
|
||||
self.patterns[pattern_hash].count += 1
|
||||
self.patterns[pattern_hash].test_ids.append(record.test_id)
|
||||
|
||||
# 按信号索引
|
||||
if pattern_hash not in self.signal_patterns[seq.signal_name]:
|
||||
self.signal_patterns[seq.signal_name].append(pattern_hash)
|
||||
|
||||
def get_overused_patterns(self, threshold: int = 3) -> List[SequencePattern]:
|
||||
"""
|
||||
获取过度使用的模式
|
||||
|
||||
Args:
|
||||
threshold: 过度使用阈值
|
||||
|
||||
Returns:
|
||||
过度使用的模式列表
|
||||
"""
|
||||
return [p for p in self.patterns.values() if p.is_overused(threshold)]
|
||||
|
||||
def get_common_patterns(self, top_n: int = 5) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取最常见的模式
|
||||
|
||||
Args:
|
||||
top_n: 返回数量
|
||||
|
||||
Returns:
|
||||
[(pattern, count), ...]
|
||||
"""
|
||||
sorted_patterns = sorted(
|
||||
self.patterns.items(),
|
||||
key=lambda x: x[1].count,
|
||||
reverse=True
|
||||
)
|
||||
return [(p[1].pattern, p[1].count) for p in sorted_patterns[:top_n]]
|
||||
|
||||
def get_pattern_for_signal(self, signal_name: str) -> List[SequencePattern]:
|
||||
"""获取特定信号的所有模式"""
|
||||
pattern_hashes = self.signal_patterns.get(signal_name, [])
|
||||
return [self.patterns[h] for h in pattern_hashes if h in self.patterns]
|
||||
|
||||
# ==================== 多样性分析 ====================
|
||||
|
||||
def calculate_sequence_diversity(self, new_sequences: List[InputSequence]) -> float:
|
||||
"""
|
||||
计算新序列与历史记录的多样性得分
|
||||
|
||||
Args:
|
||||
new_sequences: 新的输入序列列表
|
||||
|
||||
Returns:
|
||||
多样性得分 (0.0 - 1.0)
|
||||
"""
|
||||
if not self.records:
|
||||
return 1.0 # 没有历史记录时,认为完全多样
|
||||
|
||||
if not new_sequences:
|
||||
return 0.0 # 没有序列时,多样性为0
|
||||
|
||||
# 检查模式重复度
|
||||
new_patterns = {seq.get_hash() for seq in new_sequences}
|
||||
total_patterns = len(new_patterns)
|
||||
|
||||
if total_patterns == 0:
|
||||
return 0.0
|
||||
|
||||
# 计算新模式比例
|
||||
new_pattern_count = sum(1 for h in new_patterns if h not in self.patterns)
|
||||
pattern_diversity = new_pattern_count / total_patterns
|
||||
|
||||
return pattern_diversity
|
||||
|
||||
def calculate_edit_distance_diversity(self, new_code: str) -> float:
|
||||
"""
|
||||
基于编辑距离计算多样性
|
||||
|
||||
使用简化的编辑距离计算
|
||||
"""
|
||||
if not self.records:
|
||||
return 1.0
|
||||
|
||||
# 获取最近的记录作为参考
|
||||
recent_records = self.get_recent_records(5)
|
||||
|
||||
min_distance = float('inf')
|
||||
for record in recent_records:
|
||||
distance = self._levenshtein_distance(new_code, record.code)
|
||||
min_distance = min(min_distance, distance)
|
||||
|
||||
# 归一化到 [0, 1]
|
||||
max_len = max(len(new_code), max(len(r.code) for r in recent_records))
|
||||
if max_len == 0:
|
||||
return 0.0
|
||||
|
||||
return min_distance / max_len
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
"""计算Levenshtein编辑距离(简化版)"""
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
# 使用简化的计算(抽样)
|
||||
if len(s1) > 500:
|
||||
s1 = s1[:500]
|
||||
if len(s2) > 500:
|
||||
s2 = s2[:500]
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
# ==================== 统计信息 ====================
|
||||
|
||||
def _update_stats(self):
|
||||
"""更新统计信息"""
|
||||
self.stats['total_tests'] = len(self.records)
|
||||
self.stats['successful_tests'] = sum(1 for r in self.records if r.success)
|
||||
|
||||
if self.records:
|
||||
self.stats['total_coverage'] = sum(r.coverage_score for r in self.records)
|
||||
self.stats['avg_coverage'] = self.stats['total_coverage'] / len(self.records)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
**self.stats,
|
||||
'total_patterns': len(self.patterns),
|
||||
'overused_patterns': len(self.get_overused_patterns()),
|
||||
'unique_signals': len(self.signal_patterns)
|
||||
}
|
||||
|
||||
def get_diversity_report(self) -> str:
|
||||
"""生成多样性报告"""
|
||||
lines = []
|
||||
lines.append("=" * 50)
|
||||
lines.append("TEST HISTORY DIVERSITY REPORT")
|
||||
lines.append("=" * 50)
|
||||
lines.append(f"Total Tests: {self.stats['total_tests']}")
|
||||
lines.append(f"Successful Tests: {self.stats['successful_tests']}")
|
||||
lines.append(f"Total Patterns: {len(self.patterns)}")
|
||||
lines.append("")
|
||||
|
||||
# 常见模式
|
||||
lines.append("TOP 5 COMMON PATTERNS:")
|
||||
common = self.get_common_patterns(5)
|
||||
for i, (pattern, count) in enumerate(common, 1):
|
||||
lines.append(f" {i}. {pattern[:40]}... (x{count})")
|
||||
|
||||
# 过度使用的模式
|
||||
overused = self.get_overused_patterns()
|
||||
if overused:
|
||||
lines.append("")
|
||||
lines.append("OVERUSED PATTERNS (need diversification):")
|
||||
for p in overused[:5]:
|
||||
lines.append(f" - {p.signal_name}: {p.pattern[:30]}... (used {p.count} times)")
|
||||
|
||||
lines.append("=" * 50)
|
||||
return "\n".join(lines)
|
||||
|
||||
# ==================== 持久化 ====================
|
||||
|
||||
def save(self, filepath: str = None):
|
||||
"""保存历史记录到文件"""
|
||||
filepath = filepath or self.history_file
|
||||
if not filepath:
|
||||
return
|
||||
|
||||
# 手动构建可序列化的数据结构
|
||||
records_data = []
|
||||
for r in self.records:
|
||||
record_dict = {
|
||||
'test_id': r.test_id,
|
||||
'code': r.code,
|
||||
'input_sequences': [],
|
||||
'target_function': r.target_function,
|
||||
'covered_lines': r.covered_lines,
|
||||
'covered_functions': r.covered_functions,
|
||||
'coverage_score': r.coverage_score,
|
||||
'diversity_scores': r.diversity_scores,
|
||||
'iteration': r.iteration,
|
||||
'timestamp': r.timestamp,
|
||||
'success': r.success
|
||||
}
|
||||
# 手动转换 InputSequence 对象
|
||||
for seq in r.input_sequences:
|
||||
record_dict['input_sequences'].append({
|
||||
'signal_name': seq.signal_name,
|
||||
'values': seq.values
|
||||
})
|
||||
records_data.append(record_dict)
|
||||
|
||||
data = {
|
||||
'records': records_data,
|
||||
'stats': self.stats
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Test history saved to {filepath}")
|
||||
|
||||
def load(self, filepath: str):
|
||||
"""从文件加载历史记录"""
|
||||
if not os.path.exists(filepath):
|
||||
return
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.records = []
|
||||
for r in data.get('records', []):
|
||||
sequences = [
|
||||
InputSequence(**s) for s in r.get('input_sequences', [])
|
||||
]
|
||||
record = TestRecord(
|
||||
test_id=r['test_id'],
|
||||
code=r['code'],
|
||||
input_sequences=sequences,
|
||||
target_function=r.get('target_function', ''),
|
||||
covered_lines=r.get('covered_lines', []),
|
||||
covered_functions=r.get('covered_functions', []),
|
||||
coverage_score=r.get('coverage_score', 0.0),
|
||||
iteration=r.get('iteration', 0),
|
||||
timestamp=r.get('timestamp', ''),
|
||||
success=r.get('success', False)
|
||||
)
|
||||
self.records.append(record)
|
||||
self._update_patterns(record)
|
||||
|
||||
self.stats = data.get('stats', self.stats)
|
||||
logger.info(f"Loaded {len(self.records)} test records from {filepath}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数
|
||||
# ============================================================================
|
||||
|
||||
def create_test_history(history_file: str = None) -> TestHistoryManager:
|
||||
"""创建测试历史管理器"""
|
||||
return TestHistoryManager(history_file=history_file)
|
||||
Reference in New Issue
Block a user