489 lines
16 KiB
Python
489 lines
16 KiB
Python
"""
|
||
Description : RTL Mutation Operators for TB Discriminator
|
||
提供 FSM 和 DataPath 的 mutation 算子,用于生成可控的错误 RTL
|
||
Author : CorrectBench
|
||
"""
|
||
|
||
import re
|
||
import random
|
||
|
||
|
||
class FSMMutationOperators:
|
||
"""针对 FSM 的 mutation 算子"""
|
||
|
||
@staticmethod
|
||
def flip_state_condition(code: str) -> str:
|
||
"""
|
||
翻转状态转移条件中的比较运算符
|
||
if (state == SEND) -> if (state != SEND)
|
||
if (bit_cnt == 0) -> if (bit_cnt != 0)
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 == 的情况
|
||
patterns_eq = [
|
||
(r'\(state\s*==\s*(\w+)\)', r'(state != \1)'),
|
||
(r'\(bit_cnt\s*==\s*0\)', r'(bit_cnt != 0)'),
|
||
(r'\(bit_count\s*==\s*0\)', r'(bit_count != 0)'),
|
||
(r'\((\w+)\s*==\s*(\w+)\)', r'(\1 != \2)'),
|
||
]
|
||
|
||
for pattern, replacement in patterns_eq:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
# 匹配 != 的情况
|
||
patterns_ne = [
|
||
(r'\(state\s*!=\s*(\w+)\)', r'(state == \1)'),
|
||
(r'\(bit_cnt\s*!=\s*0\)', r'(bit_cnt == 0)'),
|
||
(r'\(bit_count\s*!=\s*0\)', r'(bit_count == 0)'),
|
||
]
|
||
|
||
for pattern, replacement in patterns_ne:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code # 没有找到可翻转的条件
|
||
|
||
@staticmethod
|
||
def swap_next_state(code: str) -> str:
|
||
"""
|
||
交换两个分支的下一个状态
|
||
适用于 if-else 结构交换两个分支
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 if-else 结构并交换分支
|
||
# 简化版:找到相邻的两个不同状态转移并交换
|
||
if_blocks = re.finditer(
|
||
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
|
||
result, re.DOTALL
|
||
)
|
||
|
||
matches = list(if_blocks)
|
||
if matches:
|
||
m = random.choice(matches)
|
||
full_match = m.group(0)
|
||
# 交换 then 和 else 分支的赋值
|
||
new_match = re.sub(
|
||
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
|
||
lambda mo: f"if ({m.group(0).split('(')[1].split(')')[0]}) begin {mo.group(3)} <= {mo.group(4)}; end else begin {mo.group(1)} <= {mo.group(2)};",
|
||
full_match, count=1
|
||
)
|
||
result = result.replace(full_match, new_match, 1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def invert_counter_update(code: str) -> str:
|
||
"""
|
||
反转计数器更新
|
||
bit_cnt = bit_cnt - 1 -> bit_cnt = bit_cnt + 1
|
||
bit_cnt <= bit_cnt + 1 -> bit_cnt <= bit_cnt - 1
|
||
"""
|
||
result = code
|
||
|
||
# 匹配减量
|
||
dec_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*-\s*1'
|
||
if re.search(dec_pattern, result):
|
||
result = re.sub(r'(\w+)\s*<=\s*\1\s*-\s*1', r'\1 <= \1 + 1', result, count=1)
|
||
return result
|
||
|
||
# 匹配增量
|
||
inc_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*\+\s*1'
|
||
if re.search(inc_pattern, result):
|
||
result = re.sub(r'(\w+)\s*<=\s*\1\s*\+\s*1', r'\1 <= \1 - 1', result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def stuck_at_state(code: str) -> str:
|
||
"""
|
||
让状态机卡在某个状态
|
||
在状态转移条件中加入 && 0 使条件永远为假
|
||
"""
|
||
result = code
|
||
|
||
# 找到 if (start) 这样的条件,在其中加入 && 0
|
||
if re.search(r'if\s*\(\s*start\s*\)', result):
|
||
result = re.sub(r'if\s*\(\s*start\s*\)', 'if (start && 0)', result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def flip_comparison_operator(code: str) -> str:
|
||
"""
|
||
翻转比较运算符 > < >= <=
|
||
if (bit_cnt > 0) -> if (bit_cnt <= 0)
|
||
"""
|
||
result = code
|
||
|
||
comparisons = [
|
||
(r'\b>\s*0\b', '<= 0'),
|
||
(r'\b<\s*0\b', '> 0'),
|
||
(r'\b>=\s*0\b', '< 0'),
|
||
(r'\b<=\s*0\b', '> 0'),
|
||
]
|
||
|
||
for pattern, replacement in comparisons:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
|
||
class DataPathMutationOperators:
|
||
"""数据通路 mutation 算子"""
|
||
|
||
@staticmethod
|
||
def width_mismatch_assignment(code: str) -> str:
|
||
"""
|
||
生成位宽不匹配的赋值
|
||
对于 output [7:0] data_out,赋值时用 1'b0 而不是 8'b0
|
||
"""
|
||
result = code
|
||
|
||
# 查找 output reg [7:0] data_out 这样的声明
|
||
width_match = re.search(r'output\s+reg\s+\[(\d+):0\]\s+(\w+)', result)
|
||
if width_match:
|
||
width = int(width_match.group(1))
|
||
signal_name = width_match.group(2)
|
||
|
||
# 替换赋值为 1-bit 赋值(故意产生 width mismatch)
|
||
# 查找 data_out <= xxx; 并将 xxx 替换为 1-bit
|
||
pattern = rf'({signal_name}\s*<=\s*)(\d+\'[bB]\d+|[a-fA-F0-9]+h?[a-fA-F0-9]*)'
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, r'\g<1>1\'b0', result, count=1)
|
||
return result
|
||
|
||
# 或者直接替换赋值为 1'b0
|
||
pattern2 = rf'({signal_name}\s*<=\s*)(\w+)'
|
||
if re.search(pattern2, result):
|
||
result = re.sub(pattern2, r'\g<1>1\'b0', result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def missing_assignment(code: str) -> str:
|
||
"""
|
||
移除对某个关键信号的赋值(设为注释)
|
||
通常移除 data_out 的赋值会让 RTL 功能出错
|
||
"""
|
||
result = code
|
||
|
||
# 移除 data_out 或 data 的赋值
|
||
signals_to_remove = ['data_out', 'rx_data', 'shift_reg']
|
||
|
||
for signal in signals_to_remove:
|
||
# 匹配 signal <= xxx; 格式
|
||
pattern = rf'(\s*{signal}\s*<=\s*[^;]+;)'
|
||
if re.search(pattern, result):
|
||
# 替换为空(注释掉)
|
||
result = re.sub(pattern, r'// \1 // MUTATION: missing assignment', result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def swap_shift_direction(code: str) -> str:
|
||
"""
|
||
翻转移位方向
|
||
shift_reg <= {shift_reg[6:0], miso} -> shift_reg <= {miso, shift_reg[7:1]}
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 {a, b} 格式的位移
|
||
patterns = [
|
||
# MSB first: {reg[N-1:0], new_bit} -> {new_bit, reg[N-2:0]}
|
||
(r'(\w+)\s*<=\s*\{\s*(\w+)\[(\d+):0\],?\s*(\w+)\s*\}',
|
||
r'\1 <= {\4, \2[\3-1:0]}'),
|
||
# LSB first 变体
|
||
(r'(\w+)\s*<=\s*\{\s*(\w+),\s*(\w+)\[(\d+):1\]\s*\}',
|
||
r'\1 <= {\3[\4:1], \2}'),
|
||
]
|
||
|
||
for pattern, replacement in patterns:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def invert_bit_select(code: str) -> str:
|
||
"""
|
||
反转位选择
|
||
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 signal[bit_idx] 格式的选择
|
||
pattern = r'(\w+)\s*<=\s*(\w+)\[(\d+)\]'
|
||
matches = list(re.finditer(pattern, result))
|
||
|
||
if matches:
|
||
m = random.choice(matches)
|
||
signal = m.group(1)
|
||
bit_select = m.group(2)
|
||
bit_idx = m.group(3)
|
||
original = m.group(0)
|
||
replacement = f'{signal} <= ~{bit_select}[{bit_idx}]'
|
||
result = result.replace(original, replacement, 1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def wrong_bit_index(code: str) -> str:
|
||
"""
|
||
使用错误的位索引
|
||
mosi <= shift_reg[7] -> mosi <= shift_reg[6]
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 signal[7] 格式并改成 [6]
|
||
patterns = [
|
||
(r'(\w+)\s*<=\s*(\w+)\[7\]', r'\1 <= \2[6]'),
|
||
(r'(\w+)\s*<=\s*(\w+)\[6\]', r'\1 <= \2[7]'),
|
||
(r'(\w+)\s*<=\s*(\w+)\[0\]', r'\1 <= \2[1]'),
|
||
(r'(\w+)\s*<=\s*(\w+)\[1\]', r'\1 <= \2[0]'),
|
||
]
|
||
|
||
for pattern, replacement in patterns:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def swap_registers(code: str) -> str:
|
||
"""
|
||
交换两个寄存器的赋值
|
||
temp_a <= xxx; temp_b <= yyy; -> temp_a <= yyy; temp_b <= xxx;
|
||
"""
|
||
result = code
|
||
|
||
# 查找两个相邻的寄存器赋值
|
||
pattern = r'(\s*(\w+)\s*<=\s*([^;]+);)\s*(\s*(\w+)\s*<=\s*([^;]+);)'
|
||
matches = list(re.finditer(pattern, result))
|
||
|
||
if matches:
|
||
m = random.choice(matches)
|
||
# 交换两个赋值
|
||
new_text = m.group(1).replace(m.group(2), m.group(5)).replace(m.group(3), m.group(6))
|
||
new_text2 = m.group(4).replace(m.group(5), m.group(2)).replace(m.group(6), m.group(3))
|
||
result = result[:m.start()] + new_text + ' ' + new_text2 + result[m.end():]
|
||
return result
|
||
|
||
return code
|
||
|
||
|
||
class ConservativeMutationOperators:
|
||
"""
|
||
保守 mutation 算子 - 只做微小扰动,保持 RTL 整体结构不变
|
||
这些算子产生轻微错误,生成的 RTL 更可能与 reference RTL 有部分匹配的行为
|
||
"""
|
||
|
||
@staticmethod
|
||
def single_bit_flip(code: str) -> str:
|
||
"""
|
||
只翻转一个单 bit 信号的值
|
||
spi_clk <= 1'b0 -> spi_clk <= 1'b1
|
||
busy <= 1 -> busy <= 0
|
||
"""
|
||
result = code
|
||
|
||
# 匹配单 bit 赋值: signal <= 1 或 signal <= 1'b0 或 signal <= 1'b1
|
||
# 使用 lambda 避免引号转义问题
|
||
patterns = [
|
||
(r'(\w+)\s*<=\s*1\'b0', lambda m: f'{m.group(1)} <= 1\'b1'),
|
||
(r'(\w+)\s*<=\s*1\'b1', lambda m: f'{m.group(1)} <= 1\'b0'),
|
||
(r'(\w+)\s*<=\s*1\b', lambda m: f'{m.group(1)} <= 0'),
|
||
(r'(\w+)\s*<=\s*0\b', lambda m: f'{m.group(1)} <= 1'),
|
||
]
|
||
|
||
for pattern, replacement in patterns:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def tiny_offset(code: str) -> str:
|
||
"""
|
||
给数值加/减一个微小 offset
|
||
bit_count <= 8 -> bit_count <= 9 或 bit_count <= 7
|
||
注意:只对赋值语句中的数值操作,不改变 localparam 等定义
|
||
"""
|
||
result = code
|
||
|
||
# 匹配赋值语句中的小数值(排除 localparam 等定义)
|
||
# 例如: bit_count <= 3'd0 -> bit_count <= 3'd1
|
||
patterns = [
|
||
# 3'd0 -> 3'd1 或 3'd7
|
||
(r'(bit_count|bit_cnt|counter)\s*<=\s*(\d+)\'d(\d+)',
|
||
lambda m: f'{m.group(1)} <= {m.group(2)}\'d{_offset_value(int(m.group(3)))}'),
|
||
]
|
||
|
||
for pattern, replacement in patterns:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def bit_index_shift(code: str) -> str:
|
||
"""
|
||
只改变 bit index 1 位(最保守的改变)
|
||
shift_reg[7] -> shift_reg[6] 或 shift_reg[8] (如果存在)
|
||
注意:只偏移1位,不翻转整体行为
|
||
"""
|
||
result = code
|
||
|
||
# 只偏移 7<->6 这种相邻的 index (针对 SPI 的 MSB first 特点)
|
||
patterns = [
|
||
(r'(shift_reg|data_in|data_out)\[7\]', r'\1[6]'),
|
||
(r'(shift_reg|data_in|data_out)\[6\]', r'\1[7]'),
|
||
]
|
||
|
||
for pattern, replacement in patterns:
|
||
if re.search(pattern, result):
|
||
result = re.sub(pattern, replacement, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def constant_slight_change(code: str) -> str:
|
||
"""
|
||
微调常数值(仅针对赋值语句,不改变 localparam)
|
||
只改 LSB 一位
|
||
"""
|
||
result = code
|
||
|
||
# 8'hXX -> 8'hXY (只改最低位,在赋值语句中)
|
||
# 匹配 data_in = 8'hAA 格式
|
||
hex_pattern = r'(\w+\s*<=\s*8\'h[0-9a-fA-F]{1,2})([0-9a-fA-F])'
|
||
if re.search(hex_pattern, result):
|
||
def hex_replacer(m):
|
||
prefix = m.group(1)
|
||
last_digit = m.group(2)
|
||
# 翻转最后一位
|
||
new_digit = '1' if last_digit == '0' else '0' if last_digit == '1' else last_digit
|
||
return prefix + new_digit
|
||
|
||
result = re.sub(hex_pattern, hex_replacer, result, count=1)
|
||
return result
|
||
|
||
return code
|
||
|
||
@staticmethod
|
||
def invert_single_signal(code: str) -> str:
|
||
"""
|
||
只取反单个信号的选择
|
||
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
|
||
"""
|
||
result = code
|
||
|
||
# 匹配 mosi <= shift_reg[7] 格式并取反
|
||
pattern = r'(mosi)\s*<=\s*(\w+\[)(\d+)\]'
|
||
matches = list(re.finditer(pattern, result))
|
||
|
||
if matches:
|
||
m = random.choice(matches)
|
||
signal = m.group(1)
|
||
bracket_part = m.group(2)
|
||
idx = m.group(3)
|
||
original = m.group(0)
|
||
# 替换 mosi <= xxx[7] 为 mosi <= ~xxx[7]
|
||
if '~' not in original:
|
||
replacement = f'{signal} <= ~{bracket_part}{idx}]'
|
||
result = result.replace(original, replacement, 1)
|
||
return result
|
||
|
||
return code
|
||
|
||
|
||
def _offset_value(val: int) -> int:
|
||
"""对数值进行微小偏移,只在有效范围内操作"""
|
||
# 只偏移 1,且只在 1-14 范围内
|
||
if 1 <= val <= 14:
|
||
return val + random.choice([-1, 1])
|
||
elif val > 14:
|
||
return val - 1
|
||
elif val == 0:
|
||
return 1
|
||
return val
|
||
|
||
|
||
class MutationOperatorRegistry:
|
||
"""Mutation 算子注册表"""
|
||
|
||
# 保守算子(权重更高,因为它们产生更接近 reference 的 RTL)
|
||
CONSERVATIVE_OPERATORS = [
|
||
("single_bit_flip", ConservativeMutationOperators.single_bit_flip),
|
||
("bit_index_shift", ConservativeMutationOperators.bit_index_shift),
|
||
("constant_slight_change", ConservativeMutationOperators.constant_slight_change),
|
||
("tiny_offset", ConservativeMutationOperators.tiny_offset),
|
||
("invert_single_signal", ConservativeMutationOperators.invert_single_signal),
|
||
]
|
||
|
||
# 激进算子(权重较低,因为它们可能完全破坏 RTL 结构)
|
||
AGGRESSIVE_OPERATORS = [
|
||
("flip_state_condition", FSMMutationOperators.flip_state_condition),
|
||
("invert_counter", FSMMutationOperators.invert_counter_update),
|
||
("stuck_at_state", FSMMutationOperators.stuck_at_state),
|
||
("flip_comparison", FSMMutationOperators.flip_comparison_operator),
|
||
("width_mismatch", DataPathMutationOperators.width_mismatch_assignment),
|
||
("missing_assignment", DataPathMutationOperators.missing_assignment),
|
||
("swap_shift", DataPathMutationOperators.swap_shift_direction),
|
||
("invert_bit", DataPathMutationOperators.invert_bit_select),
|
||
("wrong_bit_index", DataPathMutationOperators.wrong_bit_index),
|
||
("swap_registers", DataPathMutationOperators.swap_registers),
|
||
]
|
||
|
||
# 默认权重:保守算子出现概率更高
|
||
DEFAULT_CONSERVATIVE_WEIGHT = 0.7 # 70% 概率选择保守算子
|
||
|
||
@classmethod
|
||
def get_random_operator(cls, conservative_weight: float = None):
|
||
"""
|
||
随机获取一个 mutation 算子
|
||
|
||
Args:
|
||
conservative_weight: 保守算子的权重 (0.0-1.0),默认 0.7
|
||
"""
|
||
if conservative_weight is None:
|
||
conservative_weight = cls.DEFAULT_CONSERVATIVE_WEIGHT
|
||
|
||
import random
|
||
if random.random() < conservative_weight:
|
||
return random.choice(cls.CONSERVATIVE_OPERATORS)
|
||
else:
|
||
return random.choice(cls.AGGRESSIVE_OPERATORS)
|
||
|
||
@classmethod
|
||
def get_all_operators(cls):
|
||
"""获取所有算子"""
|
||
return cls.CONSERVATIVE_OPERATORS + cls.AGGRESSIVE_OPERATORS
|
||
|
||
@classmethod
|
||
def get_conservative_operators(cls):
|
||
"""只获取保守算子"""
|
||
return cls.CONSERVATIVE_OPERATORS
|
||
|
||
@classmethod
|
||
def get_aggressive_operators(cls):
|
||
"""只获取激进算子"""
|
||
return cls.AGGRESSIVE_OPERATORS
|