Files
CGA-bench/autoline/mutation_operators.py
2026-05-22 10:02:42 +08:00

489 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Description : RTL Mutation Operators for TB Discriminator
提供 FSM 和 DataPath 的 mutation 算子,用于生成可控的错误 RTL
Author : CorrectBench
"""
import re
import random
class FSMMutationOperators:
"""针对 FSM 的 mutation 算子"""
@staticmethod
def flip_state_condition(code: str) -> str:
"""
翻转状态转移条件中的比较运算符
if (state == SEND) -> if (state != SEND)
if (bit_cnt == 0) -> if (bit_cnt != 0)
"""
result = code
# 匹配 == 的情况
patterns_eq = [
(r'\(state\s*==\s*(\w+)\)', r'(state != \1)'),
(r'\(bit_cnt\s*==\s*0\)', r'(bit_cnt != 0)'),
(r'\(bit_count\s*==\s*0\)', r'(bit_count != 0)'),
(r'\((\w+)\s*==\s*(\w+)\)', r'(\1 != \2)'),
]
for pattern, replacement in patterns_eq:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
# 匹配 != 的情况
patterns_ne = [
(r'\(state\s*!=\s*(\w+)\)', r'(state == \1)'),
(r'\(bit_cnt\s*!=\s*0\)', r'(bit_cnt == 0)'),
(r'\(bit_count\s*!=\s*0\)', r'(bit_count == 0)'),
]
for pattern, replacement in patterns_ne:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code # 没有找到可翻转的条件
@staticmethod
def swap_next_state(code: str) -> str:
"""
交换两个分支的下一个状态
适用于 if-else 结构交换两个分支
"""
result = code
# 匹配 if-else 结构并交换分支
# 简化版:找到相邻的两个不同状态转移并交换
if_blocks = re.finditer(
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
result, re.DOTALL
)
matches = list(if_blocks)
if matches:
m = random.choice(matches)
full_match = m.group(0)
# 交换 then 和 else 分支的赋值
new_match = re.sub(
r'if\s*\([^)]+\)\s*begin\s*(\w+)\s*<=\s*(\w+);\s*end\s*else\s*begin\s*(\w+)\s*<=\s*(\w+);',
lambda mo: f"if ({m.group(0).split('(')[1].split(')')[0]}) begin {mo.group(3)} <= {mo.group(4)}; end else begin {mo.group(1)} <= {mo.group(2)};",
full_match, count=1
)
result = result.replace(full_match, new_match, 1)
return result
return code
@staticmethod
def invert_counter_update(code: str) -> str:
"""
反转计数器更新
bit_cnt = bit_cnt - 1 -> bit_cnt = bit_cnt + 1
bit_cnt <= bit_cnt + 1 -> bit_cnt <= bit_cnt - 1
"""
result = code
# 匹配减量
dec_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*-\s*1'
if re.search(dec_pattern, result):
result = re.sub(r'(\w+)\s*<=\s*\1\s*-\s*1', r'\1 <= \1 + 1', result, count=1)
return result
# 匹配增量
inc_pattern = r'(bit_cnt|bit_count|counter)\s*<=\s*\1\s*\+\s*1'
if re.search(inc_pattern, result):
result = re.sub(r'(\w+)\s*<=\s*\1\s*\+\s*1', r'\1 <= \1 - 1', result, count=1)
return result
return code
@staticmethod
def stuck_at_state(code: str) -> str:
"""
让状态机卡在某个状态
在状态转移条件中加入 && 0 使条件永远为假
"""
result = code
# 找到 if (start) 这样的条件,在其中加入 && 0
if re.search(r'if\s*\(\s*start\s*\)', result):
result = re.sub(r'if\s*\(\s*start\s*\)', 'if (start && 0)', result, count=1)
return result
return code
@staticmethod
def flip_comparison_operator(code: str) -> str:
"""
翻转比较运算符 > < >= <=
if (bit_cnt > 0) -> if (bit_cnt <= 0)
"""
result = code
comparisons = [
(r'\b>\s*0\b', '<= 0'),
(r'\b<\s*0\b', '> 0'),
(r'\b>=\s*0\b', '< 0'),
(r'\b<=\s*0\b', '> 0'),
]
for pattern, replacement in comparisons:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
class DataPathMutationOperators:
"""数据通路 mutation 算子"""
@staticmethod
def width_mismatch_assignment(code: str) -> str:
"""
生成位宽不匹配的赋值
对于 output [7:0] data_out赋值时用 1'b0 而不是 8'b0
"""
result = code
# 查找 output reg [7:0] data_out 这样的声明
width_match = re.search(r'output\s+reg\s+\[(\d+):0\]\s+(\w+)', result)
if width_match:
width = int(width_match.group(1))
signal_name = width_match.group(2)
# 替换赋值为 1-bit 赋值(故意产生 width mismatch
# 查找 data_out <= xxx; 并将 xxx 替换为 1-bit
pattern = rf'({signal_name}\s*<=\s*)(\d+\'[bB]\d+|[a-fA-F0-9]+h?[a-fA-F0-9]*)'
if re.search(pattern, result):
result = re.sub(pattern, r'\g<1>1\'b0', result, count=1)
return result
# 或者直接替换赋值为 1'b0
pattern2 = rf'({signal_name}\s*<=\s*)(\w+)'
if re.search(pattern2, result):
result = re.sub(pattern2, r'\g<1>1\'b0', result, count=1)
return result
return code
@staticmethod
def missing_assignment(code: str) -> str:
"""
移除对某个关键信号的赋值(设为注释)
通常移除 data_out 的赋值会让 RTL 功能出错
"""
result = code
# 移除 data_out 或 data 的赋值
signals_to_remove = ['data_out', 'rx_data', 'shift_reg']
for signal in signals_to_remove:
# 匹配 signal <= xxx; 格式
pattern = rf'(\s*{signal}\s*<=\s*[^;]+;)'
if re.search(pattern, result):
# 替换为空(注释掉)
result = re.sub(pattern, r'// \1 // MUTATION: missing assignment', result, count=1)
return result
return code
@staticmethod
def swap_shift_direction(code: str) -> str:
"""
翻转移位方向
shift_reg <= {shift_reg[6:0], miso} -> shift_reg <= {miso, shift_reg[7:1]}
"""
result = code
# 匹配 {a, b} 格式的位移
patterns = [
# MSB first: {reg[N-1:0], new_bit} -> {new_bit, reg[N-2:0]}
(r'(\w+)\s*<=\s*\{\s*(\w+)\[(\d+):0\],?\s*(\w+)\s*\}',
r'\1 <= {\4, \2[\3-1:0]}'),
# LSB first 变体
(r'(\w+)\s*<=\s*\{\s*(\w+),\s*(\w+)\[(\d+):1\]\s*\}',
r'\1 <= {\3[\4:1], \2}'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def invert_bit_select(code: str) -> str:
"""
反转位选择
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
"""
result = code
# 匹配 signal[bit_idx] 格式的选择
pattern = r'(\w+)\s*<=\s*(\w+)\[(\d+)\]'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
signal = m.group(1)
bit_select = m.group(2)
bit_idx = m.group(3)
original = m.group(0)
replacement = f'{signal} <= ~{bit_select}[{bit_idx}]'
result = result.replace(original, replacement, 1)
return result
return code
@staticmethod
def wrong_bit_index(code: str) -> str:
"""
使用错误的位索引
mosi <= shift_reg[7] -> mosi <= shift_reg[6]
"""
result = code
# 匹配 signal[7] 格式并改成 [6]
patterns = [
(r'(\w+)\s*<=\s*(\w+)\[7\]', r'\1 <= \2[6]'),
(r'(\w+)\s*<=\s*(\w+)\[6\]', r'\1 <= \2[7]'),
(r'(\w+)\s*<=\s*(\w+)\[0\]', r'\1 <= \2[1]'),
(r'(\w+)\s*<=\s*(\w+)\[1\]', r'\1 <= \2[0]'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def swap_registers(code: str) -> str:
"""
交换两个寄存器的赋值
temp_a <= xxx; temp_b <= yyy; -> temp_a <= yyy; temp_b <= xxx;
"""
result = code
# 查找两个相邻的寄存器赋值
pattern = r'(\s*(\w+)\s*<=\s*([^;]+);)\s*(\s*(\w+)\s*<=\s*([^;]+);)'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
# 交换两个赋值
new_text = m.group(1).replace(m.group(2), m.group(5)).replace(m.group(3), m.group(6))
new_text2 = m.group(4).replace(m.group(5), m.group(2)).replace(m.group(6), m.group(3))
result = result[:m.start()] + new_text + ' ' + new_text2 + result[m.end():]
return result
return code
class ConservativeMutationOperators:
"""
保守 mutation 算子 - 只做微小扰动,保持 RTL 整体结构不变
这些算子产生轻微错误,生成的 RTL 更可能与 reference RTL 有部分匹配的行为
"""
@staticmethod
def single_bit_flip(code: str) -> str:
"""
只翻转一个单 bit 信号的值
spi_clk <= 1'b0 -> spi_clk <= 1'b1
busy <= 1 -> busy <= 0
"""
result = code
# 匹配单 bit 赋值: signal <= 1 或 signal <= 1'b0 或 signal <= 1'b1
# 使用 lambda 避免引号转义问题
patterns = [
(r'(\w+)\s*<=\s*1\'b0', lambda m: f'{m.group(1)} <= 1\'b1'),
(r'(\w+)\s*<=\s*1\'b1', lambda m: f'{m.group(1)} <= 1\'b0'),
(r'(\w+)\s*<=\s*1\b', lambda m: f'{m.group(1)} <= 0'),
(r'(\w+)\s*<=\s*0\b', lambda m: f'{m.group(1)} <= 1'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def tiny_offset(code: str) -> str:
"""
给数值加/减一个微小 offset
bit_count <= 8 -> bit_count <= 9 或 bit_count <= 7
注意:只对赋值语句中的数值操作,不改变 localparam 等定义
"""
result = code
# 匹配赋值语句中的小数值(排除 localparam 等定义)
# 例如: bit_count <= 3'd0 -> bit_count <= 3'd1
patterns = [
# 3'd0 -> 3'd1 或 3'd7
(r'(bit_count|bit_cnt|counter)\s*<=\s*(\d+)\'d(\d+)',
lambda m: f'{m.group(1)} <= {m.group(2)}\'d{_offset_value(int(m.group(3)))}'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def bit_index_shift(code: str) -> str:
"""
只改变 bit index 1 位(最保守的改变)
shift_reg[7] -> shift_reg[6] 或 shift_reg[8] (如果存在)
注意只偏移1位不翻转整体行为
"""
result = code
# 只偏移 7<->6 这种相邻的 index (针对 SPI 的 MSB first 特点)
patterns = [
(r'(shift_reg|data_in|data_out)\[7\]', r'\1[6]'),
(r'(shift_reg|data_in|data_out)\[6\]', r'\1[7]'),
]
for pattern, replacement in patterns:
if re.search(pattern, result):
result = re.sub(pattern, replacement, result, count=1)
return result
return code
@staticmethod
def constant_slight_change(code: str) -> str:
"""
微调常数值(仅针对赋值语句,不改变 localparam
只改 LSB 一位
"""
result = code
# 8'hXX -> 8'hXY (只改最低位,在赋值语句中)
# 匹配 data_in = 8'hAA 格式
hex_pattern = r'(\w+\s*<=\s*8\'h[0-9a-fA-F]{1,2})([0-9a-fA-F])'
if re.search(hex_pattern, result):
def hex_replacer(m):
prefix = m.group(1)
last_digit = m.group(2)
# 翻转最后一位
new_digit = '1' if last_digit == '0' else '0' if last_digit == '1' else last_digit
return prefix + new_digit
result = re.sub(hex_pattern, hex_replacer, result, count=1)
return result
return code
@staticmethod
def invert_single_signal(code: str) -> str:
"""
只取反单个信号的选择
mosi <= shift_reg[7] -> mosi <= ~shift_reg[7]
"""
result = code
# 匹配 mosi <= shift_reg[7] 格式并取反
pattern = r'(mosi)\s*<=\s*(\w+\[)(\d+)\]'
matches = list(re.finditer(pattern, result))
if matches:
m = random.choice(matches)
signal = m.group(1)
bracket_part = m.group(2)
idx = m.group(3)
original = m.group(0)
# 替换 mosi <= xxx[7] 为 mosi <= ~xxx[7]
if '~' not in original:
replacement = f'{signal} <= ~{bracket_part}{idx}]'
result = result.replace(original, replacement, 1)
return result
return code
def _offset_value(val: int) -> int:
"""对数值进行微小偏移,只在有效范围内操作"""
# 只偏移 1且只在 1-14 范围内
if 1 <= val <= 14:
return val + random.choice([-1, 1])
elif val > 14:
return val - 1
elif val == 0:
return 1
return val
class MutationOperatorRegistry:
"""Mutation 算子注册表"""
# 保守算子(权重更高,因为它们产生更接近 reference 的 RTL
CONSERVATIVE_OPERATORS = [
("single_bit_flip", ConservativeMutationOperators.single_bit_flip),
("bit_index_shift", ConservativeMutationOperators.bit_index_shift),
("constant_slight_change", ConservativeMutationOperators.constant_slight_change),
("tiny_offset", ConservativeMutationOperators.tiny_offset),
("invert_single_signal", ConservativeMutationOperators.invert_single_signal),
]
# 激进算子(权重较低,因为它们可能完全破坏 RTL 结构)
AGGRESSIVE_OPERATORS = [
("flip_state_condition", FSMMutationOperators.flip_state_condition),
("invert_counter", FSMMutationOperators.invert_counter_update),
("stuck_at_state", FSMMutationOperators.stuck_at_state),
("flip_comparison", FSMMutationOperators.flip_comparison_operator),
("width_mismatch", DataPathMutationOperators.width_mismatch_assignment),
("missing_assignment", DataPathMutationOperators.missing_assignment),
("swap_shift", DataPathMutationOperators.swap_shift_direction),
("invert_bit", DataPathMutationOperators.invert_bit_select),
("wrong_bit_index", DataPathMutationOperators.wrong_bit_index),
("swap_registers", DataPathMutationOperators.swap_registers),
]
# 默认权重:保守算子出现概率更高
DEFAULT_CONSERVATIVE_WEIGHT = 0.7 # 70% 概率选择保守算子
@classmethod
def get_random_operator(cls, conservative_weight: float = None):
"""
随机获取一个 mutation 算子
Args:
conservative_weight: 保守算子的权重 (0.0-1.0),默认 0.7
"""
if conservative_weight is None:
conservative_weight = cls.DEFAULT_CONSERVATIVE_WEIGHT
import random
if random.random() < conservative_weight:
return random.choice(cls.CONSERVATIVE_OPERATORS)
else:
return random.choice(cls.AGGRESSIVE_OPERATORS)
@classmethod
def get_all_operators(cls):
"""获取所有算子"""
return cls.CONSERVATIVE_OPERATORS + cls.AGGRESSIVE_OPERATORS
@classmethod
def get_conservative_operators(cls):
"""只获取保守算子"""
return cls.CONSERVATIVE_OPERATORS
@classmethod
def get_aggressive_operators(cls):
"""只获取激进算子"""
return cls.AGGRESSIVE_OPERATORS