370 lines
14 KiB
Python
370 lines
14 KiB
Python
# @line_count 200
|
||
"""大模型API客户端模块"""
|
||
import os
|
||
import yaml
|
||
import json
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Dict, Optional, List, Any
|
||
import httpx
|
||
|
||
|
||
class APIClient:
|
||
"""大模型API客户端,支持多个提供商"""
|
||
|
||
def __init__(self, config_path: Optional[str] = None):
|
||
"""
|
||
初始化API客户端
|
||
|
||
Args:
|
||
config_path: API配置文件路径,默认为config/api_config.yaml
|
||
"""
|
||
if config_path is None:
|
||
current_dir = Path(__file__).parent.parent
|
||
config_path = current_dir / "config" / "api_config.yaml"
|
||
|
||
self.config_path = Path(config_path)
|
||
self.config: Dict[str, Any] = {}
|
||
self.current_provider: str = ""
|
||
self.load_config()
|
||
|
||
def load_config(self):
|
||
"""加载API配置"""
|
||
if not self.config_path.exists():
|
||
raise FileNotFoundError(f"API配置文件不存在: {self.config_path}")
|
||
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
self.config = yaml.safe_load(f)
|
||
|
||
self.current_provider = self.config.get('default_provider', 'deepseek')
|
||
|
||
def set_provider(self, provider: str):
|
||
"""
|
||
设置当前使用的API提供商
|
||
|
||
Args:
|
||
provider: 提供商名称 (deepseek, qianwen, openai, openrouter)
|
||
"""
|
||
if provider not in self.config.get('providers', {}):
|
||
raise ValueError(f"不支持的API提供商: {provider}")
|
||
self.current_provider = provider
|
||
|
||
def get_provider_config(self) -> Dict[str, Any]:
|
||
"""获取当前提供商的配置"""
|
||
providers = self.config.get('providers', {})
|
||
if self.current_provider not in providers:
|
||
raise ValueError(f"提供商配置不存在: {self.current_provider}")
|
||
return providers[self.current_provider]
|
||
|
||
def _call_deepseek_api(self, prompt: str, **kwargs) -> str:
|
||
"""调用DeepSeek API"""
|
||
config = self.get_provider_config()
|
||
api_key = config.get('api_key') or os.getenv('DEEPSEEK_API_KEY', '')
|
||
|
||
if not api_key:
|
||
raise ValueError("DeepSeek API密钥未配置")
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {api_key}'
|
||
}
|
||
|
||
data = {
|
||
'model': config.get('model', 'deepseek-chat'),
|
||
'messages': [
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': config.get('temperature', 0.7),
|
||
'max_tokens': config.get('max_tokens', 4000)
|
||
}
|
||
|
||
# 使用精细化超时配置:连接10秒,读取180秒
|
||
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
|
||
|
||
with httpx.Client(timeout=timeout_config) as client:
|
||
response = client.post(
|
||
config.get('base_url'),
|
||
headers=headers,
|
||
json=data
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
return result['choices'][0]['message']['content']
|
||
|
||
# def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
|
||
# """调用通义千问API"""
|
||
# config = self.get_provider_config()
|
||
# api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
|
||
|
||
# if not api_key:
|
||
# raise ValueError("通义千问API密钥未配置")
|
||
|
||
# headers = {
|
||
# 'Content-Type': 'application/json',
|
||
# 'Authorization': f'Bearer {api_key}'
|
||
# }
|
||
|
||
# # 通义千问API格式(DashScope)
|
||
# data = {
|
||
# 'model': config.get('model', 'qwen-turbo'),
|
||
# 'input': {
|
||
# 'messages': [
|
||
# {'role': 'user', 'content': prompt}
|
||
# ]
|
||
# },
|
||
# 'parameters': {
|
||
# 'temperature': config.get('temperature', 0.7),
|
||
# 'max_tokens': config.get('max_tokens', 4000)
|
||
# }
|
||
# }
|
||
|
||
# with httpx.Client(timeout=60.0) as client:
|
||
# response = client.post(
|
||
# config.get('base_url'),
|
||
# headers=headers,
|
||
# json=data
|
||
# )
|
||
# response.raise_for_status()
|
||
# result = response.json()
|
||
|
||
# # 通义千问的响应格式适配
|
||
# if 'output' in result:
|
||
# output = result['output']
|
||
# if 'choices' in output and len(output['choices']) > 0:
|
||
# return output['choices'][0]['message']['content']
|
||
# elif 'text' in output:
|
||
# return output['text']
|
||
# elif 'choices' in result and len(result['choices']) > 0:
|
||
# return result['choices'][0]['message']['content']
|
||
|
||
# # 如果都不匹配,返回整个结果(用于调试)
|
||
# raise ValueError(f"无法解析通义千问API响应: {result}")
|
||
|
||
def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
|
||
"""调用通义千问API(兼容模式)"""
|
||
config = self.get_provider_config()
|
||
api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
|
||
|
||
if not api_key:
|
||
raise ValueError("通义千问API密钥未配置")
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {api_key}'
|
||
}
|
||
|
||
# 系统提示,提高生成质量(模拟Web端行为)
|
||
# 针对航空航天软件测试的专业化system message
|
||
system_message = """你是一位拥有20年经验的航空航天软件测试专家。你的任务是生成高质量、可执行的测试用例。
|
||
|
||
【强制要求 - 必须遵守】
|
||
你必须为以下8种黑盒测试方法各生成至少1个测试用例:
|
||
1. 等价类划分 - 有效/无效输入类测试
|
||
2. 边界值分析 - min-1, min, min+1, max-1, max, max+1
|
||
3. 错误推测法 - 字节序错误、超长帧、非法字符
|
||
4. 因果图法 - 多条件联动触发测试
|
||
5. 决策表测试 - 条件组合规则覆盖
|
||
6. 状态转换法 - 状态机路径覆盖
|
||
7. 场景法 - 端到端业务流程测试
|
||
8. 随机测试 - 随机合法输入测试
|
||
|
||
【输出要求】
|
||
- 每个test_case的name字段必须用方括号标注测试方法,如:正常指令解析-[等价类划分]
|
||
- 必须生成至少8个测试用例,确保8种方法全覆盖
|
||
- 测试步骤要具体,包含CAN帧ID、指令编码、响应时间等
|
||
- 只输出JSON,不要任何解释性文字
|
||
- 严格遵循用户指定的JSON格式"""
|
||
|
||
# 使用OpenAI兼容格式,包含system message
|
||
data = {
|
||
'model': config.get('model', 'qwen-max'),
|
||
'messages': [
|
||
{'role': 'system', 'content': system_message},
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': config.get('temperature', 0.3),
|
||
'max_tokens': config.get('max_tokens', 8192)
|
||
}
|
||
|
||
|
||
print("\n发送Prompt: " + prompt)
|
||
|
||
# 使用精细化超时配置:连接10秒,读取180秒(3分钟)
|
||
# 测试用例生成需要较长时间,特别是复杂的结构化输出
|
||
timeout_config = httpx.Timeout(
|
||
connect=10.0, # 连接超时:10秒
|
||
read=180.0, # 读取超时:180秒(适合复杂生成任务)
|
||
write=10.0, # 写入超时:10秒
|
||
pool=10.0 # 连接池超时:10秒
|
||
)
|
||
|
||
with httpx.Client(timeout=timeout_config) as client:
|
||
response = client.post(
|
||
config.get('base_url'),
|
||
headers=headers,
|
||
json=data
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# OpenAI兼容格式的响应
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
# 调试输出:显示API返回的完整内容
|
||
print("\n" + "="*60)
|
||
print("[API响应] 完整返回内容:")
|
||
print("="*60)
|
||
print(f"响应长度: {len(content)} 字符")
|
||
print("-"*60)
|
||
# 显示前2000字符
|
||
if len(content) > 2000:
|
||
print(content[:2000])
|
||
print(f"\n... [截断,还有 {len(content)-2000} 字符]")
|
||
else:
|
||
print(content)
|
||
print("="*60 + "\n")
|
||
|
||
return content
|
||
|
||
raise ValueError(f"无法解析通义千问API响应: {result}")
|
||
|
||
|
||
def _call_openai_api(self, prompt: str, **kwargs) -> str:
|
||
"""调用OpenAI API"""
|
||
config = self.get_provider_config()
|
||
api_key = config.get('api_key') or os.getenv('OPENAI_API_KEY', '')
|
||
|
||
if not api_key:
|
||
raise ValueError("OpenAI API密钥未配置")
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {api_key}'
|
||
}
|
||
|
||
data = {
|
||
'model': config.get('model', 'gpt-3.5-turbo'),
|
||
'messages': [
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': config.get('temperature', 0.7),
|
||
'max_tokens': config.get('max_tokens', 4000)
|
||
}
|
||
|
||
# 使用精细化超时配置:连接10秒,读取180秒
|
||
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
|
||
|
||
with httpx.Client(timeout=timeout_config) as client:
|
||
response = client.post(
|
||
config.get('base_url'),
|
||
headers=headers,
|
||
json=data
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
return result['choices'][0]['message']['content']
|
||
|
||
def _call_openrouter_api(self, prompt: str, **kwargs) -> str:
|
||
"""调用OpenRouter API"""
|
||
config = self.get_provider_config()
|
||
api_key = config.get('api_key') or os.getenv('OPENROUTER_API_KEY', '')
|
||
|
||
if not api_key:
|
||
raise ValueError("OpenRouter API密钥未配置")
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {api_key}'
|
||
}
|
||
|
||
# 添加可选的 HTTP-Referer 和 X-Title headers
|
||
http_referer = config.get('http_referer', '')
|
||
if http_referer:
|
||
headers['HTTP-Referer'] = http_referer
|
||
|
||
x_title = config.get('x_title', '')
|
||
if x_title:
|
||
headers['X-Title'] = x_title
|
||
|
||
data = {
|
||
'model': config.get('model', 'allenai/molmo-2-8b:free'),
|
||
'messages': [
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': config.get('temperature', 0.7),
|
||
'max_tokens': config.get('max_tokens', 4000)
|
||
}
|
||
|
||
# 使用精细化超时配置:连接10秒,读取180秒
|
||
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
|
||
|
||
with httpx.Client(timeout=timeout_config) as client:
|
||
response = client.post(
|
||
config.get('base_url'),
|
||
headers=headers,
|
||
json=data
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# OpenRouter 使用 OpenAI 兼容格式
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
return result['choices'][0]['message']['content']
|
||
|
||
raise ValueError(f"无法解析OpenRouter API响应: {result}")
|
||
|
||
def call_api(self, prompt: str, max_retries: int = 3, retry_delay: int = 2) -> str:
|
||
"""
|
||
调用大模型API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
max_retries: 最大重试次数
|
||
retry_delay: 重试延迟(秒)
|
||
|
||
Returns:
|
||
API返回的文本内容
|
||
"""
|
||
provider_methods = {
|
||
'deepseek': self._call_deepseek_api,
|
||
'qianwen': self._call_qianwen_api,
|
||
'openai': self._call_openai_api,
|
||
'openrouter': self._call_openrouter_api
|
||
}
|
||
|
||
method = provider_methods.get(self.current_provider)
|
||
if not method:
|
||
raise ValueError(f"不支持的API提供商: {self.current_provider}")
|
||
|
||
last_error = None
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return method(prompt)
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < max_retries - 1:
|
||
time.sleep(retry_delay * (attempt + 1))
|
||
else:
|
||
raise Exception(f"API调用失败(重试{max_retries}次): {str(e)}") from last_error
|
||
|
||
raise Exception(f"API调用失败: {str(last_error)}")
|
||
|
||
def update_api_key(self, provider: str, api_key: str):
|
||
"""
|
||
更新API密钥
|
||
|
||
Args:
|
||
provider: 提供商名称
|
||
api_key: API密钥
|
||
"""
|
||
if provider not in self.config.get('providers', {}):
|
||
raise ValueError(f"不支持的API提供商: {provider}")
|
||
|
||
self.config['providers'][provider]['api_key'] = api_key
|
||
|
||
# 保存到配置文件
|
||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||
yaml.dump(self.config, f, allow_unicode=True, default_flow_style=False)
|
||
|