Files

370 lines
14 KiB
Python
Raw Permalink Normal View History

2026-02-04 14:38:52 +08:00
# @line_count 200
"""大模型API客户端模块"""
import os
import yaml
import json
import time
from pathlib import Path
from typing import Dict, Optional, List, Any
import httpx
class APIClient:
"""大模型API客户端支持多个提供商"""
def __init__(self, config_path: Optional[str] = None):
"""
初始化API客户端
Args:
config_path: API配置文件路径默认为config/api_config.yaml
"""
if config_path is None:
current_dir = Path(__file__).parent.parent
config_path = current_dir / "config" / "api_config.yaml"
self.config_path = Path(config_path)
self.config: Dict[str, Any] = {}
self.current_provider: str = ""
self.load_config()
def load_config(self):
"""加载API配置"""
if not self.config_path.exists():
raise FileNotFoundError(f"API配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.current_provider = self.config.get('default_provider', 'deepseek')
def set_provider(self, provider: str):
"""
设置当前使用的API提供商
Args:
provider: 提供商名称 (deepseek, qianwen, openai, openrouter)
"""
if provider not in self.config.get('providers', {}):
raise ValueError(f"不支持的API提供商: {provider}")
self.current_provider = provider
def get_provider_config(self) -> Dict[str, Any]:
"""获取当前提供商的配置"""
providers = self.config.get('providers', {})
if self.current_provider not in providers:
raise ValueError(f"提供商配置不存在: {self.current_provider}")
return providers[self.current_provider]
def _call_deepseek_api(self, prompt: str, **kwargs) -> str:
"""调用DeepSeek API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('DEEPSEEK_API_KEY', '')
if not api_key:
raise ValueError("DeepSeek API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': config.get('model', 'deepseek-chat'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
# def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
# """调用通义千问API"""
# config = self.get_provider_config()
# api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
# if not api_key:
# raise ValueError("通义千问API密钥未配置")
# headers = {
# 'Content-Type': 'application/json',
# 'Authorization': f'Bearer {api_key}'
# }
# # 通义千问API格式DashScope
# data = {
# 'model': config.get('model', 'qwen-turbo'),
# 'input': {
# 'messages': [
# {'role': 'user', 'content': prompt}
# ]
# },
# 'parameters': {
# 'temperature': config.get('temperature', 0.7),
# 'max_tokens': config.get('max_tokens', 4000)
# }
# }
# with httpx.Client(timeout=60.0) as client:
# response = client.post(
# config.get('base_url'),
# headers=headers,
# json=data
# )
# response.raise_for_status()
# result = response.json()
# # 通义千问的响应格式适配
# if 'output' in result:
# output = result['output']
# if 'choices' in output and len(output['choices']) > 0:
# return output['choices'][0]['message']['content']
# elif 'text' in output:
# return output['text']
# elif 'choices' in result and len(result['choices']) > 0:
# return result['choices'][0]['message']['content']
# # 如果都不匹配,返回整个结果(用于调试)
# raise ValueError(f"无法解析通义千问API响应: {result}")
def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
"""调用通义千问API兼容模式"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
if not api_key:
raise ValueError("通义千问API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
# 系统提示提高生成质量模拟Web端行为
# 针对航空航天软件测试的专业化system message
system_message = """你是一位拥有20年经验的航空航天软件测试专家。你的任务是生成高质量、可执行的测试用例。
强制要求 - 必须遵守
你必须为以下8种黑盒测试方法各生成至少1个测试用例
1. 等价类划分 - 有效/无效输入类测试
2. 边界值分析 - min-1, min, min+1, max-1, max, max+1
3. 错误推测法 - 字节序错误超长帧非法字符
4. 因果图法 - 多条件联动触发测试
5. 决策表测试 - 条件组合规则覆盖
6. 状态转换法 - 状态机路径覆盖
7. 场景法 - 端到端业务流程测试
8. 随机测试 - 随机合法输入测试
输出要求
- 每个test_case的name字段必须用方括号标注测试方法正常指令解析-[等价类划分]
- 必须生成至少8个测试用例确保8种方法全覆盖
- 测试步骤要具体包含CAN帧ID指令编码响应时间等
- 只输出JSON不要任何解释性文字
- 严格遵循用户指定的JSON格式"""
# 使用OpenAI兼容格式包含system message
data = {
'model': config.get('model', 'qwen-max'),
'messages': [
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.3),
'max_tokens': config.get('max_tokens', 8192)
}
print("\n发送Prompt: " + prompt)
# 使用精细化超时配置连接10秒读取180秒3分钟
# 测试用例生成需要较长时间,特别是复杂的结构化输出
timeout_config = httpx.Timeout(
connect=10.0, # 连接超时10秒
read=180.0, # 读取超时180秒适合复杂生成任务
write=10.0, # 写入超时10秒
pool=10.0 # 连接池超时10秒
)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
# OpenAI兼容格式的响应
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 调试输出显示API返回的完整内容
print("\n" + "="*60)
print("[API响应] 完整返回内容:")
print("="*60)
print(f"响应长度: {len(content)} 字符")
print("-"*60)
# 显示前2000字符
if len(content) > 2000:
print(content[:2000])
print(f"\n... [截断,还有 {len(content)-2000} 字符]")
else:
print(content)
print("="*60 + "\n")
return content
raise ValueError(f"无法解析通义千问API响应: {result}")
def _call_openai_api(self, prompt: str, **kwargs) -> str:
"""调用OpenAI API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('OPENAI_API_KEY', '')
if not api_key:
raise ValueError("OpenAI API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': config.get('model', 'gpt-3.5-turbo'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
def _call_openrouter_api(self, prompt: str, **kwargs) -> str:
"""调用OpenRouter API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('OPENROUTER_API_KEY', '')
if not api_key:
raise ValueError("OpenRouter API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
# 添加可选的 HTTP-Referer 和 X-Title headers
http_referer = config.get('http_referer', '')
if http_referer:
headers['HTTP-Referer'] = http_referer
x_title = config.get('x_title', '')
if x_title:
headers['X-Title'] = x_title
data = {
'model': config.get('model', 'allenai/molmo-2-8b:free'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
# OpenRouter 使用 OpenAI 兼容格式
if 'choices' in result and len(result['choices']) > 0:
return result['choices'][0]['message']['content']
raise ValueError(f"无法解析OpenRouter API响应: {result}")
def call_api(self, prompt: str, max_retries: int = 3, retry_delay: int = 2) -> str:
"""
调用大模型API
Args:
prompt: 提示词
max_retries: 最大重试次数
retry_delay: 重试延迟
Returns:
API返回的文本内容
"""
provider_methods = {
'deepseek': self._call_deepseek_api,
'qianwen': self._call_qianwen_api,
'openai': self._call_openai_api,
'openrouter': self._call_openrouter_api
}
method = provider_methods.get(self.current_provider)
if not method:
raise ValueError(f"不支持的API提供商: {self.current_provider}")
last_error = None
for attempt in range(max_retries):
try:
return method(prompt)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise Exception(f"API调用失败重试{max_retries}次): {str(e)}") from last_error
raise Exception(f"API调用失败: {str(last_error)}")
def update_api_key(self, provider: str, api_key: str):
"""
更新API密钥
Args:
provider: 提供商名称
api_key: API密钥
"""
if provider not in self.config.get('providers', {}):
raise ValueError(f"不支持的API提供商: {provider}")
self.config['providers'][provider]['api_key'] = api_key
# 保存到配置文件
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, allow_unicode=True, default_flow_style=False)