init rep
This commit is contained in:
178
modules/prompt_manager.py
Normal file
178
modules/prompt_manager.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# @line_count 150
|
||||
"""Prompt模板管理模块"""
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from .test_standard_manager import TestStandardManager
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Prompt模板管理器"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None,
|
||||
use_standards: bool = True,
|
||||
standard_manager: Optional[TestStandardManager] = None):
|
||||
"""
|
||||
初始化Prompt管理器
|
||||
|
||||
Args:
|
||||
config_path: Prompt配置文件路径,默认为config/default_prompt.yaml
|
||||
use_standards: 是否使用测试规范(默认True)
|
||||
standard_manager: 测试规范管理器,如果为None且use_standards=True则创建新实例
|
||||
"""
|
||||
if config_path is None:
|
||||
# 获取项目根目录
|
||||
current_dir = Path(__file__).parent.parent
|
||||
config_path = current_dir / "config" / "default_prompt.yaml"
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self.prompts: Dict[str, str] = {}
|
||||
self.use_standards = use_standards
|
||||
self.standard_manager = standard_manager or (TestStandardManager() if use_standards else None)
|
||||
# 用户自定义Prompt缓存 {requirement_id或func_point_id: custom_prompt}
|
||||
self.custom_prompts_cache: Dict[str, str] = {}
|
||||
self.load_prompts()
|
||||
|
||||
def load_prompts(self):
|
||||
"""加载Prompt模板"""
|
||||
if not self.config_path.exists():
|
||||
raise FileNotFoundError(f"Prompt配置文件不存在: {self.config_path}")
|
||||
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.prompts = {
|
||||
'test_item': config.get('test_item_prompt', ''),
|
||||
'test_case': config.get('test_case_prompt', ''),
|
||||
'batch': config.get('batch_generation_prompt', '')
|
||||
}
|
||||
|
||||
def get_prompt(self, prompt_type: str) -> str:
|
||||
"""
|
||||
获取指定类型的Prompt模板
|
||||
|
||||
Args:
|
||||
prompt_type: Prompt类型 ('test_item', 'test_case', 'batch')
|
||||
|
||||
Returns:
|
||||
Prompt模板字符串
|
||||
"""
|
||||
if prompt_type not in self.prompts:
|
||||
raise ValueError(f"未知的Prompt类型: {prompt_type}")
|
||||
|
||||
return self.prompts[prompt_type]
|
||||
|
||||
def format_prompt(self, prompt_type: str, **kwargs) -> str:
|
||||
"""
|
||||
格式化Prompt模板,替换占位符
|
||||
|
||||
Args:
|
||||
prompt_type: Prompt类型
|
||||
**kwargs: 要替换的变量,如果包含requirement字典且use_standards=True,则使用规范化Prompt
|
||||
|
||||
Returns:
|
||||
格式化后的Prompt字符串
|
||||
"""
|
||||
# 检查是否有自定义prompt(优先级最高)
|
||||
if 'requirement' in kwargs:
|
||||
requirement = kwargs['requirement']
|
||||
req_id = requirement.get('requirement_id') or requirement.get('description', '')[:50]
|
||||
|
||||
# 调试输出
|
||||
print(f"\n[PromptManager] 查找自定义Prompt:")
|
||||
print(f" - requirement_id: {req_id}")
|
||||
print(f" - 缓存中的Keys: {list(self.custom_prompts_cache.keys())}")
|
||||
print(f" - 是否匹配: {req_id in self.custom_prompts_cache}")
|
||||
|
||||
# 如果有该功能点的自定义prompt,直接返回
|
||||
if req_id in self.custom_prompts_cache:
|
||||
custom_prompt = self.custom_prompts_cache[req_id]
|
||||
print(f" ✅ 使用自定义Prompt (长度: {len(custom_prompt)} 字符)")
|
||||
return custom_prompt
|
||||
else:
|
||||
print(f" ⚠️ 未找到自定义Prompt,将使用默认生成")
|
||||
|
||||
# 如果使用测试规范且提供了requirement,使用规范化Prompt
|
||||
if self.use_standards and self.standard_manager and 'requirement' in kwargs:
|
||||
requirement = kwargs['requirement']
|
||||
standard_ids = kwargs.get('standard_ids') # 可选:指定测试规范
|
||||
print(f" 📋 使用测试规范生成Prompt")
|
||||
return self.standard_manager.build_prompt(requirement, standard_ids)
|
||||
|
||||
# 否则使用传统方式
|
||||
prompt = self.get_prompt(prompt_type)
|
||||
print(f" 📄 使用传统Prompt模板")
|
||||
|
||||
# 替换占位符 {variable_name}
|
||||
try:
|
||||
return prompt.format(**kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Prompt模板缺少必需的变量: {e}")
|
||||
|
||||
def load_custom_prompt(self, file_path: str, prompt_type: str):
|
||||
"""
|
||||
从文件加载自定义Prompt模板
|
||||
|
||||
Args:
|
||||
file_path: 自定义Prompt文件路径
|
||||
prompt_type: Prompt类型
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"自定义Prompt文件不存在: {file_path}")
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
custom_prompt = f.read()
|
||||
|
||||
self.prompts[prompt_type] = custom_prompt
|
||||
|
||||
def set_custom_prompt(self, func_point_id: str, custom_prompt: str):
|
||||
"""
|
||||
设置功能点的自定义Prompt
|
||||
|
||||
Args:
|
||||
func_point_id: 功能点ID或需求ID
|
||||
custom_prompt: 自定义Prompt内容
|
||||
"""
|
||||
self.custom_prompts_cache[func_point_id] = custom_prompt
|
||||
|
||||
def clear_custom_prompts(self):
|
||||
"""清空所有自定义Prompt缓存"""
|
||||
self.custom_prompts_cache.clear()
|
||||
|
||||
def get_custom_prompt(self, func_point_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取功能点的自定义Prompt
|
||||
|
||||
Args:
|
||||
func_point_id: 功能点ID或需求ID
|
||||
|
||||
Returns:
|
||||
自定义Prompt内容,如果没有则返回None
|
||||
"""
|
||||
return self.custom_prompts_cache.get(func_point_id)
|
||||
|
||||
def save_custom_prompt(self, prompt_type: str, prompt_content: str, file_path: Optional[str] = None):
|
||||
"""
|
||||
保存自定义Prompt到文件
|
||||
|
||||
Args:
|
||||
prompt_type: Prompt类型
|
||||
prompt_content: Prompt内容
|
||||
file_path: 保存路径,默认为templates目录
|
||||
"""
|
||||
if file_path is None:
|
||||
current_dir = Path(__file__).parent.parent
|
||||
templates_dir = current_dir / "templates"
|
||||
templates_dir.mkdir(exist_ok=True)
|
||||
file_path = templates_dir / f"{prompt_type}_prompt.txt"
|
||||
|
||||
file_path = Path(file_path)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(prompt_content)
|
||||
|
||||
return str(file_path)
|
||||
|
||||
Reference in New Issue
Block a user