import base64 import json import re from collections import defaultdict from typing import Any, AsyncGenerator, Dict, List, Optional from app.core.config import settings from app.models.chat import Message from app.models.knowledge import Document, KnowledgeBase from app.services.embedding.embedding_factory import EmbeddingsFactory from app.services.fusion_prompts import ( GENERAL_CHAT_PROMPT_TEMPLATE, GRAPH_GLOBAL_PROMPT_TEMPLATE, GRAPH_LOCAL_PROMPT_TEMPLATE, HYBRID_RAG_PROMPT_TEMPLATE, ) from app.services.graph.graphrag_adapter import GraphRAGAdapter from app.services.intent_router import route_intent from app.services.llm.llm_factory import LLMFactory from app.services.reranker.external_api import ExternalRerankerClient from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context from app.services.testing_pipeline.pipeline import run_testing_pipeline from app.services.testing_pipeline.rules import REQUIREMENT_TYPES from app.services.vector_store import VectorStoreFactory TESTING_TARGET_KEYWORDS = [ "测试项", "测试用例", "预期成果", "需求类型", "测试分解", "分解", "正常测试", "异常测试", "测试充分性", ] TESTING_ACTION_KEYWORDS = [ "生成", "输出", "给出", "写", "编写", "设计", "整理", "列出", "提供", "制定", ] TYPE_ALIAS_MAP = { "接口测试": "外部接口测试", "ui测试": "人机交互界面测试", "界面测试": "人机交互界面测试", "恢复测试": "恢复性测试", "可靠性": "可靠性测试", "安全性": "安全性测试", "边界": "边界测试", "安装": "安装性测试", "互操作": "互操作性测试", "敏感性": "敏感性测试", "充分性": "测试充分性要求", } def _escape_stream_text(text: str) -> str: return text.replace('"', '\\"').replace("\n", "\\n") def _extract_stream_text(chunk: Any) -> str: content = getattr(chunk, "content", chunk) if isinstance(content, str): return content if isinstance(content, list): parts: List[str] = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict): maybe_text = item.get("text") if isinstance(maybe_text, str): parts.append(maybe_text) else: parts.append(str(item)) return "".join(parts) return str(content) def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: preview = [] for row in rows[:10]: doc = row["document"] metadata = doc.metadata or {} preview.append( { "kb_id": row.get("kb_id"), "source": metadata.get("source") or metadata.get("file_name") or "unknown", "chunk_id": metadata.get("chunk_id") or "unknown", "score": row.get("final_score", 0), "reranker_score": row.get("reranker_score"), } ) return preview def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: context_rows: List[Dict[str, Any]] = [] for row in rows: doc = row["document"] metadata = dict(doc.metadata or {}) if "kb_id" not in metadata and row.get("kb_id") is not None: metadata["kb_id"] = row.get("kb_id") metadata.setdefault("retrieval_score", row.get("final_score", 0)) if row.get("reranker_score") is not None: metadata.setdefault("reranker_score", row.get("reranker_score")) context_rows.append( { "page_content": doc.page_content.strip(), "metadata": metadata, } ) return context_rows def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str: entities = set() relations: List[Dict[str, Any]] = [] evidences: List[str] = [] for row in rows: doc = row["document"] metadata = doc.metadata or {} for ent in metadata.get("extracted_entities", []): entities.add(str(ent)) for rel in metadata.get("extracted_relations", []): if isinstance(rel, dict): relations.append(rel) evidences.append(doc.page_content.strip()) entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。" relation_lines: List[str] = [] for rel in relations[:120]: src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN" tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN" rel_type = rel.get("type") or rel.get("relation_type") or "其他" desc = rel.get("description") or "" relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}") relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。" evidence_block = "\n\n".join( f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1) ) if not evidence_block: evidence_block = "无可用证据。" return ( "实体列表:\n" f"{entity_block}\n\n" "关系列表:\n" f"{relation_block}\n\n" "原文证据:\n" f"{evidence_block}" ) def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str: groups: Dict[str, List[str]] = defaultdict(list) for row in rows: doc = row["document"] metadata = doc.metadata or {} community_ids = metadata.get("community_ids") or [] if isinstance(community_ids, list) and community_ids: keys = [str(item) for item in community_ids] else: source = metadata.get("source") or metadata.get("file_name") or "unknown" keys = [f"source:{source}"] for key in keys: groups[key].append(doc.page_content.strip()) if not groups: return "暂无社区摘要数据,已回退为基于证据片段的全局总结。" lines: List[str] = [] for idx, (community_id, snippets) in enumerate(groups.items(), start=1): merged = " ".join(snippets[:3]) lines.append(f"社区{idx} ({community_id}) 摘要: {merged}") return "\n\n".join(lines) async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]: embeddings = EmbeddingsFactory.create() kb_vector_stores: List[Dict[str, Any]] = [] for kb in knowledge_bases: documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all() if not documents: continue store = VectorStoreFactory.create( store_type=settings.VECTOR_STORE_TYPE, collection_name=f"kb_{kb.id}", embedding_function=embeddings, ) kb_vector_stores.append({"kb_id": kb.id, "store": store}) return kb_vector_stores def _build_reranker_client() -> ExternalRerankerClient: return ExternalRerankerClient( api_url=settings.RERANKER_API_URL, api_key=settings.RERANKER_API_KEY, model=settings.RERANKER_MODEL, timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS, ) def _is_testing_generation_request(query: str) -> bool: text = (query or "").strip() if not text: return False normalized = text.lower() if normalized.startswith("/testing"): return True if any( token in normalized for token in ( "testing_orchestrator", "testing-orchestrator", "identify_requirement_type", "identify-requirement-type", ) ): return True has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS) has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS) if has_target and has_action: return True if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")): if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text): return True if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")): return True return False def _extract_requirement_type_from_query(query: str) -> Optional[str]: text = (query or "").strip() if not text: return None for req_type in REQUIREMENT_TYPES: if req_type in text: return req_type lowered = text.lower() for alias, req_type in TYPE_ALIAS_MAP.items(): if alias in text or alias in lowered: return req_type return None async def generate_response( query: str, messages: dict, knowledge_base_ids: List[int], chat_id: int, db: Any, ) -> AsyncGenerator[str, None]: try: user_message = Message(content=query, role="user", chat_id=chat_id) db.add(user_message) db.commit() bot_message = Message(content="", role="assistant", chat_id=chat_id) db.add(bot_message) db.commit() if _is_testing_generation_request(query): explicit_type = _extract_requirement_type_from_query(query) retrieval_rows: List[Dict[str, Any]] = [] knowledge_context = "" kb_vector_stores = [] if knowledge_base_ids: testing_kbs = ( db.query(KnowledgeBase) .filter(KnowledgeBase.id.in_(knowledge_base_ids)) .all() ) kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs) if kb_vector_stores: testing_retriever = MultiKBRetriever( reranker_weight=settings.RERANKER_WEIGHT, ) retrieval_rows = await testing_retriever.retrieve( query=query, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=16, top_k=8, ) if retrieval_rows: knowledge_context = format_retrieval_context(retrieval_rows) pipeline_result = run_testing_pipeline( user_requirement_text=query, requirement_type_input=explicit_type, debug=True, knowledge_context=knowledge_context, use_model_generation=True, max_items_per_group=6, cases_per_item=1, max_focus_points=6, max_llm_calls=2, ) context_payload = { "route": { "intent": "TESTING", "reason": "命中测试生成意图,已自动调用测试工具链。", }, "intent": "TESTING", "skill_profile": "testing-orchestrator", "tool_chain": [ "identify-requirement-type", "decompose-test-items", "generate-test-cases", "build_expected_results", "format_output", ], "selected_chain": "TESTING_PIPELINE", "graph_used": False, "reranker_enabled": False, "retrieval_preview": _preview_rows(retrieval_rows), "context": _context_rows(retrieval_rows), "testing_pipeline": { "trace_id": pipeline_result.get("trace_id"), "requirement_type": pipeline_result.get("requirement_type"), "candidates": pipeline_result.get("candidates", []), "pipeline_summary": pipeline_result.get("pipeline_summary", ""), "knowledge_used": pipeline_result.get("knowledge_used", False), "step_logs": pipeline_result.get("step_logs", []), }, } escaped_context = json.dumps(context_payload, ensure_ascii=False) base64_context = base64.b64encode(escaped_context.encode()).decode() separator = "__LLM_RESPONSE__" full_response = f"{base64_context}{separator}" yield f'0:"{base64_context}{separator}"\n' rendered_text = pipeline_result.get("formatted_output", "").strip() if not rendered_text: rendered_text = "未生成测试内容,请补充更明确的需求后重试。" full_response += rendered_text yield f'0:"{_escape_stream_text(rendered_text)}"\n' yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n' bot_message.content = full_response db.commit() return knowledge_bases = ( db.query(KnowledgeBase) .filter(KnowledgeBase.id.in_(knowledge_base_ids)) .all() ) kb_ids = [kb.id for kb in knowledge_bases] llm = LLMFactory.create() decision = await route_intent(llm=llm, query=query, messages=messages) intent = decision["intent"] kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases) if intent in {"B", "C", "D"} and not kb_vector_stores: intent = "A" decision = { "intent": "A", "reason": "未发现可用知识库向量集合,已降级为通用对话路。", } reranker_client = _build_reranker_client() retriever = MultiKBRetriever( reranker_client=reranker_client, reranker_weight=settings.RERANKER_WEIGHT, ) retrieval_rows: List[Dict[str, Any]] = [] graph_used = False selected_chain = intent prompt_text = "" if intent == "A": prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query) elif intent == "B": retrieval_rows = await retriever.retrieve( query=query, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=16, top_k=12, ) context = format_retrieval_context(retrieval_rows) or "无可用证据。" prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context) elif intent == "C": graph_context = "" used_kb_ids: List[int] = [] if settings.GRAPHRAG_ENABLED and kb_ids: try: adapter = GraphRAGAdapter() graph_context, used_kb_ids = await adapter.local_context_multi( kb_ids, query, top_k=settings.GRAPHRAG_LOCAL_TOP_K, level=settings.GRAPHRAG_QUERY_LEVEL, ) graph_used = bool(graph_context) except Exception: graph_context = "" if not graph_context: retrieval_rows = await retriever.retrieve( query=query, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=18, top_k=14, ) graph_context = _build_local_graph_context_fallback(retrieval_rows) selected_chain = "C_fallback_B" else: selected_chain = "C_graph" prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format( query=query, graph_context=graph_context, ) else: community_context = "" if settings.GRAPHRAG_ENABLED and kb_ids: try: adapter = GraphRAGAdapter() community_context, used_kb_ids = await adapter.global_context_multi( kb_ids, query, level=settings.GRAPHRAG_QUERY_LEVEL, ) graph_used = bool(community_context) except Exception: community_context = "" if not community_context: retrieval_rows = await retriever.retrieve( query=query, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=20, top_k=14, ) community_context = _build_global_community_context_fallback(retrieval_rows) selected_chain = "D_fallback_B" else: selected_chain = "D_graph" prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format( query=query, community_context=community_context, ) context_payload = { "route": decision, "intent": intent, "selected_chain": selected_chain, "graph_used": graph_used, "reranker_enabled": reranker_client.enabled, "retrieval_preview": _preview_rows(retrieval_rows), "context": _context_rows(retrieval_rows), } escaped_context = json.dumps(context_payload, ensure_ascii=False) base64_context = base64.b64encode(escaped_context.encode()).decode() separator = "__LLM_RESPONSE__" full_response = f"{base64_context}{separator}" yield f'0:"{base64_context}{separator}"\n' async for chunk in llm.astream(prompt_text): text = _extract_stream_text(chunk) if not text: continue full_response += text yield f'0:"{_escape_stream_text(text)}"\n' yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n' bot_message.content = full_response db.commit() except Exception as e: error_message = f"Error generating response: {str(e)}" print(error_message) yield "3:{text}\n".format(text=error_message) if "bot_message" in locals(): bot_message.content = error_message db.commit() finally: db.close()