上传文件至「/」
建立知识库后进行简单的需求-实现RAG对照
This commit is contained in:
164
Code_RAG.py
Normal file
164
Code_RAG.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import faiss
|
||||
from dashscope import TextEmbedding, Generation
|
||||
import dashscope
|
||||
|
||||
# ========================
|
||||
# 配置区(按需修改路径)
|
||||
# ========================
|
||||
FAISS_INDEX_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag.faiss"
|
||||
METADATAS_JSON_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag_metadata.json"
|
||||
TOP_K = 8
|
||||
dashscope.api_key =""
|
||||
|
||||
|
||||
def get_qwen_embedding(text: str) -> np.ndarray:
|
||||
"""使用 Qwen text-embedding-v4 获取文本向量"""
|
||||
resp = TextEmbedding.call(
|
||||
model="text-embedding-v4",
|
||||
input=text
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Embedding API error: {resp}")
|
||||
embedding = resp.output["embeddings"][0]["embedding"]
|
||||
return np.array(embedding, dtype=np.float32)
|
||||
|
||||
|
||||
def call_qwen_max(prompt: str, max_tokens=800) -> str:
|
||||
"""调用 Qwen-Max 进行推理"""
|
||||
response = Generation.call(
|
||||
model="qwen-max",
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.01, # 尽可能确定性输出
|
||||
result_format="text"
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Qwen-Max API error: {response}")
|
||||
return response.output.text.strip()
|
||||
|
||||
|
||||
def load_knowledge_base():
|
||||
"""加载 FAISS 索引和 metadatas,并动态重建 contexts"""
|
||||
print("正在加载 FAISS 索引...")
|
||||
index = faiss.read_index(FAISS_INDEX_PATH)
|
||||
|
||||
print("正在加载元数据...")
|
||||
with open(METADATAS_JSON_PATH, "r", encoding="utf-8") as f:
|
||||
metadatas = json.load(f)
|
||||
|
||||
# 动态重建 contexts(与构建索引时使用的格式完全一致)
|
||||
contexts = []
|
||||
for meta in metadatas:
|
||||
func_name = meta.get("name", "未知函数")
|
||||
file_path = meta.get("file", "未知文件")
|
||||
comment = meta.get("comment", "").strip()
|
||||
summary = meta.get("summary", "无摘要")
|
||||
called_functions = meta.get("calls", [])
|
||||
caller_functions = meta.get("called_by", [])
|
||||
included_headers = meta.get("includes", [])
|
||||
|
||||
context_str = (
|
||||
f"【实体类型】函数\n"
|
||||
f"【函数名】{func_name}\n"
|
||||
f"【所在文件】{os.path.basename(file_path)}\n"
|
||||
f"【代码注释】{comment if comment else '无'}\n"
|
||||
f"【功能摘要】{summary}\n"
|
||||
f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n"
|
||||
f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n"
|
||||
f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}"
|
||||
)
|
||||
contexts.append(context_str)
|
||||
|
||||
assert len(contexts) == len(metadatas), "重建的 contexts 与 metadatas 长度不一致"
|
||||
assert index.ntotal == len(contexts), "FAISS 向量数量与 contexts 数量不匹配"
|
||||
|
||||
return index, contexts, metadatas
|
||||
|
||||
|
||||
def check_requirement_implementation(user_requirement: str):
|
||||
"""主逻辑:判断需求是否已实现"""
|
||||
print("正在加载知识库...")
|
||||
faiss_index, context_texts, metadatas = load_knowledge_base()
|
||||
|
||||
print("正在生成需求嵌入...")
|
||||
query_vec = get_qwen_embedding(user_requirement)
|
||||
query_vec = np.expand_dims(query_vec, axis=0).astype(np.float32)
|
||||
|
||||
print(f"正在检索 Top-{TOP_K} 相关函数...")
|
||||
distances, indices = faiss_index.search(query_vec, TOP_K)
|
||||
|
||||
# 构造检索结果文本
|
||||
evidence_blocks = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
sim = 1 - 0.5 * float(distances[0][i]) # 近似余弦相似度(因 embedding 已归一化)
|
||||
block = f"【候选 {i + 1}】(相似度 ≈ {sim:.3f})\n{context_texts[idx]}"
|
||||
evidence_blocks.append(block)
|
||||
|
||||
evidence_text = "\n\n".join(evidence_blocks)
|
||||
|
||||
# 构造判断 Prompt
|
||||
judge_prompt = f"""你是一名资深卫星软件工程师。请根据以下用户提出的功能需求,结合检索到的现有代码函数信息,严格判断该需求是否已经被当前代码库实现。同时需要关注实现的正确与否,检索知识库信息来确认。
|
||||
|
||||
【用户需求】
|
||||
{user_requirement}
|
||||
|
||||
【检索到的相关函数信息】
|
||||
{evidence_text}
|
||||
|
||||
请严格按照以下 JSON 格式回答,不要包含任何额外文字、注释或 Markdown:
|
||||
{{
|
||||
"implemented": true 或 false,
|
||||
"reason": "不超过100字的理由,需引用具体函数名或功能描述",
|
||||
"most_relevant_function": "最相关的函数名(若无则写 null)",
|
||||
"file": "所在文件名(若无则写 null)"
|
||||
}}
|
||||
"""
|
||||
|
||||
print("正在调用 Qwen-Max 进行判断...")
|
||||
response_text = call_qwen_max(judge_prompt)
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
import json
|
||||
result = json.loads(response_text)
|
||||
except Exception as e:
|
||||
print(" Qwen-Max 返回非 JSON 格式,尝试提取关键信息...")
|
||||
implemented = "是" in response_text[:30] or "true" in response_text[:30].lower()
|
||||
result = {
|
||||
"implemented": implemented,
|
||||
"reason": response_text[:200],
|
||||
"most_relevant_function": None,
|
||||
"file": None
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
print("卫星需求实现检查工具")
|
||||
print("请输入您的自然语言需求(输入 'quit' 退出):")
|
||||
|
||||
while True:
|
||||
user_input = input("\n> ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("再见!")
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
try:
|
||||
result = check_requirement_implementation(user_input)
|
||||
print("\n判断结果:")
|
||||
print(f" 已实现:{'是' if result['implemented'] else '否'}")
|
||||
print(f" 理由:{result['reason']}")
|
||||
if result.get("most_relevant_function"):
|
||||
print(f" 相关函数:{result['most_relevant_function']}({result['file']})")
|
||||
except Exception as e:
|
||||
print(f"处理出错:{e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user