Files
CGA-bench/prompt_scripts/rtl_summarizer.py
2026-05-22 10:02:42 +08:00

438 lines
14 KiB
Python

"""
Description : RTL summarization tool for large RTL files
Generates concise summaries to reduce LLM prompt size while preserving key information
Author : CorrectBench
Time : 2026/04/19
"""
import re
from typing import Optional, Dict, List, Tuple
def extract_module_name(rtl_code: str) -> str:
"""Extract module name from RTL code"""
match = re.search(r'module\s+(\w+)', rtl_code)
return match.group(1) if match else "unknown"
def extract_module_header(rtl_code: str) -> Optional[str]:
"""Extract module header (ports and parameters)"""
match = re.search(
r'module\s+\w+\s*(#\([^)]*\))?\s*\(([^)]*)\)',
rtl_code,
re.DOTALL
)
if match:
params = match.group(1) or ""
ports = match.group(2) or ""
return f"module (...) {params.strip()}\n({ports.strip()})"
return None
def extract_interface_signals(rtl_code: str) -> List[Dict[str, str]]:
"""
Extract interface signals from module header
Returns list of dicts with name, direction, width
"""
signals = []
# Match input/output declarations in the module header
header_match = re.search(r'module\s+\w+\s*(?:#\([^)]*\))?\s*\(([^)]*)\)', rtl_code, re.DOTALL)
if not header_match:
return signals
port_section = header_match.group(1)
# Match input/output/inout with optional width
# Pattern: input/output/inout [wire/reg] [signed] [width] name
pattern = r'(input|output|inout)\s+(?:wire|reg)?\s*(?:signed)?\s*(?:\[([^:]+):([^\]]+)\])?\s*(\w+)'
for match in re.finditer(pattern, port_section):
direction = match.group(1)
width_high = match.group(2)
width_low = match.group(3)
name = match.group(4)
width = ""
if width_high and width_low:
width = f"[{width_high}:{width_low}]"
signals.append({
"name": name,
"direction": direction,
"width": width
})
return signals
def detect_circuit_type(rtl_code: str) -> str:
"""
Detect circuit type: FSM, COMBINATIONAL, SEQUENTIAL, PROTOCOL, or MIXED
"""
has_always_ff = bool(re.search(r'always_ff\s*@', rtl_code))
has_always_latch = bool(re.search(r'always_latch\s*@', rtl_code))
has_always_comb = bool(re.search(r'always_comb\s*@', rtl_code))
has_regular_always = bool(re.search(r'always\s*@', rtl_code))
# Check for FSM patterns
has_state_reg = bool(re.search(r'reg\s+\w*state\w*', rtl_code, re.IGNORECASE))
has_case_state = bool(re.search(r'case\s*\(\s*\w*state\w*\s*\)', rtl_code, re.IGNORECASE))
is_fsm = has_state_reg and has_case_state
# Check for protocol interfaces
has_handshake = bool(re.search(r'(valid|ready|ack|request|grant)', rtl_code, re.IGNORECASE))
has_spi = bool(re.search(r'(spi|miso|mosi|sclk)', rtl_code, re.IGNORECASE))
has_i2c = bool(re.search(r'(i2c|sda|scl)', rtl_code, re.IGNORECASE))
is_protocol = has_handshake or has_spi or has_i2c
if is_fsm:
if is_protocol:
return "SEQUENTIAL (FSM + PROTOCOL)"
return "SEQUENTIAL (FSM)"
elif has_always_ff or (has_regular_always and not has_always_comb):
if is_protocol:
return "SEQUENTIAL (PROTOCOL)"
return "SEQUENTIAL"
elif has_always_comb or has_always_latch:
return "COMBINATIONAL"
elif is_protocol:
return "PROTOCOL"
else:
return "COMBINATIONAL"
def extract_fsm_info(rtl_code: str) -> Optional[Dict]:
"""
Extract FSM information: states, state variable, transitions
"""
# Find state register
state_var_match = re.search(r'reg\s+(?:\[(\d+):0\])?\s*(\w*state\w*)', rtl_code, re.IGNORECASE)
if not state_var_match:
return None
state_bits = state_var_match.group(1)
state_var = state_var_match.group(2)
# Extract state names from localparam or parameter
states = []
param_pattern = r'(?:localparam|parameter)\s+(?:(\w+)\s*=\s*(\d+)|(\w+)\s*=\s*(\w+))'
for match in re.finditer(param_pattern, rtl_code):
if match:
name = match.group(1) or match.group(3)
if name and any(kw in name.upper() for kw in ['STATE', 'ST']):
states.append(name)
# If no explicit state names, try to extract from case statements
if not states:
case_match = re.search(r'case\s*\(\s*' + re.escape(state_var) + r'\s*\)(.*?)endcase', rtl_code, re.DOTALL | re.IGNORECASE)
if case_match:
case_body = case_match.group(1)
# Find all state names (words ending with colon)
state_names = re.findall(r'(\w+)\s*:', case_body)
states = [s for s in state_names if s.upper() not in ['NEXT', 'DEFAULT']]
return {
"state_variable": state_var,
"state_bits": int(state_bits) if state_bits else 0,
"states": states[:20] if states else [], # Limit to 20 states
"state_count": len(states) if states else 0
}
def extract_counters(rtl_code: str) -> List[Dict]:
"""
Extract counter/relator information
"""
counters = []
# Match counter registers: reg [N:0] counter_name
counter_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w*(?:count|cnt|counter|rel)\w*)'
for match in re.finditer(counter_pattern, rtl_code, re.IGNORECASE):
width = match.group(1)
name = match.group(2)
counters.append({
"name": name,
"width": int(width) + 1 if width else 1
})
return counters[:5] # Limit to 5 counters
def extract_fifos(rtl_code: str) -> List[Dict]:
"""
Extract FIFO buffer information
"""
fifos = []
# Match FIFO-like structures
fifo_pattern = r'(?:reg|wire)\s+(?:\[(\d+):0\])?\s*(\w*(?:fifo|queue|buffer)\w*)'
for match in re.finditer(fifo_pattern, rtl_code, re.IGNORECASE):
width = match.group(1)
name = match.group(2)
fifos.append({
"name": name,
"width": int(width) + 1 if width else 8
})
# Also look for FIFO instantiation patterns
fifo_inst_pattern = r'(\w+)\s+#\.\w+\s*\(.*?\)'
for match in re.finditer(fifo_inst_pattern, rtl_code, re.DOTALL):
name = match.group(1)
if 'fifo' in name.lower():
fifos.append({"name": name, "type": "instantiated"})
return fifos[:5] # Limit to 5 FIFOs
def extract_key_registers(rtl_code: str) -> List[str]:
"""
Extract key registered signals (not just clk/rst)
"""
registers = []
# Match reg declarations
reg_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w+)'
for match in re.finditer(reg_pattern, rtl_code):
width = match.group(1)
name = match.group(2)
# Skip common non-data signals
if any(kw in name.lower() for kw in ['clk', 'rst', 'reset', 'state', 'cnt', 'count', 'fifo', 'buffer']):
continue
registers.append(name)
return registers[:10] # Limit to 10 registers
def extract_parameters(rtl_code: str) -> List[Dict]:
"""
Extract module parameters
"""
params = []
param_pattern = r'parameter\s+(\w+)\s*=\s*([^;]+)'
for match in re.finditer(param_pattern, rtl_code):
name = match.group(1)
value = match.group(2).strip()
params.append({"name": name, "value": value})
return params[:10] # Limit to 10 parameters
def is_large_rtl(rtl_code: str, threshold: int = 5000) -> bool:
"""Check if RTL code exceeds size threshold"""
return len(rtl_code) >= threshold
def summarize_rtl(rtl_code: str, include_critical_section: bool = True) -> str:
"""
Generate a concise summary of RTL code for LLM prompts.
Args:
rtl_code: Full RTL Verilog code
include_critical_section: If True, include a sample of critical logic (FSM case, etc.)
Returns:
Formatted summary string
"""
lines = []
lines.append("=" * 60)
lines.append("RTL SUMMARY (Generated by rtl_summarizer)")
lines.append("=" * 60)
# Module name and type
module_name = extract_module_name(rtl_code)
circuit_type = detect_circuit_type(rtl_code)
lines.append(f"\nModule: {module_name}")
lines.append(f"Type: {circuit_type}")
# Interface signals
signals = extract_interface_signals(rtl_code)
if signals:
lines.append("\nInterface:")
# Group by direction
inputs = [s for s in signals if s["direction"] == "input"]
outputs = [s for s in signals if s["direction"] == "output"]
inouts = [s for s in signals if s["direction"] == "inout"]
if inputs:
input_str = ", ".join([f"{s['name']}{s['width']}" for s in inputs[:15]])
if len(inputs) > 15:
input_str += f", ... (+{len(inputs)-15} more)"
lines.append(f" Inputs: {input_str}")
if outputs:
output_str = ", ".join([f"{s['name']}{s['width']}" for s in outputs[:15]])
if len(outputs) > 15:
output_str += f", ... (+{len(outputs)-15} more)"
lines.append(f" Outputs: {output_str}")
if inouts:
lines.append(f" Inouts: {', '.join([s['name'] for s in inouts])}")
# Parameters
params = extract_parameters(rtl_code)
if params:
lines.append("\nParameters:")
for p in params[:10]:
lines.append(f" {p['name']} = {p['value']}")
# FSM info
fsm_info = extract_fsm_info(rtl_code)
if fsm_info and fsm_info['states']:
lines.append(f"\nFSM:")
lines.append(f" State Variable: {fsm_info['state_variable']}")
if fsm_info['state_bits']:
lines.append(f" State Bits: {fsm_info['state_bits']}")
state_list = ", ".join(fsm_info['states'][:15])
if fsm_info['state_count'] > 15:
state_list += f", ... (+{fsm_info['state_count']-15} more states)"
lines.append(f" States: {state_list}")
# Counters
counters = extract_counters(rtl_code)
if counters:
lines.append("\nCounters:")
for c in counters:
lines.append(f" - {c['name']} ({c.get('width', 1)}-bit)")
# FIFOs/Buffers
fifos = extract_fifos(rtl_code)
if fifos:
lines.append("\nFIFOs/Buffers:")
for f in fifos:
lines.append(f" - {f['name']}")
# Key registers
registers = extract_key_registers(rtl_code)
if registers:
lines.append("\nKey Registers:")
lines.append(f" {', '.join(registers)}")
# Critical section (FSM case body or main always block)
if include_critical_section:
# Try to extract FSM case body
if fsm_info:
case_match = re.search(
r'case\s*\(\s*' + re.escape(fsm_info['state_variable']) + r'\s*\)(.*?)endcase',
rtl_code,
re.DOTALL | re.IGNORECASE
)
if case_match:
lines.append("\n" + "-" * 40)
lines.append("FSM Case Body (for reference):")
lines.append("-" * 40)
case_body = case_match.group(1)[:800] # Limit to 800 chars
lines.append(case_body)
if len(case_match.group(1)) > 800:
lines.append("... (truncated)")
# Try to extract main always block
else:
always_match = re.search(r'(always\s*@.*?begin.*?end)', rtl_code, re.DOTALL)
if always_match:
lines.append("\n" + "-" * 40)
lines.append("Main Always Block (for reference):")
lines.append("-" * 40)
always_body = always_match.group(1)[:800]
lines.append(always_body)
if len(always_match.group(1)) > 800:
lines.append("... (truncated)")
lines.append("\n" + "=" * 60)
lines.append("Note: This is a SUMMARY. Full RTL code is available in the reference file.")
lines.append("=" * 60)
return "\n".join(lines)
if __name__ == "__main__":
# Test with a sample
sample_rtl = """
module spi_controller #(
parameter CLK_DIV = 16,
parameter DATA_WIDTH = 8
) (
input wire clk,
input wire rst_n,
input wire [7:0] tx_data,
input wire tx_valid,
output reg [7:0] rx_data,
output reg rx_valid,
output wire spi_clk,
output wire spi_mosi,
input wire spi_miso
);
localparam IDLE = 3'd0;
localparam TX_START = 3'd1;
localparam TX_WAIT = 3'd2;
localparam RX_START = 3'd3;
localparam RX_WAIT = 3'd4;
localparam DONE = 3'd5;
reg [2:0] state;
reg [2:0] next_state;
reg [7:0] shift_reg;
reg [3:0] bit_count;
reg [15:0] clk_count;
always @(posedge clk or negedge rst_n) begin
if (!rst_n)
state <= IDLE;
else
state <= next_state;
end
always @(*) begin
case (state)
IDLE: next_state = tx_valid ? TX_START : IDLE;
TX_START: next_state = TX_WAIT;
TX_WAIT: next_state = (bit_count == 7) ? RX_START : TX_WAIT;
RX_START: next_state = RX_WAIT;
RX_WAIT: next_state = (bit_count == 7) ? DONE : RX_WAIT;
DONE: next_state = IDLE;
default: next_state = IDLE;
endcase
end
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
shift_reg <= 0;
bit_count <= 0;
end else begin
case (state)
TX_START: begin
shift_reg <= tx_data;
bit_count <= 0;
end
TX_WAIT: begin
shift_reg <= {shift_reg[6:0], 1'b0};
bit_count <= bit_count + 1;
end
RX_WAIT: begin
rx_data[7-bit_count] <= spi_miso;
bit_count <= bit_count + 1;
end
DONE: begin
rx_valid <= 1;
end
endcase
end
end
assign spi_clk = clk_count >= CLK_DIV[15:0] && state != IDLE;
assign spi_mosi = shift_reg[7];
endmodule
"""
summary = summarize_rtl(sample_rtl)
print(summary)
print("\n" + "=" * 60)
print(f"is_large_rtl (threshold=5000): {is_large_rtl(sample_rtl, 5000)}")
print(f"is_large_rtl (threshold=2000): {is_large_rtl(sample_rtl, 2000)}")