""" Description : RTL Mutator - 应用 mutation 算子生成错误 RTL 用于替代 LLM 自由生成,确保语法正确且错误模式可控 Author : CorrectBench """ import random import re from loader_saver import autologger as logger from .mutation_operators import MutationOperatorRegistry, FSMMutationOperators, DataPathMutationOperators class RTLMutator: """ 将 mutation 算子应用到参考 RTL,生成可控的错误 RTL """ def __init__(self, reference_rtl: str, seed: int = None): """ Args: reference_rtl: 正确的参考 RTL 代码 seed: 随机种子(用于可重复性) """ self.reference_rtl = reference_rtl if seed is not None: random.seed(seed) def _safe_log(self, level: str, message: str): """安全地记录日志,处理 logger 未初始化的情况""" try: if hasattr(logger, 'logger') and logger.logger is not None: if level == 'debug': logger.debug(message) elif level == 'warning': logger.warning(message) else: logger.info(message) except: print(f"[RTLMutator {level.upper()}] {message}") def generate_mutations(self, num: int, conservative_weight: float = None) -> list[str]: """ 生成指定数量的 mutated RTL Args: num: 要生成的 mutated RTL 数量 conservative_weight: 保守算子权重 (0.0-1.0),默认 0.7 值越高,生成的 RTL 越接近原始代码 Returns: list of mutated RTL codes """ if conservative_weight is None: conservative_weight = MutationOperatorRegistry.DEFAULT_CONSERVATIVE_WEIGHT mutations = [] attempts = 0 max_attempts = num * 3 # 尝试次数上限 while len(mutations) < num and attempts < max_attempts: attempts += 1 # 随机选择一个 mutation 算子(带保守权重) name, operator = MutationOperatorRegistry.get_random_operator(conservative_weight) try: mutated = operator(self.reference_rtl) # 检查 mutation 是否真的产生了变化 if mutated and mutated != self.reference_rtl: # 验证语法基本正确(不是完全破坏) if self._validate_mutation(mutated): mutations.append(mutated) self._safe_log('debug', f"Applied mutation: {name}") except Exception as e: self._safe_log('debug', f"Mutation failed: {name}, error: {e}") continue if len(mutations) < num: self._safe_log('warning', f"Only generated {len(mutations)}/{num} mutations") return mutations def _validate_mutation(self, mutated_code: str) -> bool: """ 验证 mutation 后的代码基本语法正确 Args: mutated_code: mutation 后的代码 Returns: True if valid, False otherwise """ # 基本语法检查 if not mutated_code or len(mutated_code) < 50: return False # 检查是否有 module 关键字 if 'module' not in mutated_code: return False # 检查是否有 endmodule if 'endmodule' not in mutated_code: return False # 检查是否有基本的 always 块或 assign 语句 if 'always' not in mutated_code and 'assign' not in mutated_code: return False # 检查括号是否平衡 open_count = mutated_code.count('(') close_count = mutated_code.count(')') if open_count != close_count: return False # 检查 begin-end 是否平衡 begin_count = len(re.findall(r'\bbegin\b', mutated_code)) end_count = len(re.findall(r'\bend\b', mutated_code)) if abs(begin_count - end_count) > 1: # 允许差1,因为可能有注释 return False return True def generate_targeted_mutations(self, num: int, target_type: str) -> list[str]: """ 生成特定类型的 mutation Args: num: 要生成的 mutated RTL 数量 target_type: 'fsm', 'datapath', 或 'all' Returns: list of mutated RTL codes """ if target_type == 'fsm': operators = [ FSMMutationOperators.flip_state_condition, FSMMutationOperators.swap_next_state, FSMMutationOperators.invert_counter_update, FSMMutationOperators.stuck_at_state, FSMMutationOperators.flip_comparison_operator, ] elif target_type == 'datapath': operators = [ DataPathMutationOperators.width_mismatch_assignment, DataPathMutationOperators.missing_assignment, DataPathMutationOperators.swap_shift_direction, DataPathMutationOperators.invert_bit_select, DataPathMutationOperators.wrong_bit_index, DataPathMutationOperators.swap_registers, ] else: operators = [op for _, op in MutationOperatorRegistry.get_all_operators()] mutations = [] attempts = 0 max_attempts = num * 5 while len(mutations) < num and attempts < max_attempts: attempts += 1 operator = random.choice(operators) try: mutated = operator(self.reference_rtl) if mutated and mutated != self.reference_rtl and self._validate_mutation(mutated): mutations.append(mutated) except: continue return mutations def apply_multiple_mutations(self, num_mutations: int = 2) -> str: """ 在同一个 RTL 上应用多个 mutation Args: num_mutations: 要应用的 mutation 数量 Returns: mutated RTL code """ current = self.reference_rtl operators = [op for _, op in MutationOperatorRegistry.get_all_operators()] for _ in range(num_mutations): operator = random.choice(operators) try: mutated = operator(current) if mutated and mutated != current and self._validate_mutation(mutated): current = mutated except: continue return current def generate_mutation_set(reference_rtl: str, num: int, seed: int = None) -> list[str]: """ 便捷函数:生成一组 mutation RTL Args: reference_rtl: 参考 RTL num: 数量 seed: 随机种子 Returns: list of mutated RTL codes """ mutator = RTLMutator(reference_rtl, seed=seed) return mutator.generate_mutations(num)