211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
|
|
"""
|
|||
|
|
Description : RTL Mutator - 应用 mutation 算子生成错误 RTL
|
|||
|
|
用于替代 LLM 自由生成,确保语法正确且错误模式可控
|
|||
|
|
Author : CorrectBench
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import random
|
|||
|
|
import re
|
|||
|
|
from loader_saver import autologger as logger
|
|||
|
|
from .mutation_operators import MutationOperatorRegistry, FSMMutationOperators, DataPathMutationOperators
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RTLMutator:
|
|||
|
|
"""
|
|||
|
|
将 mutation 算子应用到参考 RTL,生成可控的错误 RTL
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, reference_rtl: str, seed: int = None):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
reference_rtl: 正确的参考 RTL 代码
|
|||
|
|
seed: 随机种子(用于可重复性)
|
|||
|
|
"""
|
|||
|
|
self.reference_rtl = reference_rtl
|
|||
|
|
if seed is not None:
|
|||
|
|
random.seed(seed)
|
|||
|
|
|
|||
|
|
def _safe_log(self, level: str, message: str):
|
|||
|
|
"""安全地记录日志,处理 logger 未初始化的情况"""
|
|||
|
|
try:
|
|||
|
|
if hasattr(logger, 'logger') and logger.logger is not None:
|
|||
|
|
if level == 'debug':
|
|||
|
|
logger.debug(message)
|
|||
|
|
elif level == 'warning':
|
|||
|
|
logger.warning(message)
|
|||
|
|
else:
|
|||
|
|
logger.info(message)
|
|||
|
|
except:
|
|||
|
|
print(f"[RTLMutator {level.upper()}] {message}")
|
|||
|
|
|
|||
|
|
def generate_mutations(self, num: int, conservative_weight: float = None) -> list[str]:
|
|||
|
|
"""
|
|||
|
|
生成指定数量的 mutated RTL
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
num: 要生成的 mutated RTL 数量
|
|||
|
|
conservative_weight: 保守算子权重 (0.0-1.0),默认 0.7
|
|||
|
|
值越高,生成的 RTL 越接近原始代码
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list of mutated RTL codes
|
|||
|
|
"""
|
|||
|
|
if conservative_weight is None:
|
|||
|
|
conservative_weight = MutationOperatorRegistry.DEFAULT_CONSERVATIVE_WEIGHT
|
|||
|
|
|
|||
|
|
mutations = []
|
|||
|
|
attempts = 0
|
|||
|
|
max_attempts = num * 3 # 尝试次数上限
|
|||
|
|
|
|||
|
|
while len(mutations) < num and attempts < max_attempts:
|
|||
|
|
attempts += 1
|
|||
|
|
|
|||
|
|
# 随机选择一个 mutation 算子(带保守权重)
|
|||
|
|
name, operator = MutationOperatorRegistry.get_random_operator(conservative_weight)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
mutated = operator(self.reference_rtl)
|
|||
|
|
|
|||
|
|
# 检查 mutation 是否真的产生了变化
|
|||
|
|
if mutated and mutated != self.reference_rtl:
|
|||
|
|
# 验证语法基本正确(不是完全破坏)
|
|||
|
|
if self._validate_mutation(mutated):
|
|||
|
|
mutations.append(mutated)
|
|||
|
|
self._safe_log('debug', f"Applied mutation: {name}")
|
|||
|
|
except Exception as e:
|
|||
|
|
self._safe_log('debug', f"Mutation failed: {name}, error: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if len(mutations) < num:
|
|||
|
|
self._safe_log('warning', f"Only generated {len(mutations)}/{num} mutations")
|
|||
|
|
|
|||
|
|
return mutations
|
|||
|
|
|
|||
|
|
def _validate_mutation(self, mutated_code: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
验证 mutation 后的代码基本语法正确
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
mutated_code: mutation 后的代码
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
True if valid, False otherwise
|
|||
|
|
"""
|
|||
|
|
# 基本语法检查
|
|||
|
|
if not mutated_code or len(mutated_code) < 50:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 检查是否有 module 关键字
|
|||
|
|
if 'module' not in mutated_code:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 检查是否有 endmodule
|
|||
|
|
if 'endmodule' not in mutated_code:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 检查是否有基本的 always 块或 assign 语句
|
|||
|
|
if 'always' not in mutated_code and 'assign' not in mutated_code:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 检查括号是否平衡
|
|||
|
|
open_count = mutated_code.count('(')
|
|||
|
|
close_count = mutated_code.count(')')
|
|||
|
|
if open_count != close_count:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 检查 begin-end 是否平衡
|
|||
|
|
begin_count = len(re.findall(r'\bbegin\b', mutated_code))
|
|||
|
|
end_count = len(re.findall(r'\bend\b', mutated_code))
|
|||
|
|
if abs(begin_count - end_count) > 1: # 允许差1,因为可能有注释
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def generate_targeted_mutations(self, num: int, target_type: str) -> list[str]:
|
|||
|
|
"""
|
|||
|
|
生成特定类型的 mutation
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
num: 要生成的 mutated RTL 数量
|
|||
|
|
target_type: 'fsm', 'datapath', 或 'all'
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list of mutated RTL codes
|
|||
|
|
"""
|
|||
|
|
if target_type == 'fsm':
|
|||
|
|
operators = [
|
|||
|
|
FSMMutationOperators.flip_state_condition,
|
|||
|
|
FSMMutationOperators.swap_next_state,
|
|||
|
|
FSMMutationOperators.invert_counter_update,
|
|||
|
|
FSMMutationOperators.stuck_at_state,
|
|||
|
|
FSMMutationOperators.flip_comparison_operator,
|
|||
|
|
]
|
|||
|
|
elif target_type == 'datapath':
|
|||
|
|
operators = [
|
|||
|
|
DataPathMutationOperators.width_mismatch_assignment,
|
|||
|
|
DataPathMutationOperators.missing_assignment,
|
|||
|
|
DataPathMutationOperators.swap_shift_direction,
|
|||
|
|
DataPathMutationOperators.invert_bit_select,
|
|||
|
|
DataPathMutationOperators.wrong_bit_index,
|
|||
|
|
DataPathMutationOperators.swap_registers,
|
|||
|
|
]
|
|||
|
|
else:
|
|||
|
|
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
|
|||
|
|
|
|||
|
|
mutations = []
|
|||
|
|
attempts = 0
|
|||
|
|
max_attempts = num * 5
|
|||
|
|
|
|||
|
|
while len(mutations) < num and attempts < max_attempts:
|
|||
|
|
attempts += 1
|
|||
|
|
operator = random.choice(operators)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
mutated = operator(self.reference_rtl)
|
|||
|
|
if mutated and mutated != self.reference_rtl and self._validate_mutation(mutated):
|
|||
|
|
mutations.append(mutated)
|
|||
|
|
except:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return mutations
|
|||
|
|
|
|||
|
|
def apply_multiple_mutations(self, num_mutations: int = 2) -> str:
|
|||
|
|
"""
|
|||
|
|
在同一个 RTL 上应用多个 mutation
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
num_mutations: 要应用的 mutation 数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
mutated RTL code
|
|||
|
|
"""
|
|||
|
|
current = self.reference_rtl
|
|||
|
|
operators = [op for _, op in MutationOperatorRegistry.get_all_operators()]
|
|||
|
|
|
|||
|
|
for _ in range(num_mutations):
|
|||
|
|
operator = random.choice(operators)
|
|||
|
|
try:
|
|||
|
|
mutated = operator(current)
|
|||
|
|
if mutated and mutated != current and self._validate_mutation(mutated):
|
|||
|
|
current = mutated
|
|||
|
|
except:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return current
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_mutation_set(reference_rtl: str, num: int, seed: int = None) -> list[str]:
|
|||
|
|
"""
|
|||
|
|
便捷函数:生成一组 mutation RTL
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
reference_rtl: 参考 RTL
|
|||
|
|
num: 数量
|
|||
|
|
seed: 随机种子
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list of mutated RTL codes
|
|||
|
|
"""
|
|||
|
|
mutator = RTLMutator(reference_rtl, seed=seed)
|
|||
|
|
return mutator.generate_mutations(num)
|