438 lines
14 KiB
Python
438 lines
14 KiB
Python
"""
|
|
Description : RTL summarization tool for large RTL files
|
|
Generates concise summaries to reduce LLM prompt size while preserving key information
|
|
Author : CorrectBench
|
|
Time : 2026/04/19
|
|
"""
|
|
|
|
import re
|
|
from typing import Optional, Dict, List, Tuple
|
|
|
|
|
|
def extract_module_name(rtl_code: str) -> str:
|
|
"""Extract module name from RTL code"""
|
|
match = re.search(r'module\s+(\w+)', rtl_code)
|
|
return match.group(1) if match else "unknown"
|
|
|
|
|
|
def extract_module_header(rtl_code: str) -> Optional[str]:
|
|
"""Extract module header (ports and parameters)"""
|
|
match = re.search(
|
|
r'module\s+\w+\s*(#\([^)]*\))?\s*\(([^)]*)\)',
|
|
rtl_code,
|
|
re.DOTALL
|
|
)
|
|
if match:
|
|
params = match.group(1) or ""
|
|
ports = match.group(2) or ""
|
|
return f"module (...) {params.strip()}\n({ports.strip()})"
|
|
return None
|
|
|
|
|
|
def extract_interface_signals(rtl_code: str) -> List[Dict[str, str]]:
|
|
"""
|
|
Extract interface signals from module header
|
|
Returns list of dicts with name, direction, width
|
|
"""
|
|
signals = []
|
|
|
|
# Match input/output declarations in the module header
|
|
header_match = re.search(r'module\s+\w+\s*(?:#\([^)]*\))?\s*\(([^)]*)\)', rtl_code, re.DOTALL)
|
|
if not header_match:
|
|
return signals
|
|
|
|
port_section = header_match.group(1)
|
|
|
|
# Match input/output/inout with optional width
|
|
# Pattern: input/output/inout [wire/reg] [signed] [width] name
|
|
pattern = r'(input|output|inout)\s+(?:wire|reg)?\s*(?:signed)?\s*(?:\[([^:]+):([^\]]+)\])?\s*(\w+)'
|
|
|
|
for match in re.finditer(pattern, port_section):
|
|
direction = match.group(1)
|
|
width_high = match.group(2)
|
|
width_low = match.group(3)
|
|
name = match.group(4)
|
|
|
|
width = ""
|
|
if width_high and width_low:
|
|
width = f"[{width_high}:{width_low}]"
|
|
|
|
signals.append({
|
|
"name": name,
|
|
"direction": direction,
|
|
"width": width
|
|
})
|
|
|
|
return signals
|
|
|
|
|
|
def detect_circuit_type(rtl_code: str) -> str:
|
|
"""
|
|
Detect circuit type: FSM, COMBINATIONAL, SEQUENTIAL, PROTOCOL, or MIXED
|
|
"""
|
|
has_always_ff = bool(re.search(r'always_ff\s*@', rtl_code))
|
|
has_always_latch = bool(re.search(r'always_latch\s*@', rtl_code))
|
|
has_always_comb = bool(re.search(r'always_comb\s*@', rtl_code))
|
|
has_regular_always = bool(re.search(r'always\s*@', rtl_code))
|
|
|
|
# Check for FSM patterns
|
|
has_state_reg = bool(re.search(r'reg\s+\w*state\w*', rtl_code, re.IGNORECASE))
|
|
has_case_state = bool(re.search(r'case\s*\(\s*\w*state\w*\s*\)', rtl_code, re.IGNORECASE))
|
|
is_fsm = has_state_reg and has_case_state
|
|
|
|
# Check for protocol interfaces
|
|
has_handshake = bool(re.search(r'(valid|ready|ack|request|grant)', rtl_code, re.IGNORECASE))
|
|
has_spi = bool(re.search(r'(spi|miso|mosi|sclk)', rtl_code, re.IGNORECASE))
|
|
has_i2c = bool(re.search(r'(i2c|sda|scl)', rtl_code, re.IGNORECASE))
|
|
is_protocol = has_handshake or has_spi or has_i2c
|
|
|
|
if is_fsm:
|
|
if is_protocol:
|
|
return "SEQUENTIAL (FSM + PROTOCOL)"
|
|
return "SEQUENTIAL (FSM)"
|
|
elif has_always_ff or (has_regular_always and not has_always_comb):
|
|
if is_protocol:
|
|
return "SEQUENTIAL (PROTOCOL)"
|
|
return "SEQUENTIAL"
|
|
elif has_always_comb or has_always_latch:
|
|
return "COMBINATIONAL"
|
|
elif is_protocol:
|
|
return "PROTOCOL"
|
|
else:
|
|
return "COMBINATIONAL"
|
|
|
|
|
|
def extract_fsm_info(rtl_code: str) -> Optional[Dict]:
|
|
"""
|
|
Extract FSM information: states, state variable, transitions
|
|
"""
|
|
# Find state register
|
|
state_var_match = re.search(r'reg\s+(?:\[(\d+):0\])?\s*(\w*state\w*)', rtl_code, re.IGNORECASE)
|
|
if not state_var_match:
|
|
return None
|
|
|
|
state_bits = state_var_match.group(1)
|
|
state_var = state_var_match.group(2)
|
|
|
|
# Extract state names from localparam or parameter
|
|
states = []
|
|
param_pattern = r'(?:localparam|parameter)\s+(?:(\w+)\s*=\s*(\d+)|(\w+)\s*=\s*(\w+))'
|
|
for match in re.finditer(param_pattern, rtl_code):
|
|
if match:
|
|
name = match.group(1) or match.group(3)
|
|
if name and any(kw in name.upper() for kw in ['STATE', 'ST']):
|
|
states.append(name)
|
|
|
|
# If no explicit state names, try to extract from case statements
|
|
if not states:
|
|
case_match = re.search(r'case\s*\(\s*' + re.escape(state_var) + r'\s*\)(.*?)endcase', rtl_code, re.DOTALL | re.IGNORECASE)
|
|
if case_match:
|
|
case_body = case_match.group(1)
|
|
# Find all state names (words ending with colon)
|
|
state_names = re.findall(r'(\w+)\s*:', case_body)
|
|
states = [s for s in state_names if s.upper() not in ['NEXT', 'DEFAULT']]
|
|
|
|
return {
|
|
"state_variable": state_var,
|
|
"state_bits": int(state_bits) if state_bits else 0,
|
|
"states": states[:20] if states else [], # Limit to 20 states
|
|
"state_count": len(states) if states else 0
|
|
}
|
|
|
|
|
|
def extract_counters(rtl_code: str) -> List[Dict]:
|
|
"""
|
|
Extract counter/relator information
|
|
"""
|
|
counters = []
|
|
|
|
# Match counter registers: reg [N:0] counter_name
|
|
counter_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w*(?:count|cnt|counter|rel)\w*)'
|
|
for match in re.finditer(counter_pattern, rtl_code, re.IGNORECASE):
|
|
width = match.group(1)
|
|
name = match.group(2)
|
|
counters.append({
|
|
"name": name,
|
|
"width": int(width) + 1 if width else 1
|
|
})
|
|
|
|
return counters[:5] # Limit to 5 counters
|
|
|
|
|
|
def extract_fifos(rtl_code: str) -> List[Dict]:
|
|
"""
|
|
Extract FIFO buffer information
|
|
"""
|
|
fifos = []
|
|
|
|
# Match FIFO-like structures
|
|
fifo_pattern = r'(?:reg|wire)\s+(?:\[(\d+):0\])?\s*(\w*(?:fifo|queue|buffer)\w*)'
|
|
for match in re.finditer(fifo_pattern, rtl_code, re.IGNORECASE):
|
|
width = match.group(1)
|
|
name = match.group(2)
|
|
fifos.append({
|
|
"name": name,
|
|
"width": int(width) + 1 if width else 8
|
|
})
|
|
|
|
# Also look for FIFO instantiation patterns
|
|
fifo_inst_pattern = r'(\w+)\s+#\.\w+\s*\(.*?\)'
|
|
for match in re.finditer(fifo_inst_pattern, rtl_code, re.DOTALL):
|
|
name = match.group(1)
|
|
if 'fifo' in name.lower():
|
|
fifos.append({"name": name, "type": "instantiated"})
|
|
|
|
return fifos[:5] # Limit to 5 FIFOs
|
|
|
|
|
|
def extract_key_registers(rtl_code: str) -> List[str]:
|
|
"""
|
|
Extract key registered signals (not just clk/rst)
|
|
"""
|
|
registers = []
|
|
|
|
# Match reg declarations
|
|
reg_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w+)'
|
|
for match in re.finditer(reg_pattern, rtl_code):
|
|
width = match.group(1)
|
|
name = match.group(2)
|
|
|
|
# Skip common non-data signals
|
|
if any(kw in name.lower() for kw in ['clk', 'rst', 'reset', 'state', 'cnt', 'count', 'fifo', 'buffer']):
|
|
continue
|
|
|
|
registers.append(name)
|
|
|
|
return registers[:10] # Limit to 10 registers
|
|
|
|
|
|
def extract_parameters(rtl_code: str) -> List[Dict]:
|
|
"""
|
|
Extract module parameters
|
|
"""
|
|
params = []
|
|
|
|
param_pattern = r'parameter\s+(\w+)\s*=\s*([^;]+)'
|
|
for match in re.finditer(param_pattern, rtl_code):
|
|
name = match.group(1)
|
|
value = match.group(2).strip()
|
|
params.append({"name": name, "value": value})
|
|
|
|
return params[:10] # Limit to 10 parameters
|
|
|
|
|
|
def is_large_rtl(rtl_code: str, threshold: int = 5000) -> bool:
|
|
"""Check if RTL code exceeds size threshold"""
|
|
return len(rtl_code) >= threshold
|
|
|
|
|
|
def summarize_rtl(rtl_code: str, include_critical_section: bool = True) -> str:
|
|
"""
|
|
Generate a concise summary of RTL code for LLM prompts.
|
|
|
|
Args:
|
|
rtl_code: Full RTL Verilog code
|
|
include_critical_section: If True, include a sample of critical logic (FSM case, etc.)
|
|
|
|
Returns:
|
|
Formatted summary string
|
|
"""
|
|
lines = []
|
|
lines.append("=" * 60)
|
|
lines.append("RTL SUMMARY (Generated by rtl_summarizer)")
|
|
lines.append("=" * 60)
|
|
|
|
# Module name and type
|
|
module_name = extract_module_name(rtl_code)
|
|
circuit_type = detect_circuit_type(rtl_code)
|
|
lines.append(f"\nModule: {module_name}")
|
|
lines.append(f"Type: {circuit_type}")
|
|
|
|
# Interface signals
|
|
signals = extract_interface_signals(rtl_code)
|
|
if signals:
|
|
lines.append("\nInterface:")
|
|
# Group by direction
|
|
inputs = [s for s in signals if s["direction"] == "input"]
|
|
outputs = [s for s in signals if s["direction"] == "output"]
|
|
inouts = [s for s in signals if s["direction"] == "inout"]
|
|
|
|
if inputs:
|
|
input_str = ", ".join([f"{s['name']}{s['width']}" for s in inputs[:15]])
|
|
if len(inputs) > 15:
|
|
input_str += f", ... (+{len(inputs)-15} more)"
|
|
lines.append(f" Inputs: {input_str}")
|
|
|
|
if outputs:
|
|
output_str = ", ".join([f"{s['name']}{s['width']}" for s in outputs[:15]])
|
|
if len(outputs) > 15:
|
|
output_str += f", ... (+{len(outputs)-15} more)"
|
|
lines.append(f" Outputs: {output_str}")
|
|
|
|
if inouts:
|
|
lines.append(f" Inouts: {', '.join([s['name'] for s in inouts])}")
|
|
|
|
# Parameters
|
|
params = extract_parameters(rtl_code)
|
|
if params:
|
|
lines.append("\nParameters:")
|
|
for p in params[:10]:
|
|
lines.append(f" {p['name']} = {p['value']}")
|
|
|
|
# FSM info
|
|
fsm_info = extract_fsm_info(rtl_code)
|
|
if fsm_info and fsm_info['states']:
|
|
lines.append(f"\nFSM:")
|
|
lines.append(f" State Variable: {fsm_info['state_variable']}")
|
|
if fsm_info['state_bits']:
|
|
lines.append(f" State Bits: {fsm_info['state_bits']}")
|
|
state_list = ", ".join(fsm_info['states'][:15])
|
|
if fsm_info['state_count'] > 15:
|
|
state_list += f", ... (+{fsm_info['state_count']-15} more states)"
|
|
lines.append(f" States: {state_list}")
|
|
|
|
# Counters
|
|
counters = extract_counters(rtl_code)
|
|
if counters:
|
|
lines.append("\nCounters:")
|
|
for c in counters:
|
|
lines.append(f" - {c['name']} ({c.get('width', 1)}-bit)")
|
|
|
|
# FIFOs/Buffers
|
|
fifos = extract_fifos(rtl_code)
|
|
if fifos:
|
|
lines.append("\nFIFOs/Buffers:")
|
|
for f in fifos:
|
|
lines.append(f" - {f['name']}")
|
|
|
|
# Key registers
|
|
registers = extract_key_registers(rtl_code)
|
|
if registers:
|
|
lines.append("\nKey Registers:")
|
|
lines.append(f" {', '.join(registers)}")
|
|
|
|
# Critical section (FSM case body or main always block)
|
|
if include_critical_section:
|
|
# Try to extract FSM case body
|
|
if fsm_info:
|
|
case_match = re.search(
|
|
r'case\s*\(\s*' + re.escape(fsm_info['state_variable']) + r'\s*\)(.*?)endcase',
|
|
rtl_code,
|
|
re.DOTALL | re.IGNORECASE
|
|
)
|
|
if case_match:
|
|
lines.append("\n" + "-" * 40)
|
|
lines.append("FSM Case Body (for reference):")
|
|
lines.append("-" * 40)
|
|
case_body = case_match.group(1)[:800] # Limit to 800 chars
|
|
lines.append(case_body)
|
|
if len(case_match.group(1)) > 800:
|
|
lines.append("... (truncated)")
|
|
|
|
# Try to extract main always block
|
|
else:
|
|
always_match = re.search(r'(always\s*@.*?begin.*?end)', rtl_code, re.DOTALL)
|
|
if always_match:
|
|
lines.append("\n" + "-" * 40)
|
|
lines.append("Main Always Block (for reference):")
|
|
lines.append("-" * 40)
|
|
always_body = always_match.group(1)[:800]
|
|
lines.append(always_body)
|
|
if len(always_match.group(1)) > 800:
|
|
lines.append("... (truncated)")
|
|
|
|
lines.append("\n" + "=" * 60)
|
|
lines.append("Note: This is a SUMMARY. Full RTL code is available in the reference file.")
|
|
lines.append("=" * 60)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test with a sample
|
|
sample_rtl = """
|
|
module spi_controller #(
|
|
parameter CLK_DIV = 16,
|
|
parameter DATA_WIDTH = 8
|
|
) (
|
|
input wire clk,
|
|
input wire rst_n,
|
|
input wire [7:0] tx_data,
|
|
input wire tx_valid,
|
|
output reg [7:0] rx_data,
|
|
output reg rx_valid,
|
|
output wire spi_clk,
|
|
output wire spi_mosi,
|
|
input wire spi_miso
|
|
);
|
|
|
|
localparam IDLE = 3'd0;
|
|
localparam TX_START = 3'd1;
|
|
localparam TX_WAIT = 3'd2;
|
|
localparam RX_START = 3'd3;
|
|
localparam RX_WAIT = 3'd4;
|
|
localparam DONE = 3'd5;
|
|
|
|
reg [2:0] state;
|
|
reg [2:0] next_state;
|
|
reg [7:0] shift_reg;
|
|
reg [3:0] bit_count;
|
|
reg [15:0] clk_count;
|
|
|
|
always @(posedge clk or negedge rst_n) begin
|
|
if (!rst_n)
|
|
state <= IDLE;
|
|
else
|
|
state <= next_state;
|
|
end
|
|
|
|
always @(*) begin
|
|
case (state)
|
|
IDLE: next_state = tx_valid ? TX_START : IDLE;
|
|
TX_START: next_state = TX_WAIT;
|
|
TX_WAIT: next_state = (bit_count == 7) ? RX_START : TX_WAIT;
|
|
RX_START: next_state = RX_WAIT;
|
|
RX_WAIT: next_state = (bit_count == 7) ? DONE : RX_WAIT;
|
|
DONE: next_state = IDLE;
|
|
default: next_state = IDLE;
|
|
endcase
|
|
end
|
|
|
|
always @(posedge clk or negedge rst_n) begin
|
|
if (!rst_n) begin
|
|
shift_reg <= 0;
|
|
bit_count <= 0;
|
|
end else begin
|
|
case (state)
|
|
TX_START: begin
|
|
shift_reg <= tx_data;
|
|
bit_count <= 0;
|
|
end
|
|
TX_WAIT: begin
|
|
shift_reg <= {shift_reg[6:0], 1'b0};
|
|
bit_count <= bit_count + 1;
|
|
end
|
|
RX_WAIT: begin
|
|
rx_data[7-bit_count] <= spi_miso;
|
|
bit_count <= bit_count + 1;
|
|
end
|
|
DONE: begin
|
|
rx_valid <= 1;
|
|
end
|
|
endcase
|
|
end
|
|
end
|
|
|
|
assign spi_clk = clk_count >= CLK_DIV[15:0] && state != IDLE;
|
|
assign spi_mosi = shift_reg[7];
|
|
|
|
endmodule
|
|
"""
|
|
|
|
summary = summarize_rtl(sample_rtl)
|
|
print(summary)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"is_large_rtl (threshold=5000): {is_large_rtl(sample_rtl, 5000)}")
|
|
print(f"is_large_rtl (threshold=2000): {is_large_rtl(sample_rtl, 2000)}")
|