580 lines
20 KiB
Python
580 lines
20 KiB
Python
"""
|
||
Description : Test History Manager (Layer 1 Support Module)
|
||
- Store and manage test case history
|
||
- Support sequence pattern analysis
|
||
- Provide diversity statistics
|
||
Author : CGA Enhancement Project
|
||
Time : 2026/03/16
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import logging
|
||
from typing import List, Dict, Optional, Any, Tuple, Set
|
||
from dataclasses import dataclass, field, asdict
|
||
from datetime import datetime
|
||
from collections import defaultdict
|
||
import hashlib
|
||
import re
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ============================================================================
|
||
# 数据结构定义
|
||
# ============================================================================
|
||
|
||
@dataclass
|
||
class InputSequence:
|
||
"""
|
||
输入序列记录
|
||
|
||
Attributes:
|
||
signal_name: 信号名称
|
||
values: 赋值序列 [(time, value), ...]
|
||
"""
|
||
signal_name: str
|
||
values: List[Tuple[int, Any]] = field(default_factory=list)
|
||
|
||
def to_pattern_string(self) -> str:
|
||
"""转换为模式字符串(仅包含值)"""
|
||
return "->".join(str(v[1]) for v in self.values)
|
||
|
||
def get_hash(self) -> str:
|
||
"""获取序列哈希值"""
|
||
return hashlib.md5(self.to_pattern_string().encode()).hexdigest()[:8]
|
||
|
||
|
||
@dataclass
|
||
class TestRecord:
|
||
"""
|
||
测试用例记录
|
||
|
||
Attributes:
|
||
test_id: 测试ID
|
||
code: 生成的测试代码
|
||
input_sequences: 输入信号序列列表
|
||
target_function: 目标功能点
|
||
covered_lines: 覆盖的代码行
|
||
covered_functions: 覆盖的功能点
|
||
coverage_score: 覆盖率分数
|
||
diversity_scores: 多样性得分字典
|
||
iteration: 迭代次数
|
||
timestamp: 时间戳
|
||
success: 是否成功
|
||
"""
|
||
test_id: str
|
||
code: str = ""
|
||
input_sequences: List[InputSequence] = field(default_factory=list)
|
||
target_function: str = ""
|
||
covered_lines: List[int] = field(default_factory=list)
|
||
covered_functions: List[str] = field(default_factory=list)
|
||
coverage_score: float = 0.0
|
||
diversity_scores: Dict[str, float] = field(default_factory=dict)
|
||
iteration: int = 0
|
||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
success: bool = False
|
||
|
||
def get_sequence_patterns(self) -> Dict[str, str]:
|
||
"""获取所有输入序列的模式"""
|
||
return {seq.signal_name: seq.to_pattern_string() for seq in self.input_sequences}
|
||
|
||
|
||
@dataclass
|
||
class SequencePattern:
|
||
"""
|
||
序列模式统计
|
||
|
||
Attributes:
|
||
pattern: 模式字符串
|
||
count: 出现次数
|
||
signal_name: 所属信号
|
||
test_ids: 关联的测试ID列表
|
||
"""
|
||
pattern: str
|
||
count: int = 0
|
||
signal_name: str = ""
|
||
test_ids: List[str] = field(default_factory=list)
|
||
|
||
def is_overused(self, threshold: int = 3) -> bool:
|
||
"""判断是否过度使用"""
|
||
return self.count >= threshold
|
||
|
||
|
||
# ============================================================================
|
||
# 序列提取器
|
||
# ============================================================================
|
||
|
||
class SequenceExtractor:
|
||
"""
|
||
从测试代码中提取输入序列
|
||
|
||
解析Verilog测试代码,提取信号赋值序列
|
||
"""
|
||
|
||
# 匹配信号赋值语句
|
||
ASSIGNMENT_PATTERNS = [
|
||
# 阻塞赋值: signal = value;
|
||
r'(\w+)\s*=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
|
||
# 非阻塞赋值: signal <= value;
|
||
r'(\w+)\s*<=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
|
||
# 简单赋值(无位宽)
|
||
r'(\w+)\s*=\s*(\d+)\s*;',
|
||
]
|
||
|
||
# 匹配延时
|
||
DELAY_PATTERN = r'#\s*(\d+)\s*;'
|
||
|
||
# 匹配时钟周期等待
|
||
CLOCK_WAIT_PATTERN = r'repeat\s*\(\s*(\d+)\s*\)\s*@\s*\(\s*posedge\s+(\w+)\s*\)'
|
||
|
||
def __init__(self):
|
||
self.known_signals: Set[str] = set()
|
||
|
||
def set_known_signals(self, signals: List[str]):
|
||
"""设置已知信号列表(用于过滤)"""
|
||
self.known_signals = set(signals)
|
||
|
||
def extract(self, code: str) -> List[InputSequence]:
|
||
"""
|
||
从代码中提取输入序列
|
||
|
||
Args:
|
||
code: Verilog测试代码
|
||
|
||
Returns:
|
||
输入序列列表
|
||
"""
|
||
sequences = {}
|
||
current_time = 0
|
||
|
||
# 按行处理代码
|
||
lines = code.split('\n')
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
|
||
# 跳过注释和空行
|
||
if not line or line.startswith('//'):
|
||
continue
|
||
|
||
# 检测延时,更新时间
|
||
delay_match = re.search(self.DELAY_PATTERN, line)
|
||
if delay_match:
|
||
current_time += int(delay_match.group(1))
|
||
continue
|
||
|
||
# 检测时钟周期等待
|
||
clock_match = re.search(self.CLOCK_WAIT_PATTERN, line, re.IGNORECASE)
|
||
if clock_match:
|
||
cycles = int(clock_match.group(1))
|
||
current_time += cycles * 10 # 假设每周期10时间单位
|
||
continue
|
||
|
||
# 检测赋值语句
|
||
for pattern in self.ASSIGNMENT_PATTERNS:
|
||
matches = re.finditer(pattern, line, re.IGNORECASE)
|
||
for match in matches:
|
||
signal = match.group(1)
|
||
value = match.group(2)
|
||
|
||
# 过滤非目标信号
|
||
if self.known_signals and signal not in self.known_signals:
|
||
continue
|
||
|
||
# 跳过明显的非输入信号
|
||
if signal.lower() in ['i', 'j', 'k', 'cnt', 'count', 'temp']:
|
||
continue
|
||
|
||
if signal not in sequences:
|
||
sequences[signal] = InputSequence(signal_name=signal)
|
||
|
||
sequences[signal].values.append((current_time, value))
|
||
current_time += 1 # 赋值语句本身占用1时间单位
|
||
|
||
return list(sequences.values())
|
||
|
||
|
||
# ============================================================================
|
||
# 测试历史管理器
|
||
# ============================================================================
|
||
|
||
class TestHistoryManager:
|
||
"""
|
||
测试历史管理器
|
||
|
||
管理已生成测试用例的历史记录,支持:
|
||
- 测试用例存储和检索
|
||
- 序列模式统计分析
|
||
- 多样性分布统计
|
||
"""
|
||
|
||
def __init__(self, history_file: str = None):
|
||
"""
|
||
Args:
|
||
history_file: 历史记录文件路径(可选)
|
||
"""
|
||
|
||
#必须先保存 history_file,否则 save() 方法无法找到文件路径
|
||
self.history_file = history_file
|
||
|
||
self.records: List[TestRecord] = []
|
||
self.patterns: Dict[str, SequencePattern] = {} # pattern_hash -> SequencePattern
|
||
self.signal_patterns: Dict[str, List[str]] = defaultdict(list) # signal_name -> [pattern_hashes]
|
||
self.sequence_extractor = SequenceExtractor()
|
||
|
||
# 统计信息
|
||
self.stats = {
|
||
'total_tests': 0,
|
||
'successful_tests': 0,
|
||
'total_coverage': 0.0,
|
||
'avg_diversity': 0.0
|
||
}
|
||
|
||
if history_file and os.path.exists(history_file):
|
||
self.load(history_file)
|
||
|
||
# ==================== 记录管理 ====================
|
||
|
||
def add_record(self,
|
||
code: str,
|
||
test_id: str = None,
|
||
target_function: str = "",
|
||
covered_lines: List[int] = None,
|
||
covered_functions: List[str] = None,
|
||
coverage_score: float = 0.0,
|
||
iteration: int = 0,
|
||
success: bool = False,
|
||
known_signals: List[str] = None) -> TestRecord:
|
||
"""
|
||
添加测试记录
|
||
|
||
Args:
|
||
code: 测试代码
|
||
test_id: 测试ID(自动生成如果未提供)
|
||
target_function: 目标功能点
|
||
covered_lines: 覆盖的代码行
|
||
covered_functions: 覆盖的功能点
|
||
coverage_score: 覆盖率分数
|
||
iteration: 迭代次数
|
||
success: 是否成功
|
||
known_signals: 已知信号列表
|
||
|
||
Returns:
|
||
创建的测试记录
|
||
"""
|
||
if test_id is None:
|
||
test_id = f"test_{len(self.records)}_{datetime.now().strftime('%H%M%S')}"
|
||
|
||
# 提取输入序列
|
||
if known_signals:
|
||
self.sequence_extractor.set_known_signals(known_signals)
|
||
input_sequences = self.sequence_extractor.extract(code)
|
||
|
||
# 创建记录
|
||
record = TestRecord(
|
||
test_id=test_id,
|
||
code=code,
|
||
input_sequences=input_sequences,
|
||
target_function=target_function,
|
||
covered_lines=covered_lines or [],
|
||
covered_functions=covered_functions or [],
|
||
coverage_score=coverage_score,
|
||
iteration=iteration,
|
||
success=success
|
||
)
|
||
|
||
self.records.append(record)
|
||
|
||
# 更新模式统计
|
||
self._update_patterns(record)
|
||
|
||
# 更新统计信息
|
||
self._update_stats()
|
||
|
||
logger.debug(f"Added test record: {test_id}, sequences: {len(input_sequences)}")
|
||
|
||
return record
|
||
|
||
def get_record(self, test_id: str) -> Optional[TestRecord]:
|
||
"""根据ID获取记录"""
|
||
for record in self.records:
|
||
if record.test_id == test_id:
|
||
return record
|
||
return None
|
||
|
||
def get_recent_records(self, n: int = 10) -> List[TestRecord]:
|
||
"""获取最近的N条记录"""
|
||
return self.records[-n:] if len(self.records) >= n else self.records
|
||
|
||
def get_successful_records(self) -> List[TestRecord]:
|
||
"""获取所有成功的记录"""
|
||
return [r for r in self.records if r.success]
|
||
|
||
# ==================== 模式分析 ====================
|
||
|
||
def _update_patterns(self, record: TestRecord):
|
||
"""更新序列模式统计"""
|
||
for seq in record.input_sequences:
|
||
pattern_str = seq.to_pattern_string()
|
||
pattern_hash = seq.get_hash()
|
||
|
||
if pattern_hash not in self.patterns:
|
||
self.patterns[pattern_hash] = SequencePattern(
|
||
pattern=pattern_str,
|
||
count=1,
|
||
signal_name=seq.signal_name,
|
||
test_ids=[record.test_id]
|
||
)
|
||
else:
|
||
self.patterns[pattern_hash].count += 1
|
||
self.patterns[pattern_hash].test_ids.append(record.test_id)
|
||
|
||
# 按信号索引
|
||
if pattern_hash not in self.signal_patterns[seq.signal_name]:
|
||
self.signal_patterns[seq.signal_name].append(pattern_hash)
|
||
|
||
def get_overused_patterns(self, threshold: int = 3) -> List[SequencePattern]:
|
||
"""
|
||
获取过度使用的模式
|
||
|
||
Args:
|
||
threshold: 过度使用阈值
|
||
|
||
Returns:
|
||
过度使用的模式列表
|
||
"""
|
||
return [p for p in self.patterns.values() if p.is_overused(threshold)]
|
||
|
||
def get_common_patterns(self, top_n: int = 5) -> List[Tuple[str, int]]:
|
||
"""
|
||
获取最常见的模式
|
||
|
||
Args:
|
||
top_n: 返回数量
|
||
|
||
Returns:
|
||
[(pattern, count), ...]
|
||
"""
|
||
sorted_patterns = sorted(
|
||
self.patterns.items(),
|
||
key=lambda x: x[1].count,
|
||
reverse=True
|
||
)
|
||
return [(p[1].pattern, p[1].count) for p in sorted_patterns[:top_n]]
|
||
|
||
def get_pattern_for_signal(self, signal_name: str) -> List[SequencePattern]:
|
||
"""获取特定信号的所有模式"""
|
||
pattern_hashes = self.signal_patterns.get(signal_name, [])
|
||
return [self.patterns[h] for h in pattern_hashes if h in self.patterns]
|
||
|
||
# ==================== 多样性分析 ====================
|
||
|
||
def calculate_sequence_diversity(self, new_sequences: List[InputSequence]) -> float:
|
||
"""
|
||
计算新序列与历史记录的多样性得分
|
||
|
||
Args:
|
||
new_sequences: 新的输入序列列表
|
||
|
||
Returns:
|
||
多样性得分 (0.0 - 1.0)
|
||
"""
|
||
if not self.records:
|
||
return 1.0 # 没有历史记录时,认为完全多样
|
||
|
||
if not new_sequences:
|
||
return 0.0 # 没有序列时,多样性为0
|
||
|
||
# 检查模式重复度
|
||
new_patterns = {seq.get_hash() for seq in new_sequences}
|
||
total_patterns = len(new_patterns)
|
||
|
||
if total_patterns == 0:
|
||
return 0.0
|
||
|
||
# 计算新模式比例
|
||
new_pattern_count = sum(1 for h in new_patterns if h not in self.patterns)
|
||
pattern_diversity = new_pattern_count / total_patterns
|
||
|
||
return pattern_diversity
|
||
|
||
def calculate_edit_distance_diversity(self, new_code: str) -> float:
|
||
"""
|
||
基于编辑距离计算多样性
|
||
|
||
使用简化的编辑距离计算
|
||
"""
|
||
if not self.records:
|
||
return 1.0
|
||
|
||
# 获取最近的记录作为参考
|
||
recent_records = self.get_recent_records(5)
|
||
|
||
min_distance = float('inf')
|
||
for record in recent_records:
|
||
distance = self._levenshtein_distance(new_code, record.code)
|
||
min_distance = min(min_distance, distance)
|
||
|
||
# 归一化到 [0, 1]
|
||
max_len = max(len(new_code), max(len(r.code) for r in recent_records))
|
||
if max_len == 0:
|
||
return 0.0
|
||
|
||
return min_distance / max_len
|
||
|
||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||
"""计算Levenshtein编辑距离(简化版)"""
|
||
if len(s1) < len(s2):
|
||
return self._levenshtein_distance(s2, s1)
|
||
|
||
if len(s2) == 0:
|
||
return len(s1)
|
||
|
||
# 使用简化的计算(抽样)
|
||
if len(s1) > 500:
|
||
s1 = s1[:500]
|
||
if len(s2) > 500:
|
||
s2 = s2[:500]
|
||
|
||
previous_row = range(len(s2) + 1)
|
||
for i, c1 in enumerate(s1):
|
||
current_row = [i + 1]
|
||
for j, c2 in enumerate(s2):
|
||
insertions = previous_row[j + 1] + 1
|
||
deletions = current_row[j] + 1
|
||
substitutions = previous_row[j] + (c1 != c2)
|
||
current_row.append(min(insertions, deletions, substitutions))
|
||
previous_row = current_row
|
||
|
||
return previous_row[-1]
|
||
|
||
# ==================== 统计信息 ====================
|
||
|
||
def _update_stats(self):
|
||
"""更新统计信息"""
|
||
self.stats['total_tests'] = len(self.records)
|
||
self.stats['successful_tests'] = sum(1 for r in self.records if r.success)
|
||
|
||
if self.records:
|
||
self.stats['total_coverage'] = sum(r.coverage_score for r in self.records)
|
||
self.stats['avg_coverage'] = self.stats['total_coverage'] / len(self.records)
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
return {
|
||
**self.stats,
|
||
'total_patterns': len(self.patterns),
|
||
'overused_patterns': len(self.get_overused_patterns()),
|
||
'unique_signals': len(self.signal_patterns)
|
||
}
|
||
|
||
def get_diversity_report(self) -> str:
|
||
"""生成多样性报告"""
|
||
lines = []
|
||
lines.append("=" * 50)
|
||
lines.append("TEST HISTORY DIVERSITY REPORT")
|
||
lines.append("=" * 50)
|
||
lines.append(f"Total Tests: {self.stats['total_tests']}")
|
||
lines.append(f"Successful Tests: {self.stats['successful_tests']}")
|
||
lines.append(f"Total Patterns: {len(self.patterns)}")
|
||
lines.append("")
|
||
|
||
# 常见模式
|
||
lines.append("TOP 5 COMMON PATTERNS:")
|
||
common = self.get_common_patterns(5)
|
||
for i, (pattern, count) in enumerate(common, 1):
|
||
lines.append(f" {i}. {pattern[:40]}... (x{count})")
|
||
|
||
# 过度使用的模式
|
||
overused = self.get_overused_patterns()
|
||
if overused:
|
||
lines.append("")
|
||
lines.append("OVERUSED PATTERNS (need diversification):")
|
||
for p in overused[:5]:
|
||
lines.append(f" - {p.signal_name}: {p.pattern[:30]}... (used {p.count} times)")
|
||
|
||
lines.append("=" * 50)
|
||
return "\n".join(lines)
|
||
|
||
# ==================== 持久化 ====================
|
||
|
||
def save(self, filepath: str = None):
|
||
"""保存历史记录到文件"""
|
||
filepath = filepath or self.history_file
|
||
if not filepath:
|
||
return
|
||
|
||
# 手动构建可序列化的数据结构
|
||
records_data = []
|
||
for r in self.records:
|
||
record_dict = {
|
||
'test_id': r.test_id,
|
||
'code': r.code,
|
||
'input_sequences': [],
|
||
'target_function': r.target_function,
|
||
'covered_lines': r.covered_lines,
|
||
'covered_functions': r.covered_functions,
|
||
'coverage_score': r.coverage_score,
|
||
'diversity_scores': r.diversity_scores,
|
||
'iteration': r.iteration,
|
||
'timestamp': r.timestamp,
|
||
'success': r.success
|
||
}
|
||
# 手动转换 InputSequence 对象
|
||
for seq in r.input_sequences:
|
||
record_dict['input_sequences'].append({
|
||
'signal_name': seq.signal_name,
|
||
'values': seq.values
|
||
})
|
||
records_data.append(record_dict)
|
||
|
||
data = {
|
||
'records': records_data,
|
||
'stats': self.stats
|
||
}
|
||
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info(f"Test history saved to {filepath}")
|
||
|
||
def load(self, filepath: str):
|
||
"""从文件加载历史记录"""
|
||
if not os.path.exists(filepath):
|
||
return
|
||
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
self.records = []
|
||
for r in data.get('records', []):
|
||
sequences = [
|
||
InputSequence(**s) for s in r.get('input_sequences', [])
|
||
]
|
||
record = TestRecord(
|
||
test_id=r['test_id'],
|
||
code=r['code'],
|
||
input_sequences=sequences,
|
||
target_function=r.get('target_function', ''),
|
||
covered_lines=r.get('covered_lines', []),
|
||
covered_functions=r.get('covered_functions', []),
|
||
coverage_score=r.get('coverage_score', 0.0),
|
||
iteration=r.get('iteration', 0),
|
||
timestamp=r.get('timestamp', ''),
|
||
success=r.get('success', False)
|
||
)
|
||
self.records.append(record)
|
||
self._update_patterns(record)
|
||
|
||
self.stats = data.get('stats', self.stats)
|
||
logger.info(f"Loaded {len(self.records)} test records from {filepath}")
|
||
|
||
|
||
# ============================================================================
|
||
# 便捷函数
|
||
# ============================================================================
|
||
|
||
def create_test_history(history_file: str = None) -> TestHistoryManager:
|
||
"""创建测试历史管理器"""
|
||
return TestHistoryManager(history_file=history_file) |