Files
CGA-bench/autoline/mutation_operators.py

489 lines
16 KiB
Python
Raw Normal View History

2026-05-22 10:02:42 +08:00
"""
Description : RTL Mutation Operators for TB Discriminator
提供 FSM DataPath mutation 算子用于生成可控的错误 RTL
Author : CorrectBench
"""
import re
import random
class FSMMutationOperators:
"""针对 FSM 的 mutation 算子"""
@staticmethod
def flip_state_condition(code: str) -> str:
"""
翻转状态转移条件中的比较运算符
if (state == SEND) -> if (state != SEND)
if (bit_cnt == 0) -> if (bit_cnt != 0)
"""
result = code
# 匹配 == 的情况
patterns_eq = [
(r'\(state\s*==\s*(\w+)\)', r'(state != \1)'),
(r'\(bit_cnt\s*==\s*0\)', r'(bit_cnt != 0)'),
(r'\(bit_count\s*==\s*0\)', r'(bit_count != 0)'),
(r'\((\w+)\s*==\s*(\w+)\)', r'(\1 != \2)'),
]
for pattern, replacement in patterns_eq:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
# 匹配 != 的情况
patterns_ne = [
(r'\(state\s*!=\s*(\w+)\)', r'(state == \1)'),
(r'\(bit_cnt\s*!=\s*0\)', r'(bit_cnt == 0)'),
(r'\(bit_count\s*!=\s*0\)', r'(bit_count == 0)'),
]
for pattern, replacement in patterns_ne:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code # 没有找到可翻转的条件
@staticmethod
def swap_next_state(code: str) -> str:
"""
交换两个分支的下一个状态
适用于 if-else 结构交换两个分支
"""
result = code
# 匹配 if-else 结构并交换分支
# 简化版:找到相邻的两个不同状态转移并交换
if_blocks = re.finditer(
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
result, re.DOTALL
)
matches = list(if_blocks)
if matches:
m = random.choice(matches)
full_match = m.group(0)
# 交换 then 和 else 分支的赋值
new_match = re.sub(
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
lambda mo: f"if ({m.group(0).split('(')[1].split(')')[0]}) begin {mo.group(3)} <= {mo.group(4)}; end else begin {mo.group(1)} <= {mo.group(2)};",
full_match, count=1
)
result = result.replace(full_match, new_match, 1)
return result
return code
@staticmethod
def invert_counter_update(code: str) -> str:
"""
反转计数器更新
bit_cnt = bit_cnt - 1 -> bit_cnt = bit_cnt + 1
bit_cnt <= bit_cnt + 1 -> bit_cnt <= bit_cnt - 1
"""
result = code
# 匹配减量
dec_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*-\s*1'
if re.search(dec_pattern, result):
result = re.sub(r'(\w+)\s*<=\s*\1\s*-\s*1', r'\1 <= \1 + 1', result, count=1)
return result
# 匹配增量
inc_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*\+\s*1'
if re.search(inc_pattern, result):
result = re.sub(r'(\w+)\s*<=\s*\1\s*\+\s*1', r'\1 <= \1 - 1', result, count=1)
return result
return code
@staticmethod
def stuck_at_state(code: str) -> str:
"""
让状态机卡在某个状态
在状态转移条件中加入 && 0 使条件永远为假
"""
result = code
# 找到 if (start) 这样的条件,在其中加入 && 0
if re.search(r'if\s*\(\s*start\s*\)', result):
result = re.sub(r'if\s*\(\s*start\s*\)', 'if (start && 0)', result, count=1)
return result
return code
@staticmethod
def flip_comparison_operator(code: str) -> str:
"""
翻转比较运算符 > < >= <=
if (bit_cnt > 0) -> if (bit_cnt <= 0)
"""
result = code
comparisons = [
(r'\b>\s*0\b', '<= 0'),
(r'\b<\s*0\b', '> 0'),
(r'\b>=\s*0\b', '< 0'),
(r'\b<=\s*0\b', '> 0'),
]
for pattern, replacement in comparisons:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
class DataPathMutationOperators:
"""数据通路 mutation 算子"""
@staticmethod
def width_mismatch_assignment(code: str) -> str:
"""
生成位宽不匹配的赋值
对于 output [7:0] data_out赋值时用 1'b0 而不是 8'b0
"""
result = code
# 查找 output reg [7:0] data_out 这样的声明
width_match = re.search(r'output\s+reg\s+\[(\d+):0\]\s+(\w+)', result)
if width_match:
width = int(width_match.group(1))
signal_name = width_match.group(2)
# 替换赋值为 1-bit 赋值(故意产生 width mismatch
# 查找 data_out <= xxx; 并将 xxx 替换为 1-bit
pattern = rf'({signal_name}\s*<=\s*)(\d+\'[bB]\d+|[a-fA-F0-9]+h?[a-fA-F0-9]*)'
if re.search(pattern, result):
result = re.sub(pattern, r'\g<1>1\'b0', result, count=1)
return result
# 或者直接替换赋值为 1'b0
pattern2 = rf'({signal_name}\s*<=\s*)(\w+)'
if re.search(pattern2, result):
result = re.sub(pattern2, r'\g<1>1\'b0', result, count=1)
return result
return code
@staticmethod
def missing_assignment(code: str) -> str:
"""
移除对某个关键信号的赋值设为注释
通常移除 data_out 的赋值会让 RTL 功能出错
"""
result = code
# 移除 data_out 或 data 的赋值
signals_to_remove = ['data_out', 'rx_data', 'shift_reg']
for signal in signals_to_remove:
# 匹配 signal <= xxx; 格式
pattern = rf'(\s*{signal}\s*<=\s*[^;]+;)'
if re.search(pattern, result):
# 替换为空(注释掉)
result = re.sub(pattern, r'// \1 // MUTATION: missing assignment', result, count=1)
return result
return code
@staticmethod
def swap_shift_direction(code: str) -> str:
"""
翻转移位方向
shift_reg <= {shift_reg[6:0], miso} -> shift_reg <= {miso, shift_reg[7:1]}
"""
result = code
# 匹配 {a, b} 格式的位移
patterns = [
# MSB first: {reg[N-1:0], new_bit} -> {new_bit, reg[N-2:0]}
(r'(\w+)\s*<=\s*\{\s*(\w+)\[(\d+):0\],?\s*(\w+)\s*\}',
r'\1 <= {\4, \2[\3-1:0]}'),
# LSB first 变体
(r'(\w+)\s*<=\s*\{\s*(\w+),\s*(\w+)\[(\d+):1\]\s*\}',
r'\1 <= {\3[\4:1], \2}'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def invert_bit_select(code: str) -> str:
"""
反转位选择
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
"""
result = code
# 匹配 signal[bit_idx] 格式的选择
pattern = r'(\w+)\s*<=\s*(\w+)\[(\d+)\]'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
signal = m.group(1)
bit_select = m.group(2)
bit_idx = m.group(3)
original = m.group(0)
replacement = f'{signal} <= ~{bit_select}[{bit_idx}]'
result = result.replace(original, replacement, 1)
return result
return code
@staticmethod
def wrong_bit_index(code: str) -> str:
"""
使用错误的位索引
mosi <= shift_reg[7] -> mosi <= shift_reg[6]
"""
result = code
# 匹配 signal[7] 格式并改成 [6]
patterns = [
(r'(\w+)\s*<=\s*(\w+)\[7\]', r'\1 <= \2[6]'),
(r'(\w+)\s*<=\s*(\w+)\[6\]', r'\1 <= \2[7]'),
(r'(\w+)\s*<=\s*(\w+)\[0\]', r'\1 <= \2[1]'),
(r'(\w+)\s*<=\s*(\w+)\[1\]', r'\1 <= \2[0]'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def swap_registers(code: str) -> str:
"""
交换两个寄存器的赋值
temp_a <= xxx; temp_b <= yyy; -> temp_a <= yyy; temp_b <= xxx;
"""
result = code
# 查找两个相邻的寄存器赋值
pattern = r'(\s*(\w+)\s*<=\s*([^;]+);)\s*(\s*(\w+)\s*<=\s*([^;]+);)'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
# 交换两个赋值
new_text = m.group(1).replace(m.group(2), m.group(5)).replace(m.group(3), m.group(6))
new_text2 = m.group(4).replace(m.group(5), m.group(2)).replace(m.group(6), m.group(3))
result = result[:m.start()] + new_text + ' ' + new_text2 + result[m.end():]
return result
return code
class ConservativeMutationOperators:
"""
保守 mutation 算子 - 只做微小扰动保持 RTL 整体结构不变
这些算子产生轻微错误生成的 RTL 更可能与 reference RTL 有部分匹配的行为
"""
@staticmethod
def single_bit_flip(code: str) -> str:
"""
只翻转一个单 bit 信号的值
spi_clk <= 1'b0 -> spi_clk <= 1'b1
busy <= 1 -> busy <= 0
"""
result = code
# 匹配单 bit 赋值: signal <= 1 或 signal <= 1'b0 或 signal <= 1'b1
# 使用 lambda 避免引号转义问题
patterns = [
(r'(\w+)\s*<=\s*1\'b0', lambda m: f'{m.group(1)} <= 1\'b1'),
(r'(\w+)\s*<=\s*1\'b1', lambda m: f'{m.group(1)} <= 1\'b0'),
(r'(\w+)\s*<=\s*1\b', lambda m: f'{m.group(1)} <= 0'),
(r'(\w+)\s*<=\s*0\b', lambda m: f'{m.group(1)} <= 1'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def tiny_offset(code: str) -> str:
"""
给数值加/减一个微小 offset
bit_count <= 8 -> bit_count <= 9 bit_count <= 7
注意只对赋值语句中的数值操作不改变 localparam 等定义
"""
result = code
# 匹配赋值语句中的小数值(排除 localparam 等定义)
# 例如: bit_count <= 3'd0 -> bit_count <= 3'd1
patterns = [
# 3'd0 -> 3'd1 或 3'd7
(r'(bit_count|bit_cnt|counter)\s*<=\s*(\d+)\'d(\d+)',
lambda m: f'{m.group(1)} <= {m.group(2)}\'d{_offset_value(int(m.group(3)))}'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def bit_index_shift(code: str) -> str:
"""
只改变 bit index 1 最保守的改变
shift_reg[7] -> shift_reg[6] shift_reg[8] (如果存在)
注意只偏移1位不翻转整体行为
"""
result = code
# 只偏移 7<->6 这种相邻的 index (针对 SPI 的 MSB first 特点)
patterns = [
(r'(shift_reg|data_in|data_out)\[7\]', r'\1[6]'),
(r'(shift_reg|data_in|data_out)\[6\]', r'\1[7]'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def constant_slight_change(code: str) -> str:
"""
微调常数值仅针对赋值语句不改变 localparam
只改 LSB 一位
"""
result = code
# 8'hXX -> 8'hXY (只改最低位,在赋值语句中)
# 匹配 data_in = 8'hAA 格式
hex_pattern = r'(\w+\s*<=\s*8\'h[0-9a-fA-F]{1,2})([0-9a-fA-F])'
if re.search(hex_pattern, result):
def hex_replacer(m):
prefix = m.group(1)
last_digit = m.group(2)
# 翻转最后一位
new_digit = '1' if last_digit == '0' else '0' if last_digit == '1' else last_digit
return prefix + new_digit
result = re.sub(hex_pattern, hex_replacer, result, count=1)
return result
return code
@staticmethod
def invert_single_signal(code: str) -> str:
"""
只取反单个信号的选择
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
"""
result = code
# 匹配 mosi <= shift_reg[7] 格式并取反
pattern = r'(mosi)\s*<=\s*(\w+\[)(\d+)\]'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
signal = m.group(1)
bracket_part = m.group(2)
idx = m.group(3)
original = m.group(0)
# 替换 mosi <= xxx[7] 为 mosi <= ~xxx[7]
if '~' not in original:
replacement = f'{signal} <= ~{bracket_part}{idx}]'
result = result.replace(original, replacement, 1)
return result
return code
def _offset_value(val: int) -> int:
"""对数值进行微小偏移,只在有效范围内操作"""
# 只偏移 1且只在 1-14 范围内
if 1 <= val <= 14:
return val + random.choice([-1, 1])
elif val > 14:
return val - 1
elif val == 0:
return 1
return val
class MutationOperatorRegistry:
"""Mutation 算子注册表"""
# 保守算子(权重更高,因为它们产生更接近 reference 的 RTL
CONSERVATIVE_OPERATORS = [
("single_bit_flip", ConservativeMutationOperators.single_bit_flip),
("bit_index_shift", ConservativeMutationOperators.bit_index_shift),
("constant_slight_change", ConservativeMutationOperators.constant_slight_change),
("tiny_offset", ConservativeMutationOperators.tiny_offset),
("invert_single_signal", ConservativeMutationOperators.invert_single_signal),
]
# 激进算子(权重较低,因为它们可能完全破坏 RTL 结构)
AGGRESSIVE_OPERATORS = [
("flip_state_condition", FSMMutationOperators.flip_state_condition),
("invert_counter", FSMMutationOperators.invert_counter_update),
("stuck_at_state", FSMMutationOperators.stuck_at_state),
("flip_comparison", FSMMutationOperators.flip_comparison_operator),
("width_mismatch", DataPathMutationOperators.width_mismatch_assignment),
("missing_assignment", DataPathMutationOperators.missing_assignment),
("swap_shift", DataPathMutationOperators.swap_shift_direction),
("invert_bit", DataPathMutationOperators.invert_bit_select),
("wrong_bit_index", DataPathMutationOperators.wrong_bit_index),
("swap_registers", DataPathMutationOperators.swap_registers),
]
# 默认权重:保守算子出现概率更高
DEFAULT_CONSERVATIVE_WEIGHT = 0.7 # 70% 概率选择保守算子
@classmethod
def get_random_operator(cls, conservative_weight: float = None):
"""
随机获取一个 mutation 算子
Args:
conservative_weight: 保守算子的权重 (0.0-1.0)默认 0.7
"""
if conservative_weight is None:
conservative_weight = cls.DEFAULT_CONSERVATIVE_WEIGHT
import random
if random.random() < conservative_weight:
return random.choice(cls.CONSERVATIVE_OPERATORS)
else:
return random.choice(cls.AGGRESSIVE_OPERATORS)
@classmethod
def get_all_operators(cls):
"""获取所有算子"""
return cls.CONSERVATIVE_OPERATORS + cls.AGGRESSIVE_OPERATORS
@classmethod
def get_conservative_operators(cls):
"""只获取保守算子"""
return cls.CONSERVATIVE_OPERATORS
@classmethod
def get_aggressive_operators(cls):
"""只获取激进算子"""
return cls.AGGRESSIVE_OPERATORS