211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
"""
|
||
Description : RTL Mutator - 应用 mutation 算子生成错误 RTL
|
||
用于替代 LLM 自由生成,确保语法正确且错误模式可控
|
||
Author : CorrectBench
|
||
"""
|
||
|
||
import random
|
||
import re
|
||
from loader_saver import autologger as logger
|
||
from .mutation_operators import MutationOperatorRegistry, FSMMutationOperators, DataPathMutationOperators
|
||
|
||
|
||
class RTLMutator:
|
||
"""
|
||
将 mutation 算子应用到参考 RTL,生成可控的错误 RTL
|
||
"""
|
||
|
||
def __init__(self, reference_rtl: str, seed: int = None):
|
||
"""
|
||
Args:
|
||
reference_rtl: 正确的参考 RTL 代码
|
||
seed: 随机种子(用于可重复性)
|
||
"""
|
||
self.reference_rtl = reference_rtl
|
||
if seed is not None:
|
||
random.seed(seed)
|
||
|
||
def _safe_log(self, level: str, message: str):
|
||
"""安全地记录日志,处理 logger 未初始化的情况"""
|
||
try:
|
||
if hasattr(logger, 'logger') and logger.logger is not None:
|
||
if level == 'debug':
|
||
logger.debug(message)
|
||
elif level == 'warning':
|
||
logger.warning(message)
|
||
else:
|
||
logger.info(message)
|
||
except:
|
||
print(f"[RTLMutator {level.upper()}] {message}")
|
||
|
||
def generate_mutations(self, num: int, conservative_weight: float = None) -> list[str]:
|
||
"""
|
||
生成指定数量的 mutated RTL
|
||
|
||
Args:
|
||
num: 要生成的 mutated RTL 数量
|
||
conservative_weight: 保守算子权重 (0.0-1.0),默认 0.7
|
||
值越高,生成的 RTL 越接近原始代码
|
||
|
||
Returns:
|
||
list of mutated RTL codes
|
||
"""
|
||
if conservative_weight is None:
|
||
conservative_weight = MutationOperatorRegistry.DEFAULT_CONSERVATIVE_WEIGHT
|
||
|
||
mutations = []
|
||
attempts = 0
|
||
max_attempts = num * 3 # 尝试次数上限
|
||
|
||
while len(mutations) < num and attempts < max_attempts:
|
||
attempts += 1
|
||
|
||
# 随机选择一个 mutation 算子(带保守权重)
|
||
name, operator = MutationOperatorRegistry.get_random_operator(conservative_weight)
|
||
|
||
try:
|
||
mutated = operator(self.reference_rtl)
|
||
|
||
# 检查 mutation 是否真的产生了变化
|
||
if mutated and mutated != self.reference_rtl:
|
||
# 验证语法基本正确(不是完全破坏)
|
||
if self._validate_mutation(mutated):
|
||
mutations.append(mutated)
|
||
self._safe_log('debug', f"Applied mutation: {name}")
|
||
except Exception as e:
|
||
self._safe_log('debug', f"Mutation failed: {name}, error: {e}")
|
||
continue
|
||
|
||
if len(mutations) < num:
|
||
self._safe_log('warning', f"Only generated {len(mutations)}/{num} mutations")
|
||
|
||
return mutations
|
||
|
||
def _validate_mutation(self, mutated_code: str) -> bool:
|
||
"""
|
||
验证 mutation 后的代码基本语法正确
|
||
|
||
Args:
|
||
mutated_code: mutation 后的代码
|
||
|
||
Returns:
|
||
True if valid, False otherwise
|
||
"""
|
||
# 基本语法检查
|
||
if not mutated_code or len(mutated_code) < 50:
|
||
return False
|
||
|
||
# 检查是否有 module 关键字
|
||
if 'module' not in mutated_code:
|
||
return False
|
||
|
||
# 检查是否有 endmodule
|
||
if 'endmodule' not in mutated_code:
|
||
return False
|
||
|
||
# 检查是否有基本的 always 块或 assign 语句
|
||
if 'always' not in mutated_code and 'assign' not in mutated_code:
|
||
return False
|
||
|
||
# 检查括号是否平衡
|
||
open_count = mutated_code.count('(')
|
||
close_count = mutated_code.count(')')
|
||
if open_count != close_count:
|
||
return False
|
||
|
||
# 检查 begin-end 是否平衡
|
||
begin_count = len(re.findall(r'\bbegin\b', mutated_code))
|
||
end_count = len(re.findall(r'\bend\b', mutated_code))
|
||
if abs(begin_count - end_count) > 1: # 允许差1,因为可能有注释
|
||
return False
|
||
|
||
return True
|
||
|
||
def generate_targeted_mutations(self, num: int, target_type: str) -> list[str]:
|
||
"""
|
||
生成特定类型的 mutation
|
||
|
||
Args:
|
||
num: 要生成的 mutated RTL 数量
|
||
target_type: 'fsm', 'datapath', 或 'all'
|
||
|
||
Returns:
|
||
list of mutated RTL codes
|
||
"""
|
||
if target_type == 'fsm':
|
||
operators = [
|
||
FSMMutationOperators.flip_state_condition,
|
||
FSMMutationOperators.swap_next_state,
|
||
FSMMutationOperators.invert_counter_update,
|
||
FSMMutationOperators.stuck_at_state,
|
||
FSMMutationOperators.flip_comparison_operator,
|
||
]
|
||
elif target_type == 'datapath':
|
||
operators = [
|
||
DataPathMutationOperators.width_mismatch_assignment,
|
||
DataPathMutationOperators.missing_assignment,
|
||
DataPathMutationOperators.swap_shift_direction,
|
||
DataPathMutationOperators.invert_bit_select,
|
||
DataPathMutationOperators.wrong_bit_index,
|
||
DataPathMutationOperators.swap_registers,
|
||
]
|
||
else:
|
||
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
|
||
|
||
mutations = []
|
||
attempts = 0
|
||
max_attempts = num * 5
|
||
|
||
while len(mutations) < num and attempts < max_attempts:
|
||
attempts += 1
|
||
operator = random.choice(operators)
|
||
|
||
try:
|
||
mutated = operator(self.reference_rtl)
|
||
if mutated and mutated != self.reference_rtl and self._validate_mutation(mutated):
|
||
mutations.append(mutated)
|
||
except:
|
||
continue
|
||
|
||
return mutations
|
||
|
||
def apply_multiple_mutations(self, num_mutations: int = 2) -> str:
|
||
"""
|
||
在同一个 RTL 上应用多个 mutation
|
||
|
||
Args:
|
||
num_mutations: 要应用的 mutation 数量
|
||
|
||
Returns:
|
||
mutated RTL code
|
||
"""
|
||
current = self.reference_rtl
|
||
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
|
||
|
||
for _ in range(num_mutations):
|
||
operator = random.choice(operators)
|
||
try:
|
||
mutated = operator(current)
|
||
if mutated and mutated != current and self._validate_mutation(mutated):
|
||
current = mutated
|
||
except:
|
||
continue
|
||
|
||
return current
|
||
|
||
|
||
def generate_mutation_set(reference_rtl: str, num: int, seed: int = None) -> list[str]:
|
||
"""
|
||
便捷函数:生成一组 mutation RTL
|
||
|
||
Args:
|
||
reference_rtl: 参考 RTL
|
||
num: 数量
|
||
seed: 随机种子
|
||
|
||
Returns:
|
||
list of mutated RTL codes
|
||
"""
|
||
mutator = RTLMutator(reference_rtl, seed=seed)
|
||
return mutator.generate_mutations(num)
|