Files
TBgen_App/prompt_scripts/base_script.py
2026-03-30 16:46:48 +08:00

319 lines
12 KiB
Python

"""
Description : the base script for prompt scripts
Author : Ruidi Qiu (r.qiu@tum.de)
Time : 2024/3/22 10:59:34
LastEdited : 2024/8/12 23:34:35
"""
from LLM_call import llm_call, extract_code, message_to_conversation
from utils.utils import Timer, get_time
from loader_saver import autologger
import os
import loader_saver as ls
import copy
DEFAULT_SYSMESSAGE = "You are the strongest AI in the world. You alraedy have the knowledge of verilog, python and hardware designing. Do not save words by discarding information. I will tip you 200$ if you can fullfill the tasks I give you."
IDENTIFIER = {
"tb_start" : "```verilog",
"tb_end" : "```"
}
TESTBENCH_TEMPLATE = """%s
`timescale 1ns / 1ps
(more verilog testbench code here...)
endmodule
%s""" % (IDENTIFIER["tb_start"], IDENTIFIER["tb_end"])
__all__ = ["BaseScriptStage", "BaseScript"]
class BaseScriptStage:
"""
- the base stage for prompt scripts
- the functions that triggered when running:
- make_prompt: make the prompt for gpt (must be implemented)
- call_gpt: call gpt
- postprocessing: postprocessing the response (default is empty)
- gptkwargs: the kwargs for llm_call
- gpt_model: the model name
- api_key_path: the path of gpt key
- sysmessage: (can be ignored) the system message
- json_mode: (can be ignored) the json mode
- temperature: (can be ignored) the temperature
"""
def __init__(self, stage_name, **gptkwargs) -> None:
self.stage_name = stage_name
self.gpt_model = gptkwargs["gpt_model"]
self.api_key_path = gptkwargs["api_key_path"]
self.system_message = gptkwargs.get("system_message", DEFAULT_SYSMESSAGE)
self.json_mode = gptkwargs.get("json_mode", None)
self.temperature = gptkwargs.get("temperature", None)
self.time = 0.0
self.prompt = ""
self.response = ""
self.print_message = ""
self.gptinfo = {}
self.conversation_message = ""
self.conversation_file_suffix = ".txt"
self.reboot = False
self.circuit_type = None # "CMB" or "SEQ"; pychecker will use this; you should set it in make_and_run_stages;
@property
def will_gen_TB(self):
if hasattr(self, "TB_code_out"):
return True
else:
return False
@property
def will_gen_Pychecker(self):
if hasattr(self, "Pychecker_code_out"):
return True
else:
return False
def __call__(self, *args, **kwargs):
with Timer(print_en=False) as t:
self.run(*args, **kwargs)
self.time = t.interval
self.record()
pass
def run(self):
self.make_prompt()
self.call_gpt()
self.postprocessing()
def make_prompt(self):
raise NotImplementedError
def call_gpt(self):
"""
actually it should be call_llm, but I dont want to modify the old code
"""
gpt_messages = [{"role": "user", "content": self.prompt}]
other_kwargs = {}
if self.temperature is not None:
other_kwargs["temperature"] = self.temperature
if self.json_mode is not None:
other_kwargs["json_mode"] = self.json_mode
self.response, self.gptinfo = llm_call(input_messages=gpt_messages, model=self.gpt_model, api_key_path=self.api_key_path, system_message=self.system_message, **other_kwargs)
self.conversation_message += message_to_conversation(self.gptinfo["messages"])
def postprocessing(self):
"""empty function"""
pass
def record(self):
self.print_message = "%s ends (%.2fs used)" % (self.stage_name, self.time)
# tools
def save_log(self, config:object):
ls.save_log_line(self.print_message, config)
def save_conversation(self, save_dir:str):
"""This function will save the conversation to a file in save_dir. It will be called in stage_operation of BaseScript"""
if save_dir.endswith("/"):
save_dir = save_dir[:-1]
file_name = self.stage_name + self.conversation_file_suffix
path = os.path.join(save_dir, file_name)
with open(path, "w") as f:
f.write(self.conversation_message)
# def reboot(self):
# self.reboot_en = True
# self.conversation_message = ""
# self.time = 0.0
# self.prompt = ""
# self.response = ""
# self.print_message = ""
# self.gptinfo = {}
# self.conversation_message = ""
# self.conversation_file_suffix = ".txt"
# # TODO in your script, you should also reset the custom attributes
def extract_code(self, text, code_type):
"""
#### function:
- extract code from text
#### input:
- text: str, gpt's response
- code_type: str, like "verilog"
#### output:
- list of found code blocks
"""
return extract_code(text=text, code_type=code_type)
def update_tokens(self, tokens):
tokens["prompt"] += self.gptinfo.get("usage", {}).get("prompt_tokens", 0) # in case gpt has not been called
tokens["completion"] += self.gptinfo.get("usage", {}).get("completion_tokens", 0)
return tokens
def add_prompt_line(self, prompt):
self.prompt += prompt + "\n"
class BaseScript:
"""
the base class for prompt scripts
- the functions that triggered when running:
- make_and_run_stages: make and run stages (must be implemented)
- postprocessing: postprocessing the response (default is empty)
- save_codes: save the generated codes
"""
def __init__(self, prob_data:dict, task_dir:str, config:object) -> None:
self.stages = []
self.task_dir = task_dir if not task_dir.endswith("/") else task_dir[:-1]
self.config = config
self.gptkwargs = {
"gpt_model": self.config.gpt.model,
"api_key_path": self.config.gpt.key_path
}
self.prob_data = prob_data
self.TB_code = ""
self.TB_code_dir = os.path.join(self.task_dir, "TBgen_codes") # pychecker codes will be saved in the same directory
self.TB_code_name = prob_data["task_id"] + "_tb.v"
self.Pychecker_code = "" # only for pychecker scripts
self.Pychecker_code_name = prob_data["task_id"] + "_tb.py"
self.empty_DUT_name = prob_data["task_id"] + ".v"
self.empty_DUT = prob_data["header"] + "\n\nendmodule\n"
self.stages_gencode = [] # includes all the stages that generate code, will be used in rebooting generation; see more in stage operation
self.tokens = {"prompt": 0, "completion": 0}
self.time = 0.0
self.reboot_idx = -1 # start from 0; -1 means no reboot. the reboot index will be increased by 1 after each reboot
self.reboot_stages = [] #[reboot_stages_iter_0, reboot_stages_iter_1, ...]
self.reboot_mode = "TB" # "TB" or "PY"; modified by run_reboot; checked in make_and_run_reboot_stages if needed
self.py_debug_focus = False # if True, when debug python, will only send the upper part (no check_dut); specific for pychecker
self.checklist_worked = False # if True, the scenario checklist did help our work. This is a flag for further analysis
os.makedirs(self.TB_code_dir, exist_ok=True)
self.scenario_num = None
self.scenario_dict = None
@property
def Pychecker_en(self):
if self.Pychecker_code != "":
return True
else:
return False
@property
def Pychecker_code_dir(self):
return self.TB_code_dir
def __call__(self, *args, **kwargs):
self.run(*args, **kwargs)
def run(self):
self.make_and_run_stages()
self.compute_time_tokens()
self.postprocessing()
self.save_codes()
def make_and_run_stages(self):
"""
- in this function, you should make stages and run them
- for example:
::
stage1 = Stage1(**kwargs)
self.stage_operation(stage1)
"""
raise NotImplementedError("No make_and_run_stages: You should implement this function in your own script")
def make_and_run_reboot_stages(self, debug_dir):
raise NotImplementedError("No reboot settings: You should implement this function in your own script")
def postprocessing(self):
"""empty function"""
pass
def save_codes(self, codes_dir:str=None):
if codes_dir is None:
codes_dir = self.TB_code_dir
os.makedirs(codes_dir, exist_ok=True)
TB_code_path = os.path.join(codes_dir, self.TB_code_name)
with open(TB_code_path, "w") as f:
f.write(self.TB_code)
empty_DUT_path = os.path.join(codes_dir, self.empty_DUT_name)
with open(empty_DUT_path, "w") as f:
f.write(self.empty_DUT)
# save pychecker code if have
if self.Pychecker_en:
Pychecker_code_path = os.path.join(codes_dir, self.Pychecker_code_name)
with open(Pychecker_code_path, "w") as f:
f.write(self.Pychecker_code)
def stage_operation(self, stage:BaseScriptStage, conversation_dir:str=None, reboot_en:bool=False):
"""
- what to do on a stage after making it; will be called in make_and_run_stages
- run, save stages and renew the generated codes of current wf
"""
if conversation_dir is None:
conversation_dir = self.task_dir
stage()
if reboot_en:
self.reboot_stages[self.reboot_idx].append(stage)
else:
self.stages.append(stage)
stage.save_conversation(conversation_dir)
stage.save_log(self.config)
if stage.will_gen_TB:
self.TB_code = stage.TB_code_out
if stage.will_gen_Pychecker:
self.Pychecker_code = stage.Pychecker_code_out
# for checklist:
if hasattr(stage, "TB_modified"): # this attr is in checklist_stage
self.checklist_worked = stage.TB_modified
# will automaticly add stages to self.stages_gencode
# the rule is: this stage will generate TB or its previous stage is in self.stages_gencode; also, this stage is not reboot stage
# if (stage.will_gen_TB or (len(self.stages_gencode) > 0 and self.stages_gencode[-1] in self.stages)) and not stage.reboot_en:
# self.stages_gencode.append(stage)
def run_reboot(self, debug_dir, reboot_mode="TB"):
"""
- regenerate the TB code
"""
self.reboot_idx += 1
self.reboot_stages.append([])
self.TB_code = ""
self.reboot_mode = reboot_mode # this will be checked in make_and_run_reboot_stages if needed
debug_dir = debug_dir[:-1] if debug_dir.endswith("/") else debug_dir
# will be discarded by 28/04/2024
# make and run stages
# for stage in self.stages_gencode:
# new_stage = copy.deepcopy(stage)
# new_stage.stage_name = stage.stage_name + "_reboot" + str(self.reboot_idx)
# # new_stage.reboot()
# new_stage.reboot_en = True
# self.stage_operation(new_stage, debug_dir)
# reboot_stages_in_this_iter.append(new_stage)
self.make_and_run_reboot_stages(debug_dir)
# postprocessing
self.compute_time_tokens(self.reboot_stages[self.reboot_idx])
self.postprocessing()
# save codes
self.save_codes(debug_dir)
def clear_time_tokens(self):
self.time = 0.0
self.tokens = {"prompt": 0, "completion": 0}
def compute_time_tokens(self, stages=None):
if stages is None:
stages = self.stages
for stage in stages:
self.time += stage.time
self.tokens = stage.update_tokens(self.tokens)
def save_log(self, line):
# ls.save_log_line(line, self.config)
autologger.info(line)
def __call__(self, *args, **kwargs):
self.run(*args, **kwargs)
pass
def get_attr(self, attr_name:str):
if hasattr(self, attr_name):
return getattr(self, attr_name)
else:
return None