Files
TBgen_App/autoline/TB3_funccheck.py

837 lines
53 KiB
Python
Raw Normal View History

2026-03-30 16:46:48 +08:00
"""
Description : The functionality checking of the generated TB, the submodule of Autoline
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/7/22 10:36:06
LastEdited : 2025/2/25 22:11:13
"""
import os
import LLM_call as llm
import iverilog_call as iv
import python_call as py
import numpy as np
import loader_saver as ls
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from loader_saver import autologger as logger
from loader_saver import log_localprefix
class TaskTBcheck():
"""
### description
- this is the self-checking stage of our pipeline; This is the main contribution of AutoBench2
- This stage is to check the functional correctness of the testbench generated by AutoBench.
"""
def __init__(self, task_dir:str, task_id:str, description:str, module_header:str, TB_code_v:str, TB_code_py:str|None=None, rtl_list:list[str]=None, rtl_num:int=20, scenario_num=None, correct_max:int=3, runfiles_save:bool=True, discriminator_mode:str="col_full_wrong", corrector_mode:str="naive", circuit_type:str=None, rtl_compens_max_iter:int=3, rtl_compens_en:bool=True, desc_improve:bool=False, **LLM_kwargs) -> None:
"""
- input:
- task_dir: the root directory of the taskTBcheck
- task_id: the name of the problem
- description: the description of the problem
- module_header: the header of the module
- TB_code_v: the generated verilog testbench code
- TB_code_py (opt.): the generated python testbench code, if None, the tb is in a pure verilog mode
- rtl_list (opt.): the list of llm-generated RTL codes, if None, will generate 20 RTL codes using LLM
- rtl_num (opt.): the number of RTL codes to generate, only used when rtl_list is None
- scenario_num (opt.): the number of scenarios in the testbench, if None, will be calculated from the failed scenarios (not accurate but won't impact the results)
- correct_max (opt.): the maximum number of correction iterations
- runfiles_save (opt.): whether to save the compilation files in TB_discrimination
- discriminator_mode (default: col_full_wrong): the mode of the discriminator
- corrector_mode (default: naive): the mode of the corrector
- circuit_type (opt.): the type of the circuit, used in the corrector (better performance if provided)
- rtl_compens_max_iter (default: 3): the maximum number of iterations of RTL compensation
- rtl_compens_en (default: True): whether to enable RTL compensation
- **LLM_kwargs: the keyword arguments for LLM (used in corrector and rtl generation), including:
- "main_model": the main llm name used in TB_generation and correction
- "rtlgen_model": the llm naem used in RTL generation
- "api_key_path": the path of the api key
- "temperature": the temperature of LLM
"""
self.task_dir = task_dir
self.working_dir = self.task_dir
self.task_id = task_id
self.description = description
self.module_header = module_header
self.TB_code_v = TB_code_v
self.TB_code_py = TB_code_py
self.pychecker_en = TB_code_py is not None
self.rtl_list = rtl_list
self.scenario_num = scenario_num
self.correct_max = correct_max
self.runfiles_save = runfiles_save
self.main_model = LLM_kwargs.get("main_model", None)
self.rtlgen_model = LLM_kwargs.get("rtlgen_model", None)
self.llm_api_key_path = LLM_kwargs.get("api_key_path", "config/key_API.json")
self.llm_temperature = LLM_kwargs.get("temperature", None)
self.circuit_type = circuit_type
self.rtl_compens_max_iter = rtl_compens_max_iter # see self.discriminate_TB for more info
self.rtl_compens_en = rtl_compens_en
self.desc_improve = desc_improve
self.tolerance_for_same_wrong_scen = 2
self.same_wrong_scen_times = 0
# discriminator and corrector
self.discriminator_mode = discriminator_mode
self.corrector_mode = corrector_mode
self.discriminator = TB_discriminator(discriminator_mode)
self.corrector = TB_corrector(self.corrector_mode, self.pychecker_en, self.circuit_type)
self.improver = SPEC_improver(description, "naive", self.pychecker_en, self.main_model, self.circuit_type)
# rtl list and number
self.rtl_newly_gen_num = 0
if self.rtl_list is None:
# self.rtl_num = rtl_num
self.set_rtl_num(rtl_num)
self.rtl_list_gen()
else:
# self.rtl_num = len(self.rtl_list)
self.set_rtl_num(len(self.rtl_list))
self.scenario_matrix = None
self.wrong_scen_num = 0 # a small number as default, will be replaced later
self.previous_wrong_scen_num = 9999 # a very large number as default, will be replaced later
self.TB_syntax_error = False
# tb analysis results
self.tb_pass = None
self.wrong_col_index = None
self.correct_col_index = None
self.unsure_col_index = None
# next_action
self.next_action = None
self.iter_now = 0
self.corrected = False # this means the TB has been corrected
if self.main_model is None:
logger.warning("main_model not found, may have trouble while correcting tb")
# record and runinfo
self.op_record = [] # record the order of the operations, similar to the var in autoline; will be added to the funccheck's op_record in the final runinfo.
@property
def rtl_num(self):
"""protected attr. rtl_num is initialized at the beginning and will not be changed"""
return self._rtl_num
@log_localprefix("TBcheck")
def run(self):
"""
- the main function of TaskTBcheck
- the TB check stage contains several sub-stages:
- 1. TB discriminating
- 2. TB correcting
- output: will update the next action of the task, including:
- "pass": the TB already passed the selfcheck, will go to the evaluation stage
- "reboot": the whole pipeline will start from the very beginning
- workflow: inital discrimination -> correction-discrimination lloop -> pass or reboot
"""
# TODO: if error occurs, go to reboot.
# initial discrimination
tolerance = 1
syntax_error = False
self.discriminate_TB()
if self.TB_syntax_error:
logger.negative("Testbench has syntax error, I give up. Reboot the whole process")
syntax_error = True
self.next_action = "reboot"
elif self.tb_pass:
logger.info("Testbench passed the funccheck")
self.next_action = "pass"
else:
# enter the correction loop, the initial TB has no syntax error when entering
if self.correct_max == 0:
logger.negative("No correction is allowed, I give up. Reboot the whole autoline process")
self.next_action = "reboot"
for self.iter_now in range(1, self.correct_max+1):
if (self.iter_now > 1) and (self.wrong_scen_num > self.previous_wrong_scen_num) and (not syntax_error):
# give up, the correction makes it worse
logger.negative(f"wrong scenarios increased ({self.wrong_scen_num} > {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
self.next_action = "reboot"
break
elif (self.iter_now > 1) and (self.wrong_scen_num == self.previous_wrong_scen_num) and (not syntax_error):
self.same_wrong_scen_times += 1
if self.same_wrong_scen_times >= self.tolerance_for_same_wrong_scen:
logger.info(f"wrong scenarios not decreased for {self.tolerance_for_same_wrong_scen} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), I give up, quiting the funccheck stage...")
self.next_action = "reboot"
break
else:
logger.info(f"wrong scenarios not decreased for {self.same_wrong_scen_times} times ({self.wrong_scen_num} = {self.previous_wrong_scen_num}), continue the correction")
self.correct_TB()
self.discriminate_TB()
if self.tb_pass:
logger.info("Testbench passed the funccheck after correction")
self.next_action = "pass"
break
elif self.TB_syntax_error:
# if the syntax error is from the corrector, we should roll back before correction
if tolerance > 0:
logger.negative(f"Testbench has syntax error after correction, I still have tolerance for that (tolerance={tolerance}). roll back and retry the self correction.")
self.TB_code_v, self.TB_code_py = self.TB_code_v_before_cor, self.TB_code_py_before_cor
tolerance -= 1
syntax_error = True
else:
logger.negative("Testbench has syntax error after correction, I don't have tolerance. I give up. Reboot the whole autoline process")
self.next_action = "reboot"
syntax_error = True
break
self.next_action = "reboot" if self.iter_now == self.correct_max else None
if (self.next_action == "reboot") and (not syntax_error) and (self.desc_improve):
# the desc improver does not work well so we do not use it in this work
self.improve_SPEC()
logger.info(f"self funccheck finished. Next Action: [{self.next_action}]")
return self.next_action
@log_localprefix("discriminator")
def discriminate_TB(self, no_any_files:bool=False):
"""
- check the correctness of the testbench and return the rtl analysis results in matrix form
- important data: the rtl list, the TB code
- update the following data: `scenario_matrix`, `tb_pass`, `wrong_col_index`, `correct_col_index`, `unsure_col_index`, `wrong_scen_num`
"""
rtl_dir_prefix = "RTL_"
self.op_record.append("discrim")
self.working_dir = os.path.join(self.task_dir, f"discrim_{self.iter_now}")
logger.info(f"Discriminating the testbench, NO.{self.iter_now} discrimination")
for i in range(self.rtl_compens_max_iter):
# the loop is for the case that too few RTL passed the syntax check, generate more rtls and recheck
failed_scenario_matrix = []
# for rtl_code in self.rtl_list:
self.TB_syntax_error = True
syntax_error_rtl = []
for rtl_idx, rtl_code in enumerate(self.rtl_list):
rtl_dir = os.path.join(self.working_dir, f"{rtl_dir_prefix}{rtl_idx+1}")
scenario_vector = self.run_testbench(rtl_dir, self.TB_code_v, rtl_code, self.TB_code_py, rtl_idx+1, self.runfiles_save and (not no_any_files))
failed_scenario_matrix.append(scenario_vector) # like [[2, 5], [3, 4, 5]]
if scenario_vector != [-1]:
self.TB_syntax_error = False
else:
syntax_error_rtl.append(rtl_idx+1)
if syntax_error_rtl != []:
logger.info(f"RTL(s) {syntax_error_rtl} have syntax error during discrimination")
if self.TB_syntax_error:
# there are two cases for TB syntax error:
# 1. this syntax error is from the previous stage, if so, we should directly reboot the whole autoline process
# 2. this syntax error is from the corrector, if so, we roll back to the version before correction and retry the correction
self.tb_pass = False
return None
# the len of each scenario vector should be the same, thus we transform each vector into a onehot vector [[1,1,0,1,1,0], [1,1,1,0,0,0]]
self.scenario_matrix = self.failed_scenarios_to_onehot_array(failed_scenario_matrix, max_scen_idx=self.scenario_num, taskid=self.task_id)
if not no_any_files:
# save this matrix into the working dir in a human readable form
np.savetxt(os.path.join(self.working_dir, "scenario_matrix.csv"), self.scenario_matrix, delimiter=",", fmt="%d")
# save this matrix in a plot form
self.draw_scenario_matrix(self.scenario_matrix, self.task_id, os.path.join(self.working_dir, "scenario_matrix.png"))
# we delete the syntax errored rtl, if no syntax error in TB
self.rtl_list = [rtl for rtl, scen in zip(self.rtl_list, failed_scenario_matrix) if scen != [-1]]
failed_scenario_matrix = [scen for scen in failed_scenario_matrix if scen != [-1]]
if len(self.rtl_list) < 0.5*self.rtl_num:
# too few RTL passed the syntax check
logger.info(f"too few RTL passed the syntax check ({len(self.rtl_list)}/{self.rtl_num}), I will generate more and recheck. This is not TB's fault.")
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num-len(self.rtl_list), self.description, self.module_header, self.rtlgen_model, self.rtl_list)
self.rtl_newly_gen_num += gen_num
# delete the previous rtl dirs (dir start with rtl_dir_prefix and under the working dir)
for subdir in os.listdir(self.working_dir):
if subdir.startswith(rtl_dir_prefix) and os.path.isdir(os.path.join(self.working_dir, subdir)):
os.system(f"rm -rf {os.path.join(self.working_dir, subdir)}")
logger.info(f"re-discriminate the testbench with updated RTL list")
if i == self.rtl_compens_max_iter-1:
logger.info(f"no re-discrimination since the max iteration reached")
else:
break
# discriminate the testbench according to the one hot matrix
self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index = self.discriminator.discriminate(self.scenario_matrix)
self.previous_wrong_scen_num = self.wrong_scen_num
self.wrong_scen_num = len(self.wrong_col_index)
return self.tb_pass, self.wrong_col_index, self.correct_col_index, self.unsure_col_index
@log_localprefix("corrector")
def correct_TB(self):
"""
- correct the testbench by using the RTL analysis results
"""
self.op_record.append("correct")
self.working_dir = os.path.join(self.task_dir, f"correct_{self.iter_now}")
self.TB_code_v_before_cor, self.TB_code_py_before_cor = self.TB_code_v, self.TB_code_py
self.TB_code_v, self.TB_code_py = self.corrector.correct(self.description, self.wrong_col_index, self.TB_code_v, self.main_model, self.TB_code_py, self.working_dir)
self.corrected = True
# self check if the testbench is corrected
# self.working_dir = os.path.join(self.task_dir, f"TBcheck_after_correct")
# self.discriminate_TB()
# if self.tb_pass:
# logger.info(f"[{self.task_id}] - Testbench passed the selfcheck after correction")
# self.next_action = "pass"
# else:
# logger.warning(f"[{self.task_id}] - Testbench failed the selfcheck after correction, failed scenarios: {self.wrong_col_index}")
@log_localprefix("improver")
def improve_SPEC(self):
"""
- improve the specification of the task according to the discrimination and correction results
"""
self.op_record.append("improve")
self.working_dir = os.path.join(self.task_dir, "improve_Desc")
self.description = self.improver.improve(self.wrong_col_index, self.correct_col_index, self.TB_code_v, self.TB_code_py, working_dir=self.working_dir)
@staticmethod
def run_testbench(dir, driver_code:str, DUT_code:str, checker_code:str, rtl_index:int, save_en:bool=True):
"""
- modified from autoline.py TBEval.run_testbench
- it has two mode: pychecker mode or verilog testbench mode
-input:
- dir: the dir to save the TB, DUT and pychecker code
- driver_code: str; the testbench code
- DUT_code: str; the DUT code
- checker_code: str; the pychecker code
- output:
- a list of failed scenarios (if rtl has syntax error, return [-1])
"""
# iverilog part
# save the TB and DUT
os.makedirs(dir, exist_ok=True)
v_driver_path = os.path.join(dir, "driver.v")
py_checker_path = os.path.join(dir, "checker.py")
dut_path = os.path.join(dir, "DUT.v")
with open(v_driver_path, "w") as f:
f.write(driver_code)
with open(dut_path, "w") as f:
f.write(DUT_code)
iv_run_info = iv.iverilog_call_and_save(dir, silent=True)
if not iv_run_info[0]:
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed, the PREREQUISITE of 'Evaluation' is no syntactic error from Testbench!!!")
# raise RuntimeError("Iverilog Compilation Failed")
return [-1]
with open(py_checker_path, "w") as f:
f.write(checker_code)
py_run_info = py.python_call_and_save(pypath=py_checker_path, silent=True)
if not py_run_info[0]:
# logger.trace(f"RTL index [{rtl_index}]: Iverilog Compilation Failed: the PREREQUISITE of 'Evaluation' is no syntactic error from Python code!!!")
# raise RuntimeError("Python Compilation Failed")
return [-1]
python_info_out = py_run_info[1]["out"]
python_info_out : str
# find the last ] in the out
last_bracket_end = python_info_out.rfind("]")
# find the last [ in the out
last_bracket_start = python_info_out.rfind("[")
# if the last [..] is a [], return []
if last_bracket_end == last_bracket_start+1:
return []
# extract the digits
failed_scenarios = python_info_out[last_bracket_start+1:last_bracket_end].replace("'", "").split(",")
# if the item is not pure digit such as 2b, then we only extract the digit part
failed_scenarios = [int("".join([char for char in scenario if char.isdigit()])) for scenario in failed_scenarios]
failed_scenarios = list(map(int, failed_scenarios))
# if save_en false, we delete the dir
if not save_en:
os.system(f"rm -rf {dir}")
return list(set(failed_scenarios))
def rtl_list_gen(self)->list[str]:
"""
- generate the RTL list using LLM, will empty the old rtl list
- attr needed: description, module_header, rtl_num, llm_model
- attr changed: rtl_list
"""
self.rtl_list = []
logger.info(f"rtl list not found, generating naive rtls for testbench checking")
self.rtl_list, gen_num = self.gen_rtl(self.rtl_num, self.description, self.module_header, self.rtlgen_model, self.rtl_list)
self.rtl_newly_gen_num += gen_num
# save the rtl list
save_path = os.path.join(self.task_dir, "rtl_list.json")
os.makedirs(self.task_dir, exist_ok=True)
ls.save_json_lines([{"task_id": self.task_id, "llmgen_RTL": self.rtl_list}], save_path)
@staticmethod
def gen_rtl(num:int, description:str, header:str, llm_mode:str, rtl_list:list=[]):
"""
- input:
- num (int): the number of RTLs to generate
- description (str): the description of the rtl problem
- header (str): the header of the module
- llm (str): the llm model to use (official model name)
- rtl_list (list) [optional]: the newly generated RTLs will be appended to this list, can be empty
- output:
- rtl_list (list): the list of the newly generated RTLs (and the old ones, if have)
"""
rtl_gen_num = 0
prompt = "Your task is to write a verilog RTL design according to the design specification. The infomation we have is the problem description that guides student to write the RTL code (DUT) and the header of the desired module. here is the problem description:\n"
prompt += description
prompt += "\nHere is the header of the desired module:\n"
prompt += header
prompt += "\nPlease only return the module code (header should be included) in verilog, please do not include any other words."
for i in range(num):
# call llm
answer = llm.llm_call(prompt, llm_mode)[0]
# extract the module code
module_code = llm.extract_code(answer, "verilog")[0]
# logger.trace(f"[{self.task_id}] - {i+1} RTLs generated")
rtl_list.append(module_code)
rtl_gen_num += 1
logger.info("%d naive rtls generated"%(rtl_gen_num))
return rtl_list, rtl_gen_num
@staticmethod
def failed_scenarios_to_onehot_array(failed_scenarios:list[list], max_scen_idx:int|None=None, taskid:str=""):
"""
- input: [failed_scenarios:list[int]] (for example: [[1,2,3], [2,3,4], [1,3,4], [-1]]), if one failed scenario list is [-1], it means the rtl has syntax error, should be skipped
- output (np.array): a onehot array (for example: [[0,0,0,1], [1,0,0,0], [0,1,0,0], [-1,-1,-1,-1]]) (1 denots pass, 0 denotes fail, -1 means syntax error)
"""
# find the max scenario index
listlen = len(failed_scenarios)
max_idx_given = max_scen_idx if max_scen_idx is not None else 1
# we calculate the max_index_cal, and define the final max scenario index using max(max_index_cal, max_index_given)
max_idx_cal = max(map(lambda x: max(x) if x != [] else 0, failed_scenarios))
if max_idx_cal in [-1, 0]:
# -1: all the scenarios in this rtl are -1
# 0: usually not possible because the scenario index is from 1, but if exists, set to 1.
max_idx_cal = 1
if failed_scenarios == list(map(lambda x: [], range(listlen))):
# this means all rtl passed
max_idx_cal = 1 # set to 1 otherwise the plot will be empty
max_idx = max(max_idx_cal, max_idx_given)
# if the failed scenario list is [-1], then all the scenarios in this rtl are -1
# create the onehot array
grid_map = [[1]*max_idx for _ in range(listlen)]
for rtl_idx, failed_scens in enumerate(failed_scenarios):
if failed_scens == [-1]:
grid_map[rtl_idx] = [-1]*max_idx
continue
for scen_idx in failed_scens:
grid_map[rtl_idx][scen_idx-1] = 0
return np.array(grid_map)
@staticmethod
def draw_scenario_matrix(scenario_matrix:np.ndarray, task_id:str, saving_path:str):
"""
- draw the 2D failed scenario array. The element in the array can only be 0, 1, -1. We use red for 0, green for 1, and gray for -1.
- if the scenario is empty, will return a gray color block.
"""
if len(scenario_matrix) == 0:
scenario_matrix = np.array([[-1]])
# if the element in data not in [0, 1, -1], change the element to -1
scenario_matrix = np.where(np.logical_or(scenario_matrix == 0, np.logical_or(scenario_matrix == 1, scenario_matrix == -1)), scenario_matrix, -1)
# get the RGB values for salmon, grey and mediumseagreen
salmon = mcolors.to_rgb("salmon")
grey = mcolors.to_rgb("grey")
mediumseagreen = mcolors.to_rgb("mediumseagreen")
color_mapping = {
0: salmon,
1: mediumseagreen,
-1: grey
}
rgb_image = np.array([[color_mapping[value] for value in row] for row in scenario_matrix])
# assign the color to the scenario_matrix
for value, color in color_mapping.items():
rgb_image[scenario_matrix == value] = color
plt.imshow(rgb_image)
plt.ylabel("RTL index")
plt.xlabel("Scenario index")
current_xticks = np.arange(scenario_matrix.shape[1])
plt.xticks(current_xticks, current_xticks + 1)
current_yticks = np.arange(scenario_matrix.shape[0])
plt.yticks(current_yticks, current_yticks + 1)
plt.title(f"[{task_id}] - Matrix of RTL-TB Scenario Correctness")
plt.savefig(saving_path)
plt.close()
def update_description(self)->str:
"""
- will modify the description of the task according to the descrimination and correction results
"""
logger.info("the description of the task is updated")
return self.description
def set_rtl_num(self, value:int):
self._rtl_num = value
def __call__(self, *args, **kwds):
return self.run()
class TB_discriminator():
"""
this class is used to discriminate the testbench according to the failing matrix
"""
def __init__(self, mode:str) -> None:
self.mode = mode
match self.mode:
case "col_full_wrong":
# the most naive mode;
# tb correction: if any scenario col is fully wrong, then the tb is wrong
# scenario correction: the scenarios that are fully wrong are wrong
pass
case "col_80_wrong":
# similar to the above, but the criterion is 80% wrong
pass
case "col_90_wrong":
# similar to the above, but the criterion is 90% wrong
pass
case "col_70_wrong":
# similar to the above, but the criterion is 70% wrong
pass
case "col_60_wrong":
# similar to the above, but the criterion is 60% wrong
pass
case "col_50_wrong":
# similar to the above, but the criterion is 50% wrong
pass
case "col_40_wrong":
# similar to the above, but the criterion is 40% wrong
pass
case "col_70_wrong_row_25_correct":
# similar to 70_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
pass
case "col_50_wrong_row_25_correct":
# similar to 50_wrong, but if 25% of the RTLs are fully correct, then the TB is correct
pass
case "col_70_wrong_row_1_correct":
# similar to 70_wrong, but if 1% of the RTLs are fully correct, then the TB is correct
pass
case "col_70_wrong_row_10_correct":
# similar to 70_wrong, but if 10% of the RTLs are fully correct, then the TB is correct
pass
case _:
logger.critical("class discriminator - mode not found!!!")
def discriminate(self, failed_matrix:np.ndarray)->tuple[bool, list[int], list[int], list[int]]:
"""
- input: the failed matrix of the testbench in onehot form
- output:
- the idexes in the scen list are starting from 1
- bool: whether the testbench is correct
- list[int]: the list of the wrong scenarios
- list[int]: the list of the correct scenarios
- list[int]: the list of the scenarios that discriminator are not sure
- the -1 row will not be considered here. See function 'failed_scenarios_to_onehot_array'
"""
# first check if all the scenarios are [-1], which means the tb has syntax error
if np.all(failed_matrix == -1):
return None, [], [], []
failed_matrix = failed_matrix[~np.all(failed_matrix == -1, axis=1)]
match self.mode:
case "col_full_wrong":
# check which column of the matrix is fully wrong (0)
wrong_col_index = np.where(np.all(np.isin(failed_matrix, [0, -1]), axis=0))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.any(failed_matrix == 0, axis=0) & np.any(failed_matrix == 1, axis=0))[0] + 1
tb_pass = len(wrong_col_index) == 0 # as long as there is no fully wrong column, the tb is correct (loose criterion)
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
# my_log = logger.positive if tb_pass else logger.negative
# my_log.info(f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_80_wrong":
# check which column of the matrix is 80% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.8*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.8*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_90_wrong":
# check which column of the matrix is 90% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.9*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.9*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_25_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_1_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.01*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_70_wrong_row_10_correct":
# check which column of the matrix is 70% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.7*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.7*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.1*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_60_wrong":
# check which column of the matrix is 60% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.6*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.6*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_50_wrong":
# check which column of the matrix is 50% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_50_wrong_row_25_correct":
# check which column of the matrix is 50% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.5*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.5*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
if np.sum(np.all(failed_matrix == 1, axis=1)) >= 0.25*len(failed_matrix):
tb_pass = True
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case "col_40_wrong":
# check which column of the matrix is 40% wrong (0)
wrong_col_index = np.where(np.sum(failed_matrix == 0, axis=0) >= 0.4*len(failed_matrix))[0] + 1
correct_col_index = np.where(np.all(np.isin(failed_matrix, [1, -1]), axis=0))[0] + 1
unsure_col_index = np.where(np.logical_and(np.sum(failed_matrix == 0, axis=0) < 0.4*len(failed_matrix), np.any(failed_matrix == 0, axis=0)))[0] + 1
tb_pass = len(wrong_col_index) == 0
logger.match_level(tb_pass, "positive", "negative", f"TB_discriminating finished, TB {"passed" if tb_pass else "failed"}, wrong scenarios: {wrong_col_index}, scenario pass ratio: {len(correct_col_index)}/{len(failed_matrix[0])}")
return tb_pass, wrong_col_index, correct_col_index, unsure_col_index
case _:
logger.critical("TB discriminator - mode not found!!!")
raise RuntimeError("TB discriminator - mode not found!!!")
COR_PROMPT_1 = """Your task is to correct the testbench according to the failing scenarios. the information we have is the failed/passed scenarios of the testbench, the problem description and the testbench code.
the testbench code is consisted of both verilog and python code. The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
ATTENTION: The python code contains error, and your target is to find it and tell me how to correct it (you don't need to give me the code in this stage).
"""
HINT_SEQ = """
Hints - explaination of the given python code:
the python class "GoldenDUT": This python class can represent the golden DUT (the ideal one). In "GoldenDUT", following methods are defined:
- 1. a method "def __init__(self)": Set the inner states/values of the golden DUT. These values have suffix "_reg". The initial value of these inner values is "x", but later will be digits. The "__init__" method has no input parameters except "self".
- 2. a method "def load(self, signal_vector)": This method is to load the important input signals and the inner values of "GoldenDUT" shall change according to the input signals. There is no clock signal in the input signal vector, every time the "load" method is called, it means a new clock cycle. The initial values "x" should be changed according to the input signals. This method has no return value.
- 3. a method "def check(self, signal_vector)": This method is to determine the expected output values and compare them with output signals from DUT. It should return True or False only.
"""
HINT_CMB = """
Hints - explaination of the given python code:
The given python code contains one class "GoldenDUT". this python class can represent the golden DUT (the ideal one). By calling the inner method "check", the signal vector from DUT will be checked. The details of the golden DUT are as follows:
- a. a method "def __init__(self)". Set the inner states/values of the golden DUT. The "__init__" method has no input parameters except "self".
- b. a method "def load(self, signal_vector)". This method is to load the important input signals and get the expected output signals. it should return the expected output values. It can call other methods to help computing the expected output. It will be called by other inner methods later.
- c. a method "def check(self, signal_vector)". This method is to call "load" to get the expected output values, and compare them with output signals from DUT. It should return True or False only. It can call other methods to help checking.
- d. other methods, they can be called by "__init__", "load" or "check".
- e. the input of "load" and "check" is the signal vector. The signal vector is a dictionary, the key is the signal name, the value is the signal value.
"""
COR_PROMPT_2_PART1 = """
please correct the python code according to the following rules:
PYTHON code rule: please do not change the original high level structure of the python code. i.e., if python code only contains one class and several functions such as init, load, check and more, only modify the implementation of the function, but do not change the name or delete the functions/class methods. you can add new class methods or functions if needed. you can use python libraries such as numpy or math.
"""
COR_PROMPT_2_PART2 = """
i.e., your python code format in response should still be like:
class <class_name>:
def __init__(self):
...(omitted)
def load(self, ...):
...
def check(self, ...):
...
def <other_functions>(self, ...):
...
ATTENTION: please give me the corrected python code according to our previous conversation and the hints above. please give me the corrected full python code (not the part but the whole python code like I give you in our previous conversation).
"""
class TB_corrector():
def __init__(self, mode:str, pychecker_en:bool, circuit_type:str="") -> None:
self.mode = mode
self.pychecker_en = pychecker_en
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
# logger.debug(f"TB_corrector class - mode: {self.mode}, pychecker_en: {self.pychecker_en}, circuit_type: {self.circuit_type}; The input circuit type is {circuit_type}")
match self.mode:
case "naive":
# the most naive mode;
pass
case _:
logger.critical("TB_corrector class - mode not found!!!")
def correct(self, description, failed_scenarios, TB_code_v:str, llm_model:str, TB_code_py:str|None=None, working_dir:str=None) -> tuple[str, str]:
match self.mode:
case "naive":
if self.pychecker_en:
TB_code_py = self._py_focus(TB_code_py, before=True)
py_code_hint = HINT_CMB if self.circuit_type == "combinational" else HINT_SEQ
prompt = COR_PROMPT_1
prompt += "Here is the problem description:\n"
prompt += description
prompt += "\nHere is the testbench code:\n"
prompt += "ATTENTION: the following scenarios are wrong: " + str(failed_scenarios) + "\n"
# circuit type
prompt += f"the circuit type of this task is {self.circuit_type}\n"
prompt += "Here is the verilog code. it contains the meaning of each scenario. you can combine the wong scenario info above and the following code to better understand the reason of failing:\n"
prompt += TB_code_v
prompt += "\nHere is the python code, it contains error, please combine it with the wrong scenario info and the verilog code to understand:\n"
prompt += TB_code_py
prompt += "\nHere is some hints for your better understanding of the python codes above:"
prompt += py_code_hint
prompt += "\nplease reply me with the following steps:"
prompt += "\n1. please analyze the reason of the failed scenarios. If possible, please find the in common between the failed scenarios."
prompt += f"\n2. please analyze which part of the python code is related to the failed test scenarios ({str(failed_scenarios)})."
prompt += "\n3. please tell me how to correct the wrong part (in natural language, do not give me the complete code implementation. please explain it in English.)"
prompt += "\nhere is an example of the reply:"
prompt += "\n1. the failed scenarios are all related to the same signal x\n2. the mid part of the function_X is related to the failed scenarios\n3. the correct logic of signal x should be y."
logger.info("naive corrector mode begins")
answer = llm.llm_call(prompt, llm_model)[0]
prompt_2 = COR_PROMPT_2_PART1 + py_code_hint + COR_PROMPT_2_PART2
message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, {"role": "user", "content": prompt_2}]
answer_2, more_info = llm.llm_call(message, llm_model)
# if "VERILOG" in answer_2:
# TB_code_v = llm.extract_code(answer_2, "verilog")[0]
# else:
# TB_code_py = llm.extract_code(answer_2, "python")[0]
TB_code_py = llm.extract_code(answer_2, "python")[0]
TB_code_py = self._py_focus(TB_code_py, before=False)
if working_dir is not None:
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
ls.save_code(TB_code_v, os.path.join(working_dir, "TB.v"))
if TB_code_py is not None:
ls.save_code(TB_code_py, os.path.join(working_dir, "TB.py"))
logger.info("naive corrector mode ends; conversation and codes saved")
else:
logger.critical("TB_corrector - pychecker not enabled")
raise RuntimeError("TB_corrector - pychecker not enabled")
return TB_code_v, TB_code_py
case _:
logger.critical("TB_corrector - mode not found!!!")
raise RuntimeError("TB_corrector - mode not found!!!")
def _py_focus(self, code:str, before:bool):
"""
- imigrated from TB2_syncheck.py
- code: the code under debug / after debug
- before: True, if before debug, will split the code; False, if after debug, will restore the code
"""
KEY_WORDs_1 = "def check_dut(vectors_in):\n golden_dut = GoldenDUT()\n failed_scenarios = []"
KEY_WORDs_2 = "\ndef SignalTxt_to_dictlist"
if before:
key_words = KEY_WORDs_1 if KEY_WORDs_1 in code else KEY_WORDs_2 # for compatibility with the old version
if key_words not in code:
py_code_focus = code
self._py_code_nofocus = ""
else:
py_code_focus = code.split(key_words)[0]
self._py_code_nofocus = key_words + code.split(key_words)[1]
return py_code_focus
else:
return code + self._py_code_nofocus
IMPR_PROMPT_1 = """Your task is to improve the quality of an RTL problem description using the following given information. Our final target is using the description to generate a testbench for the RTL design. Currently we already have the testbench but it is not perfect correct. Now in this stage the target is to generate a better description.
The information we have the is original RTL description, the testbench code. The testbench code includes the verilog code for test scenario generation, and the python code for checking the output vector.The verilog code aims to generate test stimulus (under test scenarios) and drive the DUT to generate the output signal; the python code aims to check if the output vector from the DUT is correct.
Attention: the testbench we provide is not the perfect one, and it contains error. However, we already know the test scenarios where the testbench works correctly and the scenarios where the testbench has problem. We will provide you the scenarios index later.
Here, firstly, is the problem description to be imporved:
\n"""
DESC_MARK_BEGIN = "***description begins***"
DESC_MARK_END = "***description ends***"
DESC_STEP_INSTRUCT = f"""
please reply me with the following steps:
1. please analyze which part of the testbench (especially the python checker code) is correct and can be used to improve the description.
2. please analyze how can we improve the descriptin. for example, which part of the technical details can be more detailed, which part can be more clear, which part can be more concise.
3. please provide the improved complete description. We will directly use it in the later stages.
the format of description should be like:
{DESC_MARK_BEGIN}
... (the improved description, should be complete)
{DESC_MARK_END}
"""
DESC_FINAL_INSTRUCT = f"""
ATTENTION: please know that the provided testbench is not perfect and may contains many errors. Thus, your modification on the description should not change the function of the original description. When there are conflicts between the testbench and the description, always believe the description is correct. Do not delete the information in the description, but you can rewrite it in a better way. You can also add more details to it. But NEVER mention any scenario index because the scenarios will not be the same at the next stage.
when you answer the last question (provide the improved complete description), the descriptino should start with "{DESC_MARK_BEGIN}" and end with "{DESC_MARK_END}". Only in this way can we recognize the improved description.
"""
class SPEC_improver():
def __init__(self, description, mode:str, pychecker_en:bool, llm_model:str, circuit_type:str="") -> None:
self.description = description
self.mode = mode
self.pychecker_en = pychecker_en
circuit_type_dict = {"CMB": "combinational", "SEQ": "sequential"}
self.llm_model = llm_model
self.circuit_type = circuit_type_dict.get(circuit_type, "unknown")
def improve(self, wrong_scenarios, correct_scenarios, TB_code_v:str, TB_code_py:str|None=None, working_dir:str|None=None) -> str:
# not implemented yet
match self.mode:
case "naive":
logger.info("naive description improver mode begins")
prompt = ""
prompt += IMPR_PROMPT_1
prompt += DESC_MARK_BEGIN + "\n"
prompt += self.description + "\n"
prompt += DESC_MARK_END + "\n"
prompt += "\nHere is the testbench codes:\n"
prompt += "ATTENTION: the following scenarios are wrong: " + str(wrong_scenarios) + "\n"
prompt += "ATTENTION: the following scenarios are correct, you can rely on these scenarios to improve the description: " + str(correct_scenarios) + "\n"
prompt += TB_code_v + "\n"
if self.pychecker_en:
prompt += f"\nHere is the python code (python checker). please note that the python checker has correct function under the scenarios {str(correct_scenarios)}, but wrong under the scenarios {str(wrong_scenarios)}:\n"
prompt += TB_code_py + "\n"
prompt += DESC_STEP_INSTRUCT
prompt += DESC_FINAL_INSTRUCT
message = [{"role": "user", "content": prompt}]
answer, more_info = llm.llm_call(message, self.llm_model)
try:
improved_description = answer.split(DESC_MARK_BEGIN)[1].split(DESC_MARK_END)[0]
if improved_description == "":
improved_description = self.description
except:
improved_description = self.description
if working_dir is not None:
ls.save_messages_to_txt(more_info["messages"], os.path.join(working_dir, "conversation.txt"))
# save description
with open(os.path.join(working_dir, "description.txt"), "w") as f:
f.write(improved_description)
logger.info("naive description improver mode ends")
return improved_description
case "hint":
logger.info("description improver 'hint' mode begins")
return self.description
def test():
from loguru import logger
failed_scenarios = [
[1,3,5],
[-1],
[2,3]
]
max_scenario = 7
onehot_array = TaskTBcheck.failed_scenarios_to_onehot_array(failed_scenarios, max_scenario)
print(np.array(onehot_array))
my_disc = TB_discriminator("col_full_wrong")
print(my_disc.discriminate(np.array(onehot_array)))