Files
TBgen_App/autoline/test_history.py
2026-03-30 16:46:48 +08:00

580 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Description : Test History Manager (Layer 1 Support Module)
- Store and manage test case history
- Support sequence pattern analysis
- Provide diversity statistics
Author : CGA Enhancement Project
Time : 2026/03/16
"""
import json
import os
import logging
from typing import List, Dict, Optional, Any, Tuple, Set
from dataclasses import dataclass, field, asdict
from datetime import datetime
from collections import defaultdict
import hashlib
import re
logger = logging.getLogger(__name__)
# ============================================================================
# 数据结构定义
# ============================================================================
@dataclass
class InputSequence:
"""
输入序列记录
Attributes:
signal_name: 信号名称
values: 赋值序列 [(time, value), ...]
"""
signal_name: str
values: List[Tuple[int, Any]] = field(default_factory=list)
def to_pattern_string(self) -> str:
"""转换为模式字符串(仅包含值)"""
return "->".join(str(v[1]) for v in self.values)
def get_hash(self) -> str:
"""获取序列哈希值"""
return hashlib.md5(self.to_pattern_string().encode()).hexdigest()[:8]
@dataclass
class TestRecord:
"""
测试用例记录
Attributes:
test_id: 测试ID
code: 生成的测试代码
input_sequences: 输入信号序列列表
target_function: 目标功能点
covered_lines: 覆盖的代码行
covered_functions: 覆盖的功能点
coverage_score: 覆盖率分数
diversity_scores: 多样性得分字典
iteration: 迭代次数
timestamp: 时间戳
success: 是否成功
"""
test_id: str
code: str = ""
input_sequences: List[InputSequence] = field(default_factory=list)
target_function: str = ""
covered_lines: List[int] = field(default_factory=list)
covered_functions: List[str] = field(default_factory=list)
coverage_score: float = 0.0
diversity_scores: Dict[str, float] = field(default_factory=dict)
iteration: int = 0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
success: bool = False
def get_sequence_patterns(self) -> Dict[str, str]:
"""获取所有输入序列的模式"""
return {seq.signal_name: seq.to_pattern_string() for seq in self.input_sequences}
@dataclass
class SequencePattern:
"""
序列模式统计
Attributes:
pattern: 模式字符串
count: 出现次数
signal_name: 所属信号
test_ids: 关联的测试ID列表
"""
pattern: str
count: int = 0
signal_name: str = ""
test_ids: List[str] = field(default_factory=list)
def is_overused(self, threshold: int = 3) -> bool:
"""判断是否过度使用"""
return self.count >= threshold
# ============================================================================
# 序列提取器
# ============================================================================
class SequenceExtractor:
"""
从测试代码中提取输入序列
解析Verilog测试代码提取信号赋值序列
"""
# 匹配信号赋值语句
ASSIGNMENT_PATTERNS = [
# 阻塞赋值: signal = value;
r'(\w+)\s*=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
# 非阻塞赋值: signal <= value;
r'(\w+)\s*<=\s*([0-9]+\'[bdh][0-9a-fA-FxXzZ_]+|\d+|x|z)\s*;',
# 简单赋值(无位宽)
r'(\w+)\s*=\s*(\d+)\s*;',
]
# 匹配延时
DELAY_PATTERN = r'#\s*(\d+)\s*;'
# 匹配时钟周期等待
CLOCK_WAIT_PATTERN = r'repeat\s*\(\s*(\d+)\s*\)\s*@\s*\(\s*posedge\s+(\w+)\s*\)'
def __init__(self):
self.known_signals: Set[str] = set()
def set_known_signals(self, signals: List[str]):
"""设置已知信号列表(用于过滤)"""
self.known_signals = set(signals)
def extract(self, code: str) -> List[InputSequence]:
"""
从代码中提取输入序列
Args:
code: Verilog测试代码
Returns:
输入序列列表
"""
sequences = {}
current_time = 0
# 按行处理代码
lines = code.split('\n')
for line in lines:
line = line.strip()
# 跳过注释和空行
if not line or line.startswith('//'):
continue
# 检测延时,更新时间
delay_match = re.search(self.DELAY_PATTERN, line)
if delay_match:
current_time += int(delay_match.group(1))
continue
# 检测时钟周期等待
clock_match = re.search(self.CLOCK_WAIT_PATTERN, line, re.IGNORECASE)
if clock_match:
cycles = int(clock_match.group(1))
current_time += cycles * 10 # 假设每周期10时间单位
continue
# 检测赋值语句
for pattern in self.ASSIGNMENT_PATTERNS:
matches = re.finditer(pattern, line, re.IGNORECASE)
for match in matches:
signal = match.group(1)
value = match.group(2)
# 过滤非目标信号
if self.known_signals and signal not in self.known_signals:
continue
# 跳过明显的非输入信号
if signal.lower() in ['i', 'j', 'k', 'cnt', 'count', 'temp']:
continue
if signal not in sequences:
sequences[signal] = InputSequence(signal_name=signal)
sequences[signal].values.append((current_time, value))
current_time += 1 # 赋值语句本身占用1时间单位
return list(sequences.values())
# ============================================================================
# 测试历史管理器
# ============================================================================
class TestHistoryManager:
"""
测试历史管理器
管理已生成测试用例的历史记录,支持:
- 测试用例存储和检索
- 序列模式统计分析
- 多样性分布统计
"""
def __init__(self, history_file: str = None):
"""
Args:
history_file: 历史记录文件路径(可选)
"""
#必须先保存 history_file否则 save() 方法无法找到文件路径
self.history_file = history_file
self.records: List[TestRecord] = []
self.patterns: Dict[str, SequencePattern] = {} # pattern_hash -> SequencePattern
self.signal_patterns: Dict[str, List[str]] = defaultdict(list) # signal_name -> [pattern_hashes]
self.sequence_extractor = SequenceExtractor()
# 统计信息
self.stats = {
'total_tests': 0,
'successful_tests': 0,
'total_coverage': 0.0,
'avg_diversity': 0.0
}
if history_file and os.path.exists(history_file):
self.load(history_file)
# ==================== 记录管理 ====================
def add_record(self,
code: str,
test_id: str = None,
target_function: str = "",
covered_lines: List[int] = None,
covered_functions: List[str] = None,
coverage_score: float = 0.0,
iteration: int = 0,
success: bool = False,
known_signals: List[str] = None) -> TestRecord:
"""
添加测试记录
Args:
code: 测试代码
test_id: 测试ID自动生成如果未提供
target_function: 目标功能点
covered_lines: 覆盖的代码行
covered_functions: 覆盖的功能点
coverage_score: 覆盖率分数
iteration: 迭代次数
success: 是否成功
known_signals: 已知信号列表
Returns:
创建的测试记录
"""
if test_id is None:
test_id = f"test_{len(self.records)}_{datetime.now().strftime('%H%M%S')}"
# 提取输入序列
if known_signals:
self.sequence_extractor.set_known_signals(known_signals)
input_sequences = self.sequence_extractor.extract(code)
# 创建记录
record = TestRecord(
test_id=test_id,
code=code,
input_sequences=input_sequences,
target_function=target_function,
covered_lines=covered_lines or [],
covered_functions=covered_functions or [],
coverage_score=coverage_score,
iteration=iteration,
success=success
)
self.records.append(record)
# 更新模式统计
self._update_patterns(record)
# 更新统计信息
self._update_stats()
logger.debug(f"Added test record: {test_id}, sequences: {len(input_sequences)}")
return record
def get_record(self, test_id: str) -> Optional[TestRecord]:
"""根据ID获取记录"""
for record in self.records:
if record.test_id == test_id:
return record
return None
def get_recent_records(self, n: int = 10) -> List[TestRecord]:
"""获取最近的N条记录"""
return self.records[-n:] if len(self.records) >= n else self.records
def get_successful_records(self) -> List[TestRecord]:
"""获取所有成功的记录"""
return [r for r in self.records if r.success]
# ==================== 模式分析 ====================
def _update_patterns(self, record: TestRecord):
"""更新序列模式统计"""
for seq in record.input_sequences:
pattern_str = seq.to_pattern_string()
pattern_hash = seq.get_hash()
if pattern_hash not in self.patterns:
self.patterns[pattern_hash] = SequencePattern(
pattern=pattern_str,
count=1,
signal_name=seq.signal_name,
test_ids=[record.test_id]
)
else:
self.patterns[pattern_hash].count += 1
self.patterns[pattern_hash].test_ids.append(record.test_id)
# 按信号索引
if pattern_hash not in self.signal_patterns[seq.signal_name]:
self.signal_patterns[seq.signal_name].append(pattern_hash)
def get_overused_patterns(self, threshold: int = 3) -> List[SequencePattern]:
"""
获取过度使用的模式
Args:
threshold: 过度使用阈值
Returns:
过度使用的模式列表
"""
return [p for p in self.patterns.values() if p.is_overused(threshold)]
def get_common_patterns(self, top_n: int = 5) -> List[Tuple[str, int]]:
"""
获取最常见的模式
Args:
top_n: 返回数量
Returns:
[(pattern, count), ...]
"""
sorted_patterns = sorted(
self.patterns.items(),
key=lambda x: x[1].count,
reverse=True
)
return [(p[1].pattern, p[1].count) for p in sorted_patterns[:top_n]]
def get_pattern_for_signal(self, signal_name: str) -> List[SequencePattern]:
"""获取特定信号的所有模式"""
pattern_hashes = self.signal_patterns.get(signal_name, [])
return [self.patterns[h] for h in pattern_hashes if h in self.patterns]
# ==================== 多样性分析 ====================
def calculate_sequence_diversity(self, new_sequences: List[InputSequence]) -> float:
"""
计算新序列与历史记录的多样性得分
Args:
new_sequences: 新的输入序列列表
Returns:
多样性得分 (0.0 - 1.0)
"""
if not self.records:
return 1.0 # 没有历史记录时,认为完全多样
if not new_sequences:
return 0.0 # 没有序列时多样性为0
# 检查模式重复度
new_patterns = {seq.get_hash() for seq in new_sequences}
total_patterns = len(new_patterns)
if total_patterns == 0:
return 0.0
# 计算新模式比例
new_pattern_count = sum(1 for h in new_patterns if h not in self.patterns)
pattern_diversity = new_pattern_count / total_patterns
return pattern_diversity
def calculate_edit_distance_diversity(self, new_code: str) -> float:
"""
基于编辑距离计算多样性
使用简化的编辑距离计算
"""
if not self.records:
return 1.0
# 获取最近的记录作为参考
recent_records = self.get_recent_records(5)
min_distance = float('inf')
for record in recent_records:
distance = self._levenshtein_distance(new_code, record.code)
min_distance = min(min_distance, distance)
# 归一化到 [0, 1]
max_len = max(len(new_code), max(len(r.code) for r in recent_records))
if max_len == 0:
return 0.0
return min_distance / max_len
def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""计算Levenshtein编辑距离简化版"""
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
# 使用简化的计算(抽样)
if len(s1) > 500:
s1 = s1[:500]
if len(s2) > 500:
s2 = s2[:500]
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
# ==================== 统计信息 ====================
def _update_stats(self):
"""更新统计信息"""
self.stats['total_tests'] = len(self.records)
self.stats['successful_tests'] = sum(1 for r in self.records if r.success)
if self.records:
self.stats['total_coverage'] = sum(r.coverage_score for r in self.records)
self.stats['avg_coverage'] = self.stats['total_coverage'] / len(self.records)
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
'total_patterns': len(self.patterns),
'overused_patterns': len(self.get_overused_patterns()),
'unique_signals': len(self.signal_patterns)
}
def get_diversity_report(self) -> str:
"""生成多样性报告"""
lines = []
lines.append("=" * 50)
lines.append("TEST HISTORY DIVERSITY REPORT")
lines.append("=" * 50)
lines.append(f"Total Tests: {self.stats['total_tests']}")
lines.append(f"Successful Tests: {self.stats['successful_tests']}")
lines.append(f"Total Patterns: {len(self.patterns)}")
lines.append("")
# 常见模式
lines.append("TOP 5 COMMON PATTERNS:")
common = self.get_common_patterns(5)
for i, (pattern, count) in enumerate(common, 1):
lines.append(f" {i}. {pattern[:40]}... (x{count})")
# 过度使用的模式
overused = self.get_overused_patterns()
if overused:
lines.append("")
lines.append("OVERUSED PATTERNS (need diversification):")
for p in overused[:5]:
lines.append(f" - {p.signal_name}: {p.pattern[:30]}... (used {p.count} times)")
lines.append("=" * 50)
return "\n".join(lines)
# ==================== 持久化 ====================
def save(self, filepath: str = None):
"""保存历史记录到文件"""
filepath = filepath or self.history_file
if not filepath:
return
# 手动构建可序列化的数据结构
records_data = []
for r in self.records:
record_dict = {
'test_id': r.test_id,
'code': r.code,
'input_sequences': [],
'target_function': r.target_function,
'covered_lines': r.covered_lines,
'covered_functions': r.covered_functions,
'coverage_score': r.coverage_score,
'diversity_scores': r.diversity_scores,
'iteration': r.iteration,
'timestamp': r.timestamp,
'success': r.success
}
# 手动转换 InputSequence 对象
for seq in r.input_sequences:
record_dict['input_sequences'].append({
'signal_name': seq.signal_name,
'values': seq.values
})
records_data.append(record_dict)
data = {
'records': records_data,
'stats': self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Test history saved to {filepath}")
def load(self, filepath: str):
"""从文件加载历史记录"""
if not os.path.exists(filepath):
return
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.records = []
for r in data.get('records', []):
sequences = [
InputSequence(**s) for s in r.get('input_sequences', [])
]
record = TestRecord(
test_id=r['test_id'],
code=r['code'],
input_sequences=sequences,
target_function=r.get('target_function', ''),
covered_lines=r.get('covered_lines', []),
covered_functions=r.get('covered_functions', []),
coverage_score=r.get('coverage_score', 0.0),
iteration=r.get('iteration', 0),
timestamp=r.get('timestamp', ''),
success=r.get('success', False)
)
self.records.append(record)
self._update_patterns(record)
self.stats = data.get('stats', self.stats)
logger.info(f"Loaded {len(self.records)} test records from {filepath}")
# ============================================================================
# 便捷函数
# ============================================================================
def create_test_history(history_file: str = None) -> TestHistoryManager:
"""创建测试历史管理器"""
return TestHistoryManager(history_file=history_file)