Files
test_item_gen/modules/api_client.py
2026-02-04 14:38:52 +08:00

370 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# @line_count 200
"""大模型API客户端模块"""
import os
import yaml
import json
import time
from pathlib import Path
from typing import Dict, Optional, List, Any
import httpx
class APIClient:
"""大模型API客户端支持多个提供商"""
def __init__(self, config_path: Optional[str] = None):
"""
初始化API客户端
Args:
config_path: API配置文件路径默认为config/api_config.yaml
"""
if config_path is None:
current_dir = Path(__file__).parent.parent
config_path = current_dir / "config" / "api_config.yaml"
self.config_path = Path(config_path)
self.config: Dict[str, Any] = {}
self.current_provider: str = ""
self.load_config()
def load_config(self):
"""加载API配置"""
if not self.config_path.exists():
raise FileNotFoundError(f"API配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.current_provider = self.config.get('default_provider', 'deepseek')
def set_provider(self, provider: str):
"""
设置当前使用的API提供商
Args:
provider: 提供商名称 (deepseek, qianwen, openai, openrouter)
"""
if provider not in self.config.get('providers', {}):
raise ValueError(f"不支持的API提供商: {provider}")
self.current_provider = provider
def get_provider_config(self) -> Dict[str, Any]:
"""获取当前提供商的配置"""
providers = self.config.get('providers', {})
if self.current_provider not in providers:
raise ValueError(f"提供商配置不存在: {self.current_provider}")
return providers[self.current_provider]
def _call_deepseek_api(self, prompt: str, **kwargs) -> str:
"""调用DeepSeek API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('DEEPSEEK_API_KEY', '')
if not api_key:
raise ValueError("DeepSeek API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': config.get('model', 'deepseek-chat'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
# def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
# """调用通义千问API"""
# config = self.get_provider_config()
# api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
# if not api_key:
# raise ValueError("通义千问API密钥未配置")
# headers = {
# 'Content-Type': 'application/json',
# 'Authorization': f'Bearer {api_key}'
# }
# # 通义千问API格式DashScope
# data = {
# 'model': config.get('model', 'qwen-turbo'),
# 'input': {
# 'messages': [
# {'role': 'user', 'content': prompt}
# ]
# },
# 'parameters': {
# 'temperature': config.get('temperature', 0.7),
# 'max_tokens': config.get('max_tokens', 4000)
# }
# }
# with httpx.Client(timeout=60.0) as client:
# response = client.post(
# config.get('base_url'),
# headers=headers,
# json=data
# )
# response.raise_for_status()
# result = response.json()
# # 通义千问的响应格式适配
# if 'output' in result:
# output = result['output']
# if 'choices' in output and len(output['choices']) > 0:
# return output['choices'][0]['message']['content']
# elif 'text' in output:
# return output['text']
# elif 'choices' in result and len(result['choices']) > 0:
# return result['choices'][0]['message']['content']
# # 如果都不匹配,返回整个结果(用于调试)
# raise ValueError(f"无法解析通义千问API响应: {result}")
def _call_qianwen_api(self, prompt: str, **kwargs) -> str:
"""调用通义千问API兼容模式"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('QIANWEN_API_KEY', '')
if not api_key:
raise ValueError("通义千问API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
# 系统提示提高生成质量模拟Web端行为
# 针对航空航天软件测试的专业化system message
system_message = """你是一位拥有20年经验的航空航天软件测试专家。你的任务是生成高质量、可执行的测试用例。
【强制要求 - 必须遵守】
你必须为以下8种黑盒测试方法各生成至少1个测试用例
1. 等价类划分 - 有效/无效输入类测试
2. 边界值分析 - min-1, min, min+1, max-1, max, max+1
3. 错误推测法 - 字节序错误、超长帧、非法字符
4. 因果图法 - 多条件联动触发测试
5. 决策表测试 - 条件组合规则覆盖
6. 状态转换法 - 状态机路径覆盖
7. 场景法 - 端到端业务流程测试
8. 随机测试 - 随机合法输入测试
【输出要求】
- 每个test_case的name字段必须用方括号标注测试方法正常指令解析-[等价类划分]
- 必须生成至少8个测试用例确保8种方法全覆盖
- 测试步骤要具体包含CAN帧ID、指令编码、响应时间等
- 只输出JSON不要任何解释性文字
- 严格遵循用户指定的JSON格式"""
# 使用OpenAI兼容格式包含system message
data = {
'model': config.get('model', 'qwen-max'),
'messages': [
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.3),
'max_tokens': config.get('max_tokens', 8192)
}
print("\n发送Prompt: " + prompt)
# 使用精细化超时配置连接10秒读取180秒3分钟
# 测试用例生成需要较长时间,特别是复杂的结构化输出
timeout_config = httpx.Timeout(
connect=10.0, # 连接超时10秒
read=180.0, # 读取超时180秒适合复杂生成任务
write=10.0, # 写入超时10秒
pool=10.0 # 连接池超时10秒
)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
# OpenAI兼容格式的响应
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 调试输出显示API返回的完整内容
print("\n" + "="*60)
print("[API响应] 完整返回内容:")
print("="*60)
print(f"响应长度: {len(content)} 字符")
print("-"*60)
# 显示前2000字符
if len(content) > 2000:
print(content[:2000])
print(f"\n... [截断,还有 {len(content)-2000} 字符]")
else:
print(content)
print("="*60 + "\n")
return content
raise ValueError(f"无法解析通义千问API响应: {result}")
def _call_openai_api(self, prompt: str, **kwargs) -> str:
"""调用OpenAI API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('OPENAI_API_KEY', '')
if not api_key:
raise ValueError("OpenAI API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': config.get('model', 'gpt-3.5-turbo'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
def _call_openrouter_api(self, prompt: str, **kwargs) -> str:
"""调用OpenRouter API"""
config = self.get_provider_config()
api_key = config.get('api_key') or os.getenv('OPENROUTER_API_KEY', '')
if not api_key:
raise ValueError("OpenRouter API密钥未配置")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
# 添加可选的 HTTP-Referer 和 X-Title headers
http_referer = config.get('http_referer', '')
if http_referer:
headers['HTTP-Referer'] = http_referer
x_title = config.get('x_title', '')
if x_title:
headers['X-Title'] = x_title
data = {
'model': config.get('model', 'allenai/molmo-2-8b:free'),
'messages': [
{'role': 'user', 'content': prompt}
],
'temperature': config.get('temperature', 0.7),
'max_tokens': config.get('max_tokens', 4000)
}
# 使用精细化超时配置连接10秒读取180秒
timeout_config = httpx.Timeout(connect=10.0, read=180.0, write=10.0, pool=10.0)
with httpx.Client(timeout=timeout_config) as client:
response = client.post(
config.get('base_url'),
headers=headers,
json=data
)
response.raise_for_status()
result = response.json()
# OpenRouter 使用 OpenAI 兼容格式
if 'choices' in result and len(result['choices']) > 0:
return result['choices'][0]['message']['content']
raise ValueError(f"无法解析OpenRouter API响应: {result}")
def call_api(self, prompt: str, max_retries: int = 3, retry_delay: int = 2) -> str:
"""
调用大模型API
Args:
prompt: 提示词
max_retries: 最大重试次数
retry_delay: 重试延迟(秒)
Returns:
API返回的文本内容
"""
provider_methods = {
'deepseek': self._call_deepseek_api,
'qianwen': self._call_qianwen_api,
'openai': self._call_openai_api,
'openrouter': self._call_openrouter_api
}
method = provider_methods.get(self.current_provider)
if not method:
raise ValueError(f"不支持的API提供商: {self.current_provider}")
last_error = None
for attempt in range(max_retries):
try:
return method(prompt)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise Exception(f"API调用失败重试{max_retries}次): {str(e)}") from last_error
raise Exception(f"API调用失败: {str(last_error)}")
def update_api_key(self, provider: str, api_key: str):
"""
更新API密钥
Args:
provider: 提供商名称
api_key: API密钥
"""
if provider not in self.config.get('providers', {}):
raise ValueError(f"不支持的API提供商: {provider}")
self.config['providers'][provider]['api_key'] = api_key
# 保存到配置文件
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, allow_unicode=True, default_flow_style=False)