181 lines
5.8 KiB
Python
181 lines
5.8 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
|
||
import requests
|
||
try:
|
||
from openai import OpenAI
|
||
except ImportError: # pragma: no cover - depends on deployment environment
|
||
OpenAI = None
|
||
import time
|
||
from typing import List
|
||
|
||
try:
|
||
import numpy as np
|
||
except ImportError: # pragma: no cover - depends on deployment environment
|
||
np = None
|
||
|
||
from code_parser import FUNCTION_CALL_GRAPH, CALLED_BY_GRAPH, FILE_DEPENDENCIES
|
||
from config import QWEN_API_KEY, QWEN_API_URL, QWEN_CHAT_MODEL, QWEN_EMBEDDING_MODEL
|
||
|
||
MAX_CODE_LENGTH = 800
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _get_qwen_api_key() -> str:
|
||
api_key = (
|
||
os.getenv("DASHSCOPE_API_KEY")
|
||
or os.getenv("DASH_SCOPE_API_KEY")
|
||
or os.getenv("QWEN_API_KEY")
|
||
or QWEN_API_KEY
|
||
)
|
||
if not api_key:
|
||
raise RuntimeError(
|
||
"Qwen/DashScope API key is not configured. Set DASHSCOPE_API_KEY, "
|
||
"DASH_SCOPE_API_KEY, or QWEN_API_KEY."
|
||
)
|
||
return api_key
|
||
|
||
|
||
def _chat_completions_url() -> str:
|
||
return QWEN_API_URL.rstrip("/") + "/chat/completions"
|
||
|
||
def generate_code_summary(func_name, comment, logic, code_snippet, file_path):
|
||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
|
||
included_headers = FILE_DEPENDENCIES.get(file_path, [])
|
||
|
||
prompt = f"""
|
||
你是一名资深的航天软件工程师,请总结以下C++函数的核心功能,**必须严格包含以下6点**:
|
||
|
||
1. 函数流程与逻辑:
|
||
{logic}
|
||
|
||
2. 函数目的:
|
||
→ 结合航天术语总结,必须根据上述“函数流程与逻辑”推导生成,不得依赖假设!
|
||
|
||
3. 输入参数(名称、类型、作用,用括号列出)
|
||
|
||
4. 返回值(类型和含义)
|
||
|
||
5. 与其他函数的关联关系:
|
||
- 被调用的函数: {', '.join(called_functions) or '无'}
|
||
- 调用此函数的函数: {', '.join(caller_functions) or '无'}
|
||
|
||
6. 与跨文件关联的函数(头文件): {', '.join(included_headers) or '无'}
|
||
|
||
---
|
||
函数名: {func_name}
|
||
注释: {comment or '无'}
|
||
代码片段(截断至{MAX_CODE_LENGTH}字符):
|
||
{code_snippet[:MAX_CODE_LENGTH]}
|
||
|
||
输出格式: "功能: [按上述6点组织的连贯段落,不要编号]"
|
||
"""
|
||
headers = {"Authorization": f"Bearer {_get_qwen_api_key()}", "Content-Type": "application/json"}
|
||
payload = {
|
||
"model": QWEN_CHAT_MODEL,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.1
|
||
}
|
||
response = requests.post(
|
||
_chat_completions_url(),
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
summary = response.json()['choices'][0]['message']['content'].strip()
|
||
return summary.replace("功能: ", "")
|
||
|
||
|
||
def generate_code_logic(func_name, comment, code_snippet, file_path):
|
||
prompt = f"""
|
||
你是一名资深的航天软件工程师,请分析以下C++函数,并生成其markdown格式的**核心流程逻辑总结**。
|
||
|
||
**函数名**: {func_name}
|
||
**所在文件**: {file_path}
|
||
**前置注释**: {comment or '无'}
|
||
**函数代码片段**:{code_snippet}请严格按照以下步骤执行:
|
||
1. **步骤分解**:逐行或按代码块分析函数体,用简洁的步骤描述其执行流程。例如:“1. 检查输入参数x是否有效。2. 调用内部函数compute()进行计算。3. 对结果进行格式化处理。4. 返回格式化后的字符串。”
|
||
2. **逻辑归纳**:基于上述步骤,用1-3句话概括这个函数的核心功能与逻辑主线。
|
||
|
||
**输出格式**:只需输出纯粹的、不带Markdown标题的流程逻辑描述文本。
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {_get_qwen_api_key()}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": QWEN_CHAT_MODEL,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.1
|
||
}
|
||
response = requests.post(
|
||
_chat_completions_url(),
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
logic = result['choices'][0]['message']['content'].strip()
|
||
return logic
|
||
|
||
|
||
def get_qwen_embedding(text: str) -> np.ndarray:
|
||
try:
|
||
if OpenAI is None:
|
||
raise RuntimeError("openai package is not installed.")
|
||
client = OpenAI(
|
||
api_key=_get_qwen_api_key(),
|
||
base_url=QWEN_API_URL
|
||
)
|
||
response = client.embeddings.create(
|
||
model=QWEN_EMBEDDING_MODEL,
|
||
input=text
|
||
)
|
||
embedding = response.data[0].embedding
|
||
if np is not None:
|
||
return np.array(embedding, dtype='float32')
|
||
return [float(value) for value in embedding]
|
||
except Exception as e:
|
||
logger.error(f"获取嵌入失败: {e}")
|
||
raise RuntimeError(f"Failed to get embedding: {e}") from e
|
||
|
||
|
||
def call_qwen_max(prompt: str, max_tokens: int = 1500, temperature: float = 0.1) -> str:
|
||
"""
|
||
调用 Qwen-Max 进行推理
|
||
|
||
Args:
|
||
prompt: 输入的提示文本
|
||
max_tokens: 最大输出token数
|
||
temperature: 温度参数,控制随机性
|
||
|
||
Returns:
|
||
模型生成的文本
|
||
"""
|
||
try:
|
||
# 使用 OpenAI 兼容格式调用通义千问
|
||
if OpenAI is None:
|
||
raise RuntimeError("openai package is not installed.")
|
||
client = OpenAI(
|
||
api_key=_get_qwen_api_key(),
|
||
base_url=QWEN_API_URL
|
||
)
|
||
|
||
response = client.chat.completions.create(
|
||
model=QWEN_CHAT_MODEL,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
return response.choices[0].message.content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用Qwen-Max失败: {e}")
|
||
return f"API调用失败: {str(e)}"
|