Files
CodeKnowledgeBuild/Code_RAG.py
JYF 6b0f3bb146 上传文件至「/」
建立知识库后进行简单的需求-实现RAG对照
2026-02-04 14:47:16 +08:00

164 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import numpy as np
import faiss
from dashscope import TextEmbedding, Generation
import dashscope
# ========================
# 配置区(按需修改路径)
# ========================
FAISS_INDEX_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag.faiss"
METADATAS_JSON_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag_metadata.json"
TOP_K = 8
dashscope.api_key =""
def get_qwen_embedding(text: str) -> np.ndarray:
"""使用 Qwen text-embedding-v4 获取文本向量"""
resp = TextEmbedding.call(
model="text-embedding-v4",
input=text
)
if resp.status_code != 200:
raise RuntimeError(f"Embedding API error: {resp}")
embedding = resp.output["embeddings"][0]["embedding"]
return np.array(embedding, dtype=np.float32)
def call_qwen_max(prompt: str, max_tokens=800) -> str:
"""调用 Qwen-Max 进行推理"""
response = Generation.call(
model="qwen-max",
prompt=prompt,
max_tokens=max_tokens,
temperature=0.01, # 尽可能确定性输出
result_format="text"
)
if response.status_code != 200:
raise RuntimeError(f"Qwen-Max API error: {response}")
return response.output.text.strip()
def load_knowledge_base():
"""加载 FAISS 索引和 metadatas并动态重建 contexts"""
print("正在加载 FAISS 索引...")
index = faiss.read_index(FAISS_INDEX_PATH)
print("正在加载元数据...")
with open(METADATAS_JSON_PATH, "r", encoding="utf-8") as f:
metadatas = json.load(f)
# 动态重建 contexts与构建索引时使用的格式完全一致
contexts = []
for meta in metadatas:
func_name = meta.get("name", "未知函数")
file_path = meta.get("file", "未知文件")
comment = meta.get("comment", "").strip()
summary = meta.get("summary", "无摘要")
called_functions = meta.get("calls", [])
caller_functions = meta.get("called_by", [])
included_headers = meta.get("includes", [])
context_str = (
f"【实体类型】函数\n"
f"【函数名】{func_name}\n"
f"【所在文件】{os.path.basename(file_path)}\n"
f"【代码注释】{comment if comment else ''}\n"
f"【功能摘要】{summary}\n"
f"【调用的函数】{', '.join(called_functions) if called_functions else ''}\n"
f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else ''}\n"
f"【包含的头文件】{', '.join(included_headers) if included_headers else ''}"
)
contexts.append(context_str)
assert len(contexts) == len(metadatas), "重建的 contexts 与 metadatas 长度不一致"
assert index.ntotal == len(contexts), "FAISS 向量数量与 contexts 数量不匹配"
return index, contexts, metadatas
def check_requirement_implementation(user_requirement: str):
"""主逻辑:判断需求是否已实现"""
print("正在加载知识库...")
faiss_index, context_texts, metadatas = load_knowledge_base()
print("正在生成需求嵌入...")
query_vec = get_qwen_embedding(user_requirement)
query_vec = np.expand_dims(query_vec, axis=0).astype(np.float32)
print(f"正在检索 Top-{TOP_K} 相关函数...")
distances, indices = faiss_index.search(query_vec, TOP_K)
# 构造检索结果文本
evidence_blocks = []
for i, idx in enumerate(indices[0]):
sim = 1 - 0.5 * float(distances[0][i]) # 近似余弦相似度(因 embedding 已归一化)
block = f"【候选 {i + 1}】(相似度 ≈ {sim:.3f}\n{context_texts[idx]}"
evidence_blocks.append(block)
evidence_text = "\n\n".join(evidence_blocks)
# 构造判断 Prompt
judge_prompt = f"""你是一名资深卫星软件工程师。请根据以下用户提出的功能需求,结合检索到的现有代码函数信息,严格判断该需求是否已经被当前代码库实现。同时需要关注实现的正确与否,检索知识库信息来确认。
【用户需求】
{user_requirement}
【检索到的相关函数信息】
{evidence_text}
请严格按照以下 JSON 格式回答,不要包含任何额外文字、注释或 Markdown
{{
"implemented": true 或 false,
"reason": "不超过100字的理由需引用具体函数名或功能描述",
"most_relevant_function": "最相关的函数名(若无则写 null",
"file": "所在文件名(若无则写 null"
}}
"""
print("正在调用 Qwen-Max 进行判断...")
response_text = call_qwen_max(judge_prompt)
# 尝试解析 JSON
try:
import json
result = json.loads(response_text)
except Exception as e:
print(" Qwen-Max 返回非 JSON 格式,尝试提取关键信息...")
implemented = "" in response_text[:30] or "true" in response_text[:30].lower()
result = {
"implemented": implemented,
"reason": response_text[:200],
"most_relevant_function": None,
"file": None
}
return result
def main():
print("卫星需求实现检查工具")
print("请输入您的自然语言需求(输入 'quit' 退出):")
while True:
user_input = input("\n> ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("再见!")
break
if not user_input:
continue
try:
result = check_requirement_implementation(user_input)
print("\n判断结果:")
print(f" 已实现:{'' if result['implemented'] else ''}")
print(f" 理由:{result['reason']}")
if result.get("most_relevant_function"):
print(f" 相关函数:{result['most_relevant_function']}{result['file']}")
except Exception as e:
print(f"处理出错:{e}")
if __name__ == "__main__":
main()