Files
CGA-bench/autoline/rtl_mutator.py

211 lines
6.9 KiB
Python
Raw Permalink Normal View History

2026-05-22 10:02:42 +08:00
"""
Description : RTL Mutator - 应用 mutation 算子生成错误 RTL
用于替代 LLM 自由生成确保语法正确且错误模式可控
Author : CorrectBench
"""
import random
import re
from loader_saver import autologger as logger
from .mutation_operators import MutationOperatorRegistry, FSMMutationOperators, DataPathMutationOperators
class RTLMutator:
"""
mutation 算子应用到参考 RTL生成可控的错误 RTL
"""
def __init__(self, reference_rtl: str, seed: int = None):
"""
Args:
reference_rtl: 正确的参考 RTL 代码
seed: 随机种子用于可重复性
"""
self.reference_rtl = reference_rtl
if seed is not None:
random.seed(seed)
def _safe_log(self, level: str, message: str):
"""安全地记录日志,处理 logger 未初始化的情况"""
try:
if hasattr(logger, 'logger') and logger.logger is not None:
if level == 'debug':
logger.debug(message)
elif level == 'warning':
logger.warning(message)
else:
logger.info(message)
except:
print(f"[RTLMutator {level.upper()}] {message}")
def generate_mutations(self, num: int, conservative_weight: float = None) -> list[str]:
"""
生成指定数量的 mutated RTL
Args:
num: 要生成的 mutated RTL 数量
conservative_weight: 保守算子权重 (0.0-1.0)默认 0.7
值越高生成的 RTL 越接近原始代码
Returns:
list of mutated RTL codes
"""
if conservative_weight is None:
conservative_weight = MutationOperatorRegistry.DEFAULT_CONSERVATIVE_WEIGHT
mutations = []
attempts = 0
max_attempts = num * 3 # 尝试次数上限
while len(mutations) < num and attempts < max_attempts:
attempts += 1
# 随机选择一个 mutation 算子(带保守权重)
name, operator = MutationOperatorRegistry.get_random_operator(conservative_weight)
try:
mutated = operator(self.reference_rtl)
# 检查 mutation 是否真的产生了变化
if mutated and mutated != self.reference_rtl:
# 验证语法基本正确(不是完全破坏)
if self._validate_mutation(mutated):
mutations.append(mutated)
self._safe_log('debug', f"Applied mutation: {name}")
except Exception as e:
self._safe_log('debug', f"Mutation failed: {name}, error: {e}")
continue
if len(mutations) < num:
self._safe_log('warning', f"Only generated {len(mutations)}/{num} mutations")
return mutations
def _validate_mutation(self, mutated_code: str) -> bool:
"""
验证 mutation 后的代码基本语法正确
Args:
mutated_code: mutation 后的代码
Returns:
True if valid, False otherwise
"""
# 基本语法检查
if not mutated_code or len(mutated_code) < 50:
return False
# 检查是否有 module 关键字
if 'module' not in mutated_code:
return False
# 检查是否有 endmodule
if 'endmodule' not in mutated_code:
return False
# 检查是否有基本的 always 块或 assign 语句
if 'always' not in mutated_code and 'assign' not in mutated_code:
return False
# 检查括号是否平衡
open_count = mutated_code.count('(')
close_count = mutated_code.count(')')
if open_count != close_count:
return False
# 检查 begin-end 是否平衡
begin_count = len(re.findall(r'\bbegin\b', mutated_code))
end_count = len(re.findall(r'\bend\b', mutated_code))
if abs(begin_count - end_count) > 1: # 允许差1因为可能有注释
return False
return True
def generate_targeted_mutations(self, num: int, target_type: str) -> list[str]:
"""
生成特定类型的 mutation
Args:
num: 要生成的 mutated RTL 数量
target_type: 'fsm', 'datapath', 'all'
Returns:
list of mutated RTL codes
"""
if target_type == 'fsm':
operators = [
FSMMutationOperators.flip_state_condition,
FSMMutationOperators.swap_next_state,
FSMMutationOperators.invert_counter_update,
FSMMutationOperators.stuck_at_state,
FSMMutationOperators.flip_comparison_operator,
]
elif target_type == 'datapath':
operators = [
DataPathMutationOperators.width_mismatch_assignment,
DataPathMutationOperators.missing_assignment,
DataPathMutationOperators.swap_shift_direction,
DataPathMutationOperators.invert_bit_select,
DataPathMutationOperators.wrong_bit_index,
DataPathMutationOperators.swap_registers,
]
else:
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
mutations = []
attempts = 0
max_attempts = num * 5
while len(mutations) < num and attempts < max_attempts:
attempts += 1
operator = random.choice(operators)
try:
mutated = operator(self.reference_rtl)
if mutated and mutated != self.reference_rtl and self._validate_mutation(mutated):
mutations.append(mutated)
except:
continue
return mutations
def apply_multiple_mutations(self, num_mutations: int = 2) -> str:
"""
在同一个 RTL 上应用多个 mutation
Args:
num_mutations: 要应用的 mutation 数量
Returns:
mutated RTL code
"""
current = self.reference_rtl
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
for _ in range(num_mutations):
operator = random.choice(operators)
try:
mutated = operator(current)
if mutated and mutated != current and self._validate_mutation(mutated):
current = mutated
except:
continue
return current
def generate_mutation_set(reference_rtl: str, num: int, seed: int = None) -> list[str]:
"""
便捷函数生成一组 mutation RTL
Args:
reference_rtl: 参考 RTL
num: 数量
seed: 随机种子
Returns:
list of mutated RTL codes
"""
mutator = RTLMutator(reference_rtl, seed=seed)
return mutator.generate_mutations(num)