""" Description : RTL Mutation Operators for TB Discriminator 提供 FSM 和 DataPath 的 mutation 算子,用于生成可控的错误 RTL Author : CorrectBench """ import re import random class FSMMutationOperators: """针对 FSM 的 mutation 算子""" @staticmethod def flip_state_condition(code: str) -> str: """ 翻转状态转移条件中的比较运算符 if (state == SEND) -> if (state != SEND) if (bit_cnt == 0) -> if (bit_cnt != 0) """ result = code # 匹配 == 的情况 patterns_eq = [ (r'\(state\s*==\s*(\w+)\)', r'(state != \1)'), (r'\(bit_cnt\s*==\s*0\)', r'(bit_cnt != 0)'), (r'\(bit_count\s*==\s*0\)', r'(bit_count != 0)'), (r'\((\w+)\s*==\s*(\w+)\)', r'(\1 != \2)'), ] for pattern, replacement in patterns_eq: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result # 匹配 != 的情况 patterns_ne = [ (r'\(state\s*!=\s*(\w+)\)', r'(state == \1)'), (r'\(bit_cnt\s*!=\s*0\)', r'(bit_cnt == 0)'), (r'\(bit_count\s*!=\s*0\)', r'(bit_count == 0)'), ] for pattern, replacement in patterns_ne: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code # 没有找到可翻转的条件 @staticmethod def swap_next_state(code: str) -> str: """ 交换两个分支的下一个状态 适用于 if-else 结构交换两个分支 """ result = code # 匹配 if-else 结构并交换分支 # 简化版:找到相邻的两个不同状态转移并交换 if_blocks = re.finditer( r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);', result, re.DOTALL ) matches = list(if_blocks) if matches: m = random.choice(matches) full_match = m.group(0) # 交换 then 和 else 分支的赋值 new_match = re.sub( r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);', lambda mo: f"if ({m.group(0).split('(')[1].split(')')[0]}) begin {mo.group(3)} <= {mo.group(4)}; end else begin {mo.group(1)} <= {mo.group(2)};", full_match, count=1 ) result = result.replace(full_match, new_match, 1) return result return code @staticmethod def invert_counter_update(code: str) -> str: """ 反转计数器更新 bit_cnt = bit_cnt - 1 -> bit_cnt = bit_cnt + 1 bit_cnt <= bit_cnt + 1 -> bit_cnt <= bit_cnt - 1 """ result = code # 匹配减量 dec_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*-\s*1' if re.search(dec_pattern, result): result = re.sub(r'(\w+)\s*<=\s*\1\s*-\s*1', r'\1 <= \1 + 1', result, count=1) return result # 匹配增量 inc_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*\+\s*1' if re.search(inc_pattern, result): result = re.sub(r'(\w+)\s*<=\s*\1\s*\+\s*1', r'\1 <= \1 - 1', result, count=1) return result return code @staticmethod def stuck_at_state(code: str) -> str: """ 让状态机卡在某个状态 在状态转移条件中加入 && 0 使条件永远为假 """ result = code # 找到 if (start) 这样的条件,在其中加入 && 0 if re.search(r'if\s*\(\s*start\s*\)', result): result = re.sub(r'if\s*\(\s*start\s*\)', 'if (start && 0)', result, count=1) return result return code @staticmethod def flip_comparison_operator(code: str) -> str: """ 翻转比较运算符 > < >= <= if (bit_cnt > 0) -> if (bit_cnt <= 0) """ result = code comparisons = [ (r'\b>\s*0\b', '<= 0'), (r'\b<\s*0\b', '> 0'), (r'\b>=\s*0\b', '< 0'), (r'\b<=\s*0\b', '> 0'), ] for pattern, replacement in comparisons: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code class DataPathMutationOperators: """数据通路 mutation 算子""" @staticmethod def width_mismatch_assignment(code: str) -> str: """ 生成位宽不匹配的赋值 对于 output [7:0] data_out,赋值时用 1'b0 而不是 8'b0 """ result = code # 查找 output reg [7:0] data_out 这样的声明 width_match = re.search(r'output\s+reg\s+\[(\d+):0\]\s+(\w+)', result) if width_match: width = int(width_match.group(1)) signal_name = width_match.group(2) # 替换赋值为 1-bit 赋值(故意产生 width mismatch) # 查找 data_out <= xxx; 并将 xxx 替换为 1-bit pattern = rf'({signal_name}\s*<=\s*)(\d+\'[bB]\d+|[a-fA-F0-9]+h?[a-fA-F0-9]*)' if re.search(pattern, result): result = re.sub(pattern, r'\g<1>1\'b0', result, count=1) return result # 或者直接替换赋值为 1'b0 pattern2 = rf'({signal_name}\s*<=\s*)(\w+)' if re.search(pattern2, result): result = re.sub(pattern2, r'\g<1>1\'b0', result, count=1) return result return code @staticmethod def missing_assignment(code: str) -> str: """ 移除对某个关键信号的赋值(设为注释) 通常移除 data_out 的赋值会让 RTL 功能出错 """ result = code # 移除 data_out 或 data 的赋值 signals_to_remove = ['data_out', 'rx_data', 'shift_reg'] for signal in signals_to_remove: # 匹配 signal <= xxx; 格式 pattern = rf'(\s*{signal}\s*<=\s*[^;]+;)' if re.search(pattern, result): # 替换为空(注释掉) result = re.sub(pattern, r'// \1 // MUTATION: missing assignment', result, count=1) return result return code @staticmethod def swap_shift_direction(code: str) -> str: """ 翻转移位方向 shift_reg <= {shift_reg[6:0], miso} -> shift_reg <= {miso, shift_reg[7:1]} """ result = code # 匹配 {a, b} 格式的位移 patterns = [ # MSB first: {reg[N-1:0], new_bit} -> {new_bit, reg[N-2:0]} (r'(\w+)\s*<=\s*\{\s*(\w+)\[(\d+):0\],?\s*(\w+)\s*\}', r'\1 <= {\4, \2[\3-1:0]}'), # LSB first 变体 (r'(\w+)\s*<=\s*\{\s*(\w+),\s*(\w+)\[(\d+):1\]\s*\}', r'\1 <= {\3[\4:1], \2}'), ] for pattern, replacement in patterns: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code @staticmethod def invert_bit_select(code: str) -> str: """ 反转位选择 mosi <= shift_reg[7] -> mosi <= ~shift_reg[7] """ result = code # 匹配 signal[bit_idx] 格式的选择 pattern = r'(\w+)\s*<=\s*(\w+)\[(\d+)\]' matches = list(re.finditer(pattern, result)) if matches: m = random.choice(matches) signal = m.group(1) bit_select = m.group(2) bit_idx = m.group(3) original = m.group(0) replacement = f'{signal} <= ~{bit_select}[{bit_idx}]' result = result.replace(original, replacement, 1) return result return code @staticmethod def wrong_bit_index(code: str) -> str: """ 使用错误的位索引 mosi <= shift_reg[7] -> mosi <= shift_reg[6] """ result = code # 匹配 signal[7] 格式并改成 [6] patterns = [ (r'(\w+)\s*<=\s*(\w+)\[7\]', r'\1 <= \2[6]'), (r'(\w+)\s*<=\s*(\w+)\[6\]', r'\1 <= \2[7]'), (r'(\w+)\s*<=\s*(\w+)\[0\]', r'\1 <= \2[1]'), (r'(\w+)\s*<=\s*(\w+)\[1\]', r'\1 <= \2[0]'), ] for pattern, replacement in patterns: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code @staticmethod def swap_registers(code: str) -> str: """ 交换两个寄存器的赋值 temp_a <= xxx; temp_b <= yyy; -> temp_a <= yyy; temp_b <= xxx; """ result = code # 查找两个相邻的寄存器赋值 pattern = r'(\s*(\w+)\s*<=\s*([^;]+);)\s*(\s*(\w+)\s*<=\s*([^;]+);)' matches = list(re.finditer(pattern, result)) if matches: m = random.choice(matches) # 交换两个赋值 new_text = m.group(1).replace(m.group(2), m.group(5)).replace(m.group(3), m.group(6)) new_text2 = m.group(4).replace(m.group(5), m.group(2)).replace(m.group(6), m.group(3)) result = result[:m.start()] + new_text + ' ' + new_text2 + result[m.end():] return result return code class ConservativeMutationOperators: """ 保守 mutation 算子 - 只做微小扰动,保持 RTL 整体结构不变 这些算子产生轻微错误,生成的 RTL 更可能与 reference RTL 有部分匹配的行为 """ @staticmethod def single_bit_flip(code: str) -> str: """ 只翻转一个单 bit 信号的值 spi_clk <= 1'b0 -> spi_clk <= 1'b1 busy <= 1 -> busy <= 0 """ result = code # 匹配单 bit 赋值: signal <= 1 或 signal <= 1'b0 或 signal <= 1'b1 # 使用 lambda 避免引号转义问题 patterns = [ (r'(\w+)\s*<=\s*1\'b0', lambda m: f'{m.group(1)} <= 1\'b1'), (r'(\w+)\s*<=\s*1\'b1', lambda m: f'{m.group(1)} <= 1\'b0'), (r'(\w+)\s*<=\s*1\b', lambda m: f'{m.group(1)} <= 0'), (r'(\w+)\s*<=\s*0\b', lambda m: f'{m.group(1)} <= 1'), ] for pattern, replacement in patterns: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code @staticmethod def tiny_offset(code: str) -> str: """ 给数值加/减一个微小 offset bit_count <= 8 -> bit_count <= 9 或 bit_count <= 7 注意:只对赋值语句中的数值操作,不改变 localparam 等定义 """ result = code # 匹配赋值语句中的小数值(排除 localparam 等定义) # 例如: bit_count <= 3'd0 -> bit_count <= 3'd1 patterns = [ # 3'd0 -> 3'd1 或 3'd7 (r'(bit_count|bit_cnt|counter)\s*<=\s*(\d+)\'d(\d+)', lambda m: f'{m.group(1)} <= {m.group(2)}\'d{_offset_value(int(m.group(3)))}'), ] for pattern, replacement in patterns: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code @staticmethod def bit_index_shift(code: str) -> str: """ 只改变 bit index 1 位(最保守的改变) shift_reg[7] -> shift_reg[6] 或 shift_reg[8] (如果存在) 注意:只偏移1位,不翻转整体行为 """ result = code # 只偏移 7<->6 这种相邻的 index (针对 SPI 的 MSB first 特点) patterns = [ (r'(shift_reg|data_in|data_out)\[7\]', r'\1[6]'), (r'(shift_reg|data_in|data_out)\[6\]', r'\1[7]'), ] for pattern, replacement in patterns: if re.search(pattern, result): result = re.sub(pattern, replacement, result, count=1) return result return code @staticmethod def constant_slight_change(code: str) -> str: """ 微调常数值(仅针对赋值语句,不改变 localparam) 只改 LSB 一位 """ result = code # 8'hXX -> 8'hXY (只改最低位,在赋值语句中) # 匹配 data_in = 8'hAA 格式 hex_pattern = r'(\w+\s*<=\s*8\'h[0-9a-fA-F]{1,2})([0-9a-fA-F])' if re.search(hex_pattern, result): def hex_replacer(m): prefix = m.group(1) last_digit = m.group(2) # 翻转最后一位 new_digit = '1' if last_digit == '0' else '0' if last_digit == '1' else last_digit return prefix + new_digit result = re.sub(hex_pattern, hex_replacer, result, count=1) return result return code @staticmethod def invert_single_signal(code: str) -> str: """ 只取反单个信号的选择 mosi <= shift_reg[7] -> mosi <= ~shift_reg[7] """ result = code # 匹配 mosi <= shift_reg[7] 格式并取反 pattern = r'(mosi)\s*<=\s*(\w+\[)(\d+)\]' matches = list(re.finditer(pattern, result)) if matches: m = random.choice(matches) signal = m.group(1) bracket_part = m.group(2) idx = m.group(3) original = m.group(0) # 替换 mosi <= xxx[7] 为 mosi <= ~xxx[7] if '~' not in original: replacement = f'{signal} <= ~{bracket_part}{idx}]' result = result.replace(original, replacement, 1) return result return code def _offset_value(val: int) -> int: """对数值进行微小偏移,只在有效范围内操作""" # 只偏移 1,且只在 1-14 范围内 if 1 <= val <= 14: return val + random.choice([-1, 1]) elif val > 14: return val - 1 elif val == 0: return 1 return val class MutationOperatorRegistry: """Mutation 算子注册表""" # 保守算子(权重更高,因为它们产生更接近 reference 的 RTL) CONSERVATIVE_OPERATORS = [ ("single_bit_flip", ConservativeMutationOperators.single_bit_flip), ("bit_index_shift", ConservativeMutationOperators.bit_index_shift), ("constant_slight_change", ConservativeMutationOperators.constant_slight_change), ("tiny_offset", ConservativeMutationOperators.tiny_offset), ("invert_single_signal", ConservativeMutationOperators.invert_single_signal), ] # 激进算子(权重较低,因为它们可能完全破坏 RTL 结构) AGGRESSIVE_OPERATORS = [ ("flip_state_condition", FSMMutationOperators.flip_state_condition), ("invert_counter", FSMMutationOperators.invert_counter_update), ("stuck_at_state", FSMMutationOperators.stuck_at_state), ("flip_comparison", FSMMutationOperators.flip_comparison_operator), ("width_mismatch", DataPathMutationOperators.width_mismatch_assignment), ("missing_assignment", DataPathMutationOperators.missing_assignment), ("swap_shift", DataPathMutationOperators.swap_shift_direction), ("invert_bit", DataPathMutationOperators.invert_bit_select), ("wrong_bit_index", DataPathMutationOperators.wrong_bit_index), ("swap_registers", DataPathMutationOperators.swap_registers), ] # 默认权重:保守算子出现概率更高 DEFAULT_CONSERVATIVE_WEIGHT = 0.7 # 70% 概率选择保守算子 @classmethod def get_random_operator(cls, conservative_weight: float = None): """ 随机获取一个 mutation 算子 Args: conservative_weight: 保守算子的权重 (0.0-1.0),默认 0.7 """ if conservative_weight is None: conservative_weight = cls.DEFAULT_CONSERVATIVE_WEIGHT import random if random.random() < conservative_weight: return random.choice(cls.CONSERVATIVE_OPERATORS) else: return random.choice(cls.AGGRESSIVE_OPERATORS) @classmethod def get_all_operators(cls): """获取所有算子""" return cls.CONSERVATIVE_OPERATORS + cls.AGGRESSIVE_OPERATORS @classmethod def get_conservative_operators(cls): """只获取保守算子""" return cls.CONSERVATIVE_OPERATORS @classmethod def get_aggressive_operators(cls): """只获取激进算子""" return cls.AGGRESSIVE_OPERATORS