Files
CGA-bench/autoline/rtl_mutator.py
2026-05-22 10:02:42 +08:00

211 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Description : RTL Mutator - 应用 mutation 算子生成错误 RTL
用于替代 LLM 自由生成,确保语法正确且错误模式可控
Author : CorrectBench
"""
import random
import re
from loader_saver import autologger as logger
from .mutation_operators import MutationOperatorRegistry, FSMMutationOperators, DataPathMutationOperators
class RTLMutator:
"""
将 mutation 算子应用到参考 RTL生成可控的错误 RTL
"""
def __init__(self, reference_rtl: str, seed: int = None):
"""
Args:
reference_rtl: 正确的参考 RTL 代码
seed: 随机种子(用于可重复性)
"""
self.reference_rtl = reference_rtl
if seed is not None:
random.seed(seed)
def _safe_log(self, level: str, message: str):
"""安全地记录日志,处理 logger 未初始化的情况"""
try:
if hasattr(logger, 'logger') and logger.logger is not None:
if level == 'debug':
logger.debug(message)
elif level == 'warning':
logger.warning(message)
else:
logger.info(message)
except:
print(f"[RTLMutator {level.upper()}] {message}")
def generate_mutations(self, num: int, conservative_weight: float = None) -> list[str]:
"""
生成指定数量的 mutated RTL
Args:
num: 要生成的 mutated RTL 数量
conservative_weight: 保守算子权重 (0.0-1.0),默认 0.7
值越高,生成的 RTL 越接近原始代码
Returns:
list of mutated RTL codes
"""
if conservative_weight is None:
conservative_weight = MutationOperatorRegistry.DEFAULT_CONSERVATIVE_WEIGHT
mutations = []
attempts = 0
max_attempts = num * 3 # 尝试次数上限
while len(mutations) < num and attempts < max_attempts:
attempts += 1
# 随机选择一个 mutation 算子(带保守权重)
name, operator = MutationOperatorRegistry.get_random_operator(conservative_weight)
try:
mutated = operator(self.reference_rtl)
# 检查 mutation 是否真的产生了变化
if mutated and mutated != self.reference_rtl:
# 验证语法基本正确(不是完全破坏)
if self._validate_mutation(mutated):
mutations.append(mutated)
self._safe_log('debug', f"Applied mutation: {name}")
except Exception as e:
self._safe_log('debug', f"Mutation failed: {name}, error: {e}")
continue
if len(mutations) < num:
self._safe_log('warning', f"Only generated {len(mutations)}/{num} mutations")
return mutations
def _validate_mutation(self, mutated_code: str) -> bool:
"""
验证 mutation 后的代码基本语法正确
Args:
mutated_code: mutation 后的代码
Returns:
True if valid, False otherwise
"""
# 基本语法检查
if not mutated_code or len(mutated_code) < 50:
return False
# 检查是否有 module 关键字
if 'module' not in mutated_code:
return False
# 检查是否有 endmodule
if 'endmodule' not in mutated_code:
return False
# 检查是否有基本的 always 块或 assign 语句
if 'always' not in mutated_code and 'assign' not in mutated_code:
return False
# 检查括号是否平衡
open_count = mutated_code.count('(')
close_count = mutated_code.count(')')
if open_count != close_count:
return False
# 检查 begin-end 是否平衡
begin_count = len(re.findall(r'\bbegin\b', mutated_code))
end_count = len(re.findall(r'\bend\b', mutated_code))
if abs(begin_count - end_count) > 1: # 允许差1因为可能有注释
return False
return True
def generate_targeted_mutations(self, num: int, target_type: str) -> list[str]:
"""
生成特定类型的 mutation
Args:
num: 要生成的 mutated RTL 数量
target_type: 'fsm', 'datapath', 或 'all'
Returns:
list of mutated RTL codes
"""
if target_type == 'fsm':
operators = [
FSMMutationOperators.flip_state_condition,
FSMMutationOperators.swap_next_state,
FSMMutationOperators.invert_counter_update,
FSMMutationOperators.stuck_at_state,
FSMMutationOperators.flip_comparison_operator,
]
elif target_type == 'datapath':
operators = [
DataPathMutationOperators.width_mismatch_assignment,
DataPathMutationOperators.missing_assignment,
DataPathMutationOperators.swap_shift_direction,
DataPathMutationOperators.invert_bit_select,
DataPathMutationOperators.wrong_bit_index,
DataPathMutationOperators.swap_registers,
]
else:
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
mutations = []
attempts = 0
max_attempts = num * 5
while len(mutations) < num and attempts < max_attempts:
attempts += 1
operator = random.choice(operators)
try:
mutated = operator(self.reference_rtl)
if mutated and mutated != self.reference_rtl and self._validate_mutation(mutated):
mutations.append(mutated)
except:
continue
return mutations
def apply_multiple_mutations(self, num_mutations: int = 2) -> str:
"""
在同一个 RTL 上应用多个 mutation
Args:
num_mutations: 要应用的 mutation 数量
Returns:
mutated RTL code
"""
current = self.reference_rtl
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
for _ in range(num_mutations):
operator = random.choice(operators)
try:
mutated = operator(current)
if mutated and mutated != current and self._validate_mutation(mutated):
current = mutated
except:
continue
return current
def generate_mutation_set(reference_rtl: str, num: int, seed: int = None) -> list[str]:
"""
便捷函数:生成一组 mutation RTL
Args:
reference_rtl: 参考 RTL
num: 数量
seed: 随机种子
Returns:
list of mutated RTL codes
"""
mutator = RTLMutator(reference_rtl, seed=seed)
return mutator.generate_mutations(num)