Files
test_item_gen/modules/prompt_manager.py
2026-02-04 14:38:52 +08:00

179 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# @line_count 150
"""Prompt模板管理模块"""
import os
import yaml
from pathlib import Path
from typing import Dict, Optional
from .test_standard_manager import TestStandardManager
class PromptManager:
"""Prompt模板管理器"""
def __init__(self, config_path: Optional[str] = None,
use_standards: bool = True,
standard_manager: Optional[TestStandardManager] = None):
"""
初始化Prompt管理器
Args:
config_path: Prompt配置文件路径默认为config/default_prompt.yaml
use_standards: 是否使用测试规范默认True
standard_manager: 测试规范管理器如果为None且use_standards=True则创建新实例
"""
if config_path is None:
# 获取项目根目录
current_dir = Path(__file__).parent.parent
config_path = current_dir / "config" / "default_prompt.yaml"
self.config_path = Path(config_path)
self.prompts: Dict[str, str] = {}
self.use_standards = use_standards
self.standard_manager = standard_manager or (TestStandardManager() if use_standards else None)
# 用户自定义Prompt缓存 {requirement_id或func_point_id: custom_prompt}
self.custom_prompts_cache: Dict[str, str] = {}
self.load_prompts()
def load_prompts(self):
"""加载Prompt模板"""
if not self.config_path.exists():
raise FileNotFoundError(f"Prompt配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.prompts = {
'test_item': config.get('test_item_prompt', ''),
'test_case': config.get('test_case_prompt', ''),
'batch': config.get('batch_generation_prompt', '')
}
def get_prompt(self, prompt_type: str) -> str:
"""
获取指定类型的Prompt模板
Args:
prompt_type: Prompt类型 ('test_item', 'test_case', 'batch')
Returns:
Prompt模板字符串
"""
if prompt_type not in self.prompts:
raise ValueError(f"未知的Prompt类型: {prompt_type}")
return self.prompts[prompt_type]
def format_prompt(self, prompt_type: str, **kwargs) -> str:
"""
格式化Prompt模板替换占位符
Args:
prompt_type: Prompt类型
**kwargs: 要替换的变量如果包含requirement字典且use_standards=True则使用规范化Prompt
Returns:
格式化后的Prompt字符串
"""
# 检查是否有自定义prompt优先级最高
if 'requirement' in kwargs:
requirement = kwargs['requirement']
req_id = requirement.get('requirement_id') or requirement.get('description', '')[:50]
# 调试输出
print(f"\n[PromptManager] 查找自定义Prompt:")
print(f" - requirement_id: {req_id}")
print(f" - 缓存中的Keys: {list(self.custom_prompts_cache.keys())}")
print(f" - 是否匹配: {req_id in self.custom_prompts_cache}")
# 如果有该功能点的自定义prompt直接返回
if req_id in self.custom_prompts_cache:
custom_prompt = self.custom_prompts_cache[req_id]
print(f" ✅ 使用自定义Prompt (长度: {len(custom_prompt)} 字符)")
return custom_prompt
else:
print(f" ⚠️ 未找到自定义Prompt将使用默认生成")
# 如果使用测试规范且提供了requirement使用规范化Prompt
if self.use_standards and self.standard_manager and 'requirement' in kwargs:
requirement = kwargs['requirement']
standard_ids = kwargs.get('standard_ids') # 可选:指定测试规范
print(f" 📋 使用测试规范生成Prompt")
return self.standard_manager.build_prompt(requirement, standard_ids)
# 否则使用传统方式
prompt = self.get_prompt(prompt_type)
print(f" 📄 使用传统Prompt模板")
# 替换占位符 {variable_name}
try:
return prompt.format(**kwargs)
except KeyError as e:
raise ValueError(f"Prompt模板缺少必需的变量: {e}")
def load_custom_prompt(self, file_path: str, prompt_type: str):
"""
从文件加载自定义Prompt模板
Args:
file_path: 自定义Prompt文件路径
prompt_type: Prompt类型
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"自定义Prompt文件不存在: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
custom_prompt = f.read()
self.prompts[prompt_type] = custom_prompt
def set_custom_prompt(self, func_point_id: str, custom_prompt: str):
"""
设置功能点的自定义Prompt
Args:
func_point_id: 功能点ID或需求ID
custom_prompt: 自定义Prompt内容
"""
self.custom_prompts_cache[func_point_id] = custom_prompt
def clear_custom_prompts(self):
"""清空所有自定义Prompt缓存"""
self.custom_prompts_cache.clear()
def get_custom_prompt(self, func_point_id: str) -> Optional[str]:
"""
获取功能点的自定义Prompt
Args:
func_point_id: 功能点ID或需求ID
Returns:
自定义Prompt内容如果没有则返回None
"""
return self.custom_prompts_cache.get(func_point_id)
def save_custom_prompt(self, prompt_type: str, prompt_content: str, file_path: Optional[str] = None):
"""
保存自定义Prompt到文件
Args:
prompt_type: Prompt类型
prompt_content: Prompt内容
file_path: 保存路径默认为templates目录
"""
if file_path is None:
current_dir = Path(__file__).parent.parent
templates_dir = current_dir / "templates"
templates_dir.mkdir(exist_ok=True)
file_path = templates_dir / f"{prompt_type}_prompt.txt"
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(prompt_content)
return str(file_path)