""" Description : RTL summarization tool for large RTL files Generates concise summaries to reduce LLM prompt size while preserving key information Author : CorrectBench Time : 2026/04/19 """ import re from typing import Optional, Dict, List, Tuple def extract_module_name(rtl_code: str) -> str: """Extract module name from RTL code""" match = re.search(r'module\s+(\w+)', rtl_code) return match.group(1) if match else "unknown" def extract_module_header(rtl_code: str) -> Optional[str]: """Extract module header (ports and parameters)""" match = re.search( r'module\s+\w+\s*(#\([^)]*\))?\s*\(([^)]*)\)', rtl_code, re.DOTALL ) if match: params = match.group(1) or "" ports = match.group(2) or "" return f"module (...) {params.strip()}\n({ports.strip()})" return None def extract_interface_signals(rtl_code: str) -> List[Dict[str, str]]: """ Extract interface signals from module header Returns list of dicts with name, direction, width """ signals = [] # Match input/output declarations in the module header header_match = re.search(r'module\s+\w+\s*(?:#\([^)]*\))?\s*\(([^)]*)\)', rtl_code, re.DOTALL) if not header_match: return signals port_section = header_match.group(1) # Match input/output/inout with optional width # Pattern: input/output/inout [wire/reg] [signed] [width] name pattern = r'(input|output|inout)\s+(?:wire|reg)?\s*(?:signed)?\s*(?:\[([^:]+):([^\]]+)\])?\s*(\w+)' for match in re.finditer(pattern, port_section): direction = match.group(1) width_high = match.group(2) width_low = match.group(3) name = match.group(4) width = "" if width_high and width_low: width = f"[{width_high}:{width_low}]" signals.append({ "name": name, "direction": direction, "width": width }) return signals def detect_circuit_type(rtl_code: str) -> str: """ Detect circuit type: FSM, COMBINATIONAL, SEQUENTIAL, PROTOCOL, or MIXED """ has_always_ff = bool(re.search(r'always_ff\s*@', rtl_code)) has_always_latch = bool(re.search(r'always_latch\s*@', rtl_code)) has_always_comb = bool(re.search(r'always_comb\s*@', rtl_code)) has_regular_always = bool(re.search(r'always\s*@', rtl_code)) # Check for FSM patterns has_state_reg = bool(re.search(r'reg\s+\w*state\w*', rtl_code, re.IGNORECASE)) has_case_state = bool(re.search(r'case\s*\(\s*\w*state\w*\s*\)', rtl_code, re.IGNORECASE)) is_fsm = has_state_reg and has_case_state # Check for protocol interfaces has_handshake = bool(re.search(r'(valid|ready|ack|request|grant)', rtl_code, re.IGNORECASE)) has_spi = bool(re.search(r'(spi|miso|mosi|sclk)', rtl_code, re.IGNORECASE)) has_i2c = bool(re.search(r'(i2c|sda|scl)', rtl_code, re.IGNORECASE)) is_protocol = has_handshake or has_spi or has_i2c if is_fsm: if is_protocol: return "SEQUENTIAL (FSM + PROTOCOL)" return "SEQUENTIAL (FSM)" elif has_always_ff or (has_regular_always and not has_always_comb): if is_protocol: return "SEQUENTIAL (PROTOCOL)" return "SEQUENTIAL" elif has_always_comb or has_always_latch: return "COMBINATIONAL" elif is_protocol: return "PROTOCOL" else: return "COMBINATIONAL" def extract_fsm_info(rtl_code: str) -> Optional[Dict]: """ Extract FSM information: states, state variable, transitions """ # Find state register state_var_match = re.search(r'reg\s+(?:\[(\d+):0\])?\s*(\w*state\w*)', rtl_code, re.IGNORECASE) if not state_var_match: return None state_bits = state_var_match.group(1) state_var = state_var_match.group(2) # Extract state names from localparam or parameter states = [] param_pattern = r'(?:localparam|parameter)\s+(?:(\w+)\s*=\s*(\d+)|(\w+)\s*=\s*(\w+))' for match in re.finditer(param_pattern, rtl_code): if match: name = match.group(1) or match.group(3) if name and any(kw in name.upper() for kw in ['STATE', 'ST']): states.append(name) # If no explicit state names, try to extract from case statements if not states: case_match = re.search(r'case\s*\(\s*' + re.escape(state_var) + r'\s*\)(.*?)endcase', rtl_code, re.DOTALL | re.IGNORECASE) if case_match: case_body = case_match.group(1) # Find all state names (words ending with colon) state_names = re.findall(r'(\w+)\s*:', case_body) states = [s for s in state_names if s.upper() not in ['NEXT', 'DEFAULT']] return { "state_variable": state_var, "state_bits": int(state_bits) if state_bits else 0, "states": states[:20] if states else [], # Limit to 20 states "state_count": len(states) if states else 0 } def extract_counters(rtl_code: str) -> List[Dict]: """ Extract counter/relator information """ counters = [] # Match counter registers: reg [N:0] counter_name counter_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w*(?:count|cnt|counter|rel)\w*)' for match in re.finditer(counter_pattern, rtl_code, re.IGNORECASE): width = match.group(1) name = match.group(2) counters.append({ "name": name, "width": int(width) + 1 if width else 1 }) return counters[:5] # Limit to 5 counters def extract_fifos(rtl_code: str) -> List[Dict]: """ Extract FIFO buffer information """ fifos = [] # Match FIFO-like structures fifo_pattern = r'(?:reg|wire)\s+(?:\[(\d+):0\])?\s*(\w*(?:fifo|queue|buffer)\w*)' for match in re.finditer(fifo_pattern, rtl_code, re.IGNORECASE): width = match.group(1) name = match.group(2) fifos.append({ "name": name, "width": int(width) + 1 if width else 8 }) # Also look for FIFO instantiation patterns fifo_inst_pattern = r'(\w+)\s+#\.\w+\s*\(.*?\)' for match in re.finditer(fifo_inst_pattern, rtl_code, re.DOTALL): name = match.group(1) if 'fifo' in name.lower(): fifos.append({"name": name, "type": "instantiated"}) return fifos[:5] # Limit to 5 FIFOs def extract_key_registers(rtl_code: str) -> List[str]: """ Extract key registered signals (not just clk/rst) """ registers = [] # Match reg declarations reg_pattern = r'reg\s+(?:\[(\d+):0\])?\s*(\w+)' for match in re.finditer(reg_pattern, rtl_code): width = match.group(1) name = match.group(2) # Skip common non-data signals if any(kw in name.lower() for kw in ['clk', 'rst', 'reset', 'state', 'cnt', 'count', 'fifo', 'buffer']): continue registers.append(name) return registers[:10] # Limit to 10 registers def extract_parameters(rtl_code: str) -> List[Dict]: """ Extract module parameters """ params = [] param_pattern = r'parameter\s+(\w+)\s*=\s*([^;]+)' for match in re.finditer(param_pattern, rtl_code): name = match.group(1) value = match.group(2).strip() params.append({"name": name, "value": value}) return params[:10] # Limit to 10 parameters def is_large_rtl(rtl_code: str, threshold: int = 5000) -> bool: """Check if RTL code exceeds size threshold""" return len(rtl_code) >= threshold def summarize_rtl(rtl_code: str, include_critical_section: bool = True) -> str: """ Generate a concise summary of RTL code for LLM prompts. Args: rtl_code: Full RTL Verilog code include_critical_section: If True, include a sample of critical logic (FSM case, etc.) Returns: Formatted summary string """ lines = [] lines.append("=" * 60) lines.append("RTL SUMMARY (Generated by rtl_summarizer)") lines.append("=" * 60) # Module name and type module_name = extract_module_name(rtl_code) circuit_type = detect_circuit_type(rtl_code) lines.append(f"\nModule: {module_name}") lines.append(f"Type: {circuit_type}") # Interface signals signals = extract_interface_signals(rtl_code) if signals: lines.append("\nInterface:") # Group by direction inputs = [s for s in signals if s["direction"] == "input"] outputs = [s for s in signals if s["direction"] == "output"] inouts = [s for s in signals if s["direction"] == "inout"] if inputs: input_str = ", ".join([f"{s['name']}{s['width']}" for s in inputs[:15]]) if len(inputs) > 15: input_str += f", ... (+{len(inputs)-15} more)" lines.append(f" Inputs: {input_str}") if outputs: output_str = ", ".join([f"{s['name']}{s['width']}" for s in outputs[:15]]) if len(outputs) > 15: output_str += f", ... (+{len(outputs)-15} more)" lines.append(f" Outputs: {output_str}") if inouts: lines.append(f" Inouts: {', '.join([s['name'] for s in inouts])}") # Parameters params = extract_parameters(rtl_code) if params: lines.append("\nParameters:") for p in params[:10]: lines.append(f" {p['name']} = {p['value']}") # FSM info fsm_info = extract_fsm_info(rtl_code) if fsm_info and fsm_info['states']: lines.append(f"\nFSM:") lines.append(f" State Variable: {fsm_info['state_variable']}") if fsm_info['state_bits']: lines.append(f" State Bits: {fsm_info['state_bits']}") state_list = ", ".join(fsm_info['states'][:15]) if fsm_info['state_count'] > 15: state_list += f", ... (+{fsm_info['state_count']-15} more states)" lines.append(f" States: {state_list}") # Counters counters = extract_counters(rtl_code) if counters: lines.append("\nCounters:") for c in counters: lines.append(f" - {c['name']} ({c.get('width', 1)}-bit)") # FIFOs/Buffers fifos = extract_fifos(rtl_code) if fifos: lines.append("\nFIFOs/Buffers:") for f in fifos: lines.append(f" - {f['name']}") # Key registers registers = extract_key_registers(rtl_code) if registers: lines.append("\nKey Registers:") lines.append(f" {', '.join(registers)}") # Critical section (FSM case body or main always block) if include_critical_section: # Try to extract FSM case body if fsm_info: case_match = re.search( r'case\s*\(\s*' + re.escape(fsm_info['state_variable']) + r'\s*\)(.*?)endcase', rtl_code, re.DOTALL | re.IGNORECASE ) if case_match: lines.append("\n" + "-" * 40) lines.append("FSM Case Body (for reference):") lines.append("-" * 40) case_body = case_match.group(1)[:800] # Limit to 800 chars lines.append(case_body) if len(case_match.group(1)) > 800: lines.append("... (truncated)") # Try to extract main always block else: always_match = re.search(r'(always\s*@.*?begin.*?end)', rtl_code, re.DOTALL) if always_match: lines.append("\n" + "-" * 40) lines.append("Main Always Block (for reference):") lines.append("-" * 40) always_body = always_match.group(1)[:800] lines.append(always_body) if len(always_match.group(1)) > 800: lines.append("... (truncated)") lines.append("\n" + "=" * 60) lines.append("Note: This is a SUMMARY. Full RTL code is available in the reference file.") lines.append("=" * 60) return "\n".join(lines) if __name__ == "__main__": # Test with a sample sample_rtl = """ module spi_controller #( parameter CLK_DIV = 16, parameter DATA_WIDTH = 8 ) ( input wire clk, input wire rst_n, input wire [7:0] tx_data, input wire tx_valid, output reg [7:0] rx_data, output reg rx_valid, output wire spi_clk, output wire spi_mosi, input wire spi_miso ); localparam IDLE = 3'd0; localparam TX_START = 3'd1; localparam TX_WAIT = 3'd2; localparam RX_START = 3'd3; localparam RX_WAIT = 3'd4; localparam DONE = 3'd5; reg [2:0] state; reg [2:0] next_state; reg [7:0] shift_reg; reg [3:0] bit_count; reg [15:0] clk_count; always @(posedge clk or negedge rst_n) begin if (!rst_n) state <= IDLE; else state <= next_state; end always @(*) begin case (state) IDLE: next_state = tx_valid ? TX_START : IDLE; TX_START: next_state = TX_WAIT; TX_WAIT: next_state = (bit_count == 7) ? RX_START : TX_WAIT; RX_START: next_state = RX_WAIT; RX_WAIT: next_state = (bit_count == 7) ? DONE : RX_WAIT; DONE: next_state = IDLE; default: next_state = IDLE; endcase end always @(posedge clk or negedge rst_n) begin if (!rst_n) begin shift_reg <= 0; bit_count <= 0; end else begin case (state) TX_START: begin shift_reg <= tx_data; bit_count <= 0; end TX_WAIT: begin shift_reg <= {shift_reg[6:0], 1'b0}; bit_count <= bit_count + 1; end RX_WAIT: begin rx_data[7-bit_count] <= spi_miso; bit_count <= bit_count + 1; end DONE: begin rx_valid <= 1; end endcase end end assign spi_clk = clk_count >= CLK_DIV[15:0] && state != IDLE; assign spi_mosi = shift_reg[7]; endmodule """ summary = summarize_rtl(sample_rtl) print(summary) print("\n" + "=" * 60) print(f"is_large_rtl (threshold=5000): {is_large_rtl(sample_rtl, 5000)}") print(f"is_large_rtl (threshold=2000): {is_large_rtl(sample_rtl, 2000)}")