diff --git a/Code_RAG.py b/Code_RAG.py new file mode 100644 index 0000000..ed7645b --- /dev/null +++ b/Code_RAG.py @@ -0,0 +1,164 @@ +import os +import json +import numpy as np +import faiss +from dashscope import TextEmbedding, Generation +import dashscope + +# ======================== +# 配置区(按需修改路径) +# ======================== +FAISS_INDEX_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag.faiss" +METADATAS_JSON_PATH = "C:/Users/Administrator/PycharmProjects/PythonProject/satellite_rag_metadata.json" +TOP_K = 8 +dashscope.api_key ="" + + +def get_qwen_embedding(text: str) -> np.ndarray: + """使用 Qwen text-embedding-v4 获取文本向量""" + resp = TextEmbedding.call( + model="text-embedding-v4", + input=text + ) + if resp.status_code != 200: + raise RuntimeError(f"Embedding API error: {resp}") + embedding = resp.output["embeddings"][0]["embedding"] + return np.array(embedding, dtype=np.float32) + + +def call_qwen_max(prompt: str, max_tokens=800) -> str: + """调用 Qwen-Max 进行推理""" + response = Generation.call( + model="qwen-max", + prompt=prompt, + max_tokens=max_tokens, + temperature=0.01, # 尽可能确定性输出 + result_format="text" + ) + if response.status_code != 200: + raise RuntimeError(f"Qwen-Max API error: {response}") + return response.output.text.strip() + + +def load_knowledge_base(): + """加载 FAISS 索引和 metadatas,并动态重建 contexts""" + print("正在加载 FAISS 索引...") + index = faiss.read_index(FAISS_INDEX_PATH) + + print("正在加载元数据...") + with open(METADATAS_JSON_PATH, "r", encoding="utf-8") as f: + metadatas = json.load(f) + + # 动态重建 contexts(与构建索引时使用的格式完全一致) + contexts = [] + for meta in metadatas: + func_name = meta.get("name", "未知函数") + file_path = meta.get("file", "未知文件") + comment = meta.get("comment", "").strip() + summary = meta.get("summary", "无摘要") + called_functions = meta.get("calls", []) + caller_functions = meta.get("called_by", []) + included_headers = meta.get("includes", []) + + context_str = ( + f"【实体类型】函数\n" + f"【函数名】{func_name}\n" + f"【所在文件】{os.path.basename(file_path)}\n" + f"【代码注释】{comment if comment else '无'}\n" + f"【功能摘要】{summary}\n" + f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n" + f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n" + f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}" + ) + contexts.append(context_str) + + assert len(contexts) == len(metadatas), "重建的 contexts 与 metadatas 长度不一致" + assert index.ntotal == len(contexts), "FAISS 向量数量与 contexts 数量不匹配" + + return index, contexts, metadatas + + +def check_requirement_implementation(user_requirement: str): + """主逻辑:判断需求是否已实现""" + print("正在加载知识库...") + faiss_index, context_texts, metadatas = load_knowledge_base() + + print("正在生成需求嵌入...") + query_vec = get_qwen_embedding(user_requirement) + query_vec = np.expand_dims(query_vec, axis=0).astype(np.float32) + + print(f"正在检索 Top-{TOP_K} 相关函数...") + distances, indices = faiss_index.search(query_vec, TOP_K) + + # 构造检索结果文本 + evidence_blocks = [] + for i, idx in enumerate(indices[0]): + sim = 1 - 0.5 * float(distances[0][i]) # 近似余弦相似度(因 embedding 已归一化) + block = f"【候选 {i + 1}】(相似度 ≈ {sim:.3f})\n{context_texts[idx]}" + evidence_blocks.append(block) + + evidence_text = "\n\n".join(evidence_blocks) + + # 构造判断 Prompt + judge_prompt = f"""你是一名资深卫星软件工程师。请根据以下用户提出的功能需求,结合检索到的现有代码函数信息,严格判断该需求是否已经被当前代码库实现。同时需要关注实现的正确与否,检索知识库信息来确认。 + +【用户需求】 +{user_requirement} + +【检索到的相关函数信息】 +{evidence_text} + +请严格按照以下 JSON 格式回答,不要包含任何额外文字、注释或 Markdown: +{{ + "implemented": true 或 false, + "reason": "不超过100字的理由,需引用具体函数名或功能描述", + "most_relevant_function": "最相关的函数名(若无则写 null)", + "file": "所在文件名(若无则写 null)" +}} +""" + + print("正在调用 Qwen-Max 进行判断...") + response_text = call_qwen_max(judge_prompt) + + # 尝试解析 JSON + try: + import json + result = json.loads(response_text) + except Exception as e: + print(" Qwen-Max 返回非 JSON 格式,尝试提取关键信息...") + implemented = "是" in response_text[:30] or "true" in response_text[:30].lower() + result = { + "implemented": implemented, + "reason": response_text[:200], + "most_relevant_function": None, + "file": None + } + + return result + + +def main(): + print("卫星需求实现检查工具") + print("请输入您的自然语言需求(输入 'quit' 退出):") + + while True: + user_input = input("\n> ").strip() + if user_input.lower() in ["quit", "exit", "q"]: + print("再见!") + break + if not user_input: + continue + + try: + result = check_requirement_implementation(user_input) + print("\n判断结果:") + print(f" 已实现:{'是' if result['implemented'] else '否'}") + print(f" 理由:{result['reason']}") + if result.get("most_relevant_function"): + print(f" 相关函数:{result['most_relevant_function']}({result['file']})") + except Exception as e: + print(f"处理出错:{e}") + + +if __name__ == "__main__": + main() \ No newline at end of file