533 lines
18 KiB
Python
533 lines
18 KiB
Python
import base64
|
|
import json
|
|
import re
|
|
from collections import defaultdict
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
|
|
from app.core.config import settings
|
|
from app.models.chat import Message
|
|
from app.models.knowledge import Document, KnowledgeBase
|
|
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
|
from app.services.fusion_prompts import (
|
|
GENERAL_CHAT_PROMPT_TEMPLATE,
|
|
GRAPH_GLOBAL_PROMPT_TEMPLATE,
|
|
GRAPH_LOCAL_PROMPT_TEMPLATE,
|
|
HYBRID_RAG_PROMPT_TEMPLATE,
|
|
)
|
|
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
|
from app.services.intent_router import route_intent
|
|
from app.services.llm.llm_factory import LLMFactory
|
|
from app.services.reranker.external_api import ExternalRerankerClient
|
|
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
|
from app.services.testing_pipeline.pipeline import run_testing_pipeline
|
|
from app.services.testing_pipeline.rules import REQUIREMENT_TYPES
|
|
from app.services.vector_store import VectorStoreFactory
|
|
|
|
|
|
TESTING_TARGET_KEYWORDS = [
|
|
"测试项",
|
|
"测试用例",
|
|
"预期成果",
|
|
"需求类型",
|
|
"测试分解",
|
|
"分解",
|
|
"正常测试",
|
|
"异常测试",
|
|
"测试充分性",
|
|
]
|
|
|
|
TESTING_ACTION_KEYWORDS = [
|
|
"生成",
|
|
"输出",
|
|
"给出",
|
|
"写",
|
|
"编写",
|
|
"设计",
|
|
"整理",
|
|
"列出",
|
|
"提供",
|
|
"制定",
|
|
]
|
|
|
|
TYPE_ALIAS_MAP = {
|
|
"接口测试": "外部接口测试",
|
|
"ui测试": "人机交互界面测试",
|
|
"界面测试": "人机交互界面测试",
|
|
"恢复测试": "恢复性测试",
|
|
"可靠性": "可靠性测试",
|
|
"安全性": "安全性测试",
|
|
"边界": "边界测试",
|
|
"安装": "安装性测试",
|
|
"互操作": "互操作性测试",
|
|
"敏感性": "敏感性测试",
|
|
"充分性": "测试充分性要求",
|
|
}
|
|
|
|
|
|
def _escape_stream_text(text: str) -> str:
|
|
return text.replace('"', '\\"').replace("\n", "\\n")
|
|
|
|
|
|
def _extract_stream_text(chunk: Any) -> str:
|
|
content = getattr(chunk, "content", chunk)
|
|
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
if isinstance(content, list):
|
|
parts: List[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
maybe_text = item.get("text")
|
|
if isinstance(maybe_text, str):
|
|
parts.append(maybe_text)
|
|
else:
|
|
parts.append(str(item))
|
|
return "".join(parts)
|
|
|
|
return str(content)
|
|
|
|
|
|
def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
preview = []
|
|
for row in rows[:10]:
|
|
doc = row["document"]
|
|
metadata = doc.metadata or {}
|
|
preview.append(
|
|
{
|
|
"kb_id": row.get("kb_id"),
|
|
"source": metadata.get("source") or metadata.get("file_name") or "unknown",
|
|
"chunk_id": metadata.get("chunk_id") or "unknown",
|
|
"score": row.get("final_score", 0),
|
|
"reranker_score": row.get("reranker_score"),
|
|
}
|
|
)
|
|
return preview
|
|
|
|
|
|
def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
context_rows: List[Dict[str, Any]] = []
|
|
for row in rows:
|
|
doc = row["document"]
|
|
metadata = dict(doc.metadata or {})
|
|
|
|
if "kb_id" not in metadata and row.get("kb_id") is not None:
|
|
metadata["kb_id"] = row.get("kb_id")
|
|
metadata.setdefault("retrieval_score", row.get("final_score", 0))
|
|
if row.get("reranker_score") is not None:
|
|
metadata.setdefault("reranker_score", row.get("reranker_score"))
|
|
|
|
context_rows.append(
|
|
{
|
|
"page_content": doc.page_content.strip(),
|
|
"metadata": metadata,
|
|
}
|
|
)
|
|
return context_rows
|
|
|
|
|
|
def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
|
entities = set()
|
|
relations: List[Dict[str, Any]] = []
|
|
evidences: List[str] = []
|
|
|
|
for row in rows:
|
|
doc = row["document"]
|
|
metadata = doc.metadata or {}
|
|
|
|
for ent in metadata.get("extracted_entities", []):
|
|
entities.add(str(ent))
|
|
|
|
for rel in metadata.get("extracted_relations", []):
|
|
if isinstance(rel, dict):
|
|
relations.append(rel)
|
|
|
|
evidences.append(doc.page_content.strip())
|
|
|
|
entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。"
|
|
|
|
relation_lines: List[str] = []
|
|
for rel in relations[:120]:
|
|
src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN"
|
|
tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN"
|
|
rel_type = rel.get("type") or rel.get("relation_type") or "其他"
|
|
desc = rel.get("description") or ""
|
|
relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}")
|
|
|
|
relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。"
|
|
|
|
evidence_block = "\n\n".join(
|
|
f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1)
|
|
)
|
|
if not evidence_block:
|
|
evidence_block = "无可用证据。"
|
|
|
|
return (
|
|
"实体列表:\n"
|
|
f"{entity_block}\n\n"
|
|
"关系列表:\n"
|
|
f"{relation_block}\n\n"
|
|
"原文证据:\n"
|
|
f"{evidence_block}"
|
|
)
|
|
|
|
|
|
def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
|
groups: Dict[str, List[str]] = defaultdict(list)
|
|
|
|
for row in rows:
|
|
doc = row["document"]
|
|
metadata = doc.metadata or {}
|
|
community_ids = metadata.get("community_ids") or []
|
|
|
|
if isinstance(community_ids, list) and community_ids:
|
|
keys = [str(item) for item in community_ids]
|
|
else:
|
|
source = metadata.get("source") or metadata.get("file_name") or "unknown"
|
|
keys = [f"source:{source}"]
|
|
|
|
for key in keys:
|
|
groups[key].append(doc.page_content.strip())
|
|
|
|
if not groups:
|
|
return "暂无社区摘要数据,已回退为基于证据片段的全局总结。"
|
|
|
|
lines: List[str] = []
|
|
for idx, (community_id, snippets) in enumerate(groups.items(), start=1):
|
|
merged = " ".join(snippets[:3])
|
|
lines.append(f"社区{idx} ({community_id}) 摘要: {merged}")
|
|
|
|
return "\n\n".join(lines)
|
|
|
|
|
|
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
|
embeddings = EmbeddingsFactory.create()
|
|
kb_vector_stores: List[Dict[str, Any]] = []
|
|
|
|
for kb in knowledge_bases:
|
|
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
|
if not documents:
|
|
continue
|
|
|
|
store = VectorStoreFactory.create(
|
|
store_type=settings.VECTOR_STORE_TYPE,
|
|
collection_name=f"kb_{kb.id}",
|
|
embedding_function=embeddings,
|
|
)
|
|
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
|
|
|
return kb_vector_stores
|
|
|
|
|
|
def _build_reranker_client() -> ExternalRerankerClient:
|
|
return ExternalRerankerClient(
|
|
api_url=settings.RERANKER_API_URL,
|
|
api_key=settings.RERANKER_API_KEY,
|
|
model=settings.RERANKER_MODEL,
|
|
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
|
|
def _is_testing_generation_request(query: str) -> bool:
|
|
text = (query or "").strip()
|
|
if not text:
|
|
return False
|
|
|
|
normalized = text.lower()
|
|
if normalized.startswith("/testing"):
|
|
return True
|
|
|
|
if any(
|
|
token in normalized
|
|
for token in (
|
|
"testing_orchestrator",
|
|
"testing-orchestrator",
|
|
"identify_requirement_type",
|
|
"identify-requirement-type",
|
|
)
|
|
):
|
|
return True
|
|
|
|
has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS)
|
|
has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS)
|
|
if has_target and has_action:
|
|
return True
|
|
|
|
if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")):
|
|
if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text):
|
|
return True
|
|
if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _extract_requirement_type_from_query(query: str) -> Optional[str]:
|
|
text = (query or "").strip()
|
|
if not text:
|
|
return None
|
|
|
|
for req_type in REQUIREMENT_TYPES:
|
|
if req_type in text:
|
|
return req_type
|
|
|
|
lowered = text.lower()
|
|
for alias, req_type in TYPE_ALIAS_MAP.items():
|
|
if alias in text or alias in lowered:
|
|
return req_type
|
|
|
|
return None
|
|
|
|
|
|
async def generate_response(
|
|
query: str,
|
|
messages: dict,
|
|
knowledge_base_ids: List[int],
|
|
chat_id: int,
|
|
db: Any,
|
|
) -> AsyncGenerator[str, None]:
|
|
try:
|
|
user_message = Message(content=query, role="user", chat_id=chat_id)
|
|
db.add(user_message)
|
|
db.commit()
|
|
|
|
bot_message = Message(content="", role="assistant", chat_id=chat_id)
|
|
db.add(bot_message)
|
|
db.commit()
|
|
|
|
if _is_testing_generation_request(query):
|
|
explicit_type = _extract_requirement_type_from_query(query)
|
|
|
|
retrieval_rows: List[Dict[str, Any]] = []
|
|
knowledge_context = ""
|
|
kb_vector_stores = []
|
|
if knowledge_base_ids:
|
|
testing_kbs = (
|
|
db.query(KnowledgeBase)
|
|
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
|
.all()
|
|
)
|
|
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
|
|
|
|
if kb_vector_stores:
|
|
testing_retriever = MultiKBRetriever(
|
|
reranker_weight=settings.RERANKER_WEIGHT,
|
|
)
|
|
retrieval_rows = await testing_retriever.retrieve(
|
|
query=query,
|
|
kb_vector_stores=kb_vector_stores,
|
|
fetch_k_per_kb=16,
|
|
top_k=8,
|
|
)
|
|
if retrieval_rows:
|
|
knowledge_context = format_retrieval_context(retrieval_rows)
|
|
|
|
pipeline_result = run_testing_pipeline(
|
|
user_requirement_text=query,
|
|
requirement_type_input=explicit_type,
|
|
debug=True,
|
|
knowledge_context=knowledge_context,
|
|
use_model_generation=True,
|
|
max_items_per_group=6,
|
|
cases_per_item=1,
|
|
max_focus_points=6,
|
|
max_llm_calls=2,
|
|
)
|
|
|
|
context_payload = {
|
|
"route": {
|
|
"intent": "TESTING",
|
|
"reason": "命中测试生成意图,已自动调用测试工具链。",
|
|
},
|
|
"intent": "TESTING",
|
|
"skill_profile": "testing-orchestrator",
|
|
"tool_chain": [
|
|
"identify-requirement-type",
|
|
"decompose-test-items",
|
|
"generate-test-cases",
|
|
"build_expected_results",
|
|
"format_output",
|
|
],
|
|
"selected_chain": "TESTING_PIPELINE",
|
|
"graph_used": False,
|
|
"reranker_enabled": False,
|
|
"retrieval_preview": _preview_rows(retrieval_rows),
|
|
"context": _context_rows(retrieval_rows),
|
|
"testing_pipeline": {
|
|
"trace_id": pipeline_result.get("trace_id"),
|
|
"requirement_type": pipeline_result.get("requirement_type"),
|
|
"candidates": pipeline_result.get("candidates", []),
|
|
"pipeline_summary": pipeline_result.get("pipeline_summary", ""),
|
|
"knowledge_used": pipeline_result.get("knowledge_used", False),
|
|
"step_logs": pipeline_result.get("step_logs", []),
|
|
},
|
|
}
|
|
|
|
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
|
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
|
separator = "__LLM_RESPONSE__"
|
|
|
|
full_response = f"{base64_context}{separator}"
|
|
yield f'0:"{base64_context}{separator}"\n'
|
|
|
|
rendered_text = pipeline_result.get("formatted_output", "").strip()
|
|
if not rendered_text:
|
|
rendered_text = "未生成测试内容,请补充更明确的需求后重试。"
|
|
|
|
full_response += rendered_text
|
|
yield f'0:"{_escape_stream_text(rendered_text)}"\n'
|
|
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
|
|
|
bot_message.content = full_response
|
|
db.commit()
|
|
return
|
|
|
|
knowledge_bases = (
|
|
db.query(KnowledgeBase)
|
|
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
|
.all()
|
|
)
|
|
kb_ids = [kb.id for kb in knowledge_bases]
|
|
|
|
llm = LLMFactory.create()
|
|
decision = await route_intent(llm=llm, query=query, messages=messages)
|
|
intent = decision["intent"]
|
|
|
|
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
|
if intent in {"B", "C", "D"} and not kb_vector_stores:
|
|
intent = "A"
|
|
decision = {
|
|
"intent": "A",
|
|
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
|
|
}
|
|
|
|
reranker_client = _build_reranker_client()
|
|
retriever = MultiKBRetriever(
|
|
reranker_client=reranker_client,
|
|
reranker_weight=settings.RERANKER_WEIGHT,
|
|
)
|
|
|
|
retrieval_rows: List[Dict[str, Any]] = []
|
|
graph_used = False
|
|
selected_chain = intent
|
|
prompt_text = ""
|
|
|
|
if intent == "A":
|
|
prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query)
|
|
|
|
elif intent == "B":
|
|
retrieval_rows = await retriever.retrieve(
|
|
query=query,
|
|
kb_vector_stores=kb_vector_stores,
|
|
fetch_k_per_kb=16,
|
|
top_k=12,
|
|
)
|
|
context = format_retrieval_context(retrieval_rows) or "无可用证据。"
|
|
prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context)
|
|
|
|
elif intent == "C":
|
|
graph_context = ""
|
|
used_kb_ids: List[int] = []
|
|
if settings.GRAPHRAG_ENABLED and kb_ids:
|
|
try:
|
|
adapter = GraphRAGAdapter()
|
|
graph_context, used_kb_ids = await adapter.local_context_multi(
|
|
kb_ids,
|
|
query,
|
|
top_k=settings.GRAPHRAG_LOCAL_TOP_K,
|
|
level=settings.GRAPHRAG_QUERY_LEVEL,
|
|
)
|
|
graph_used = bool(graph_context)
|
|
except Exception:
|
|
graph_context = ""
|
|
|
|
if not graph_context:
|
|
retrieval_rows = await retriever.retrieve(
|
|
query=query,
|
|
kb_vector_stores=kb_vector_stores,
|
|
fetch_k_per_kb=18,
|
|
top_k=14,
|
|
)
|
|
graph_context = _build_local_graph_context_fallback(retrieval_rows)
|
|
selected_chain = "C_fallback_B"
|
|
|
|
else:
|
|
selected_chain = "C_graph"
|
|
|
|
prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format(
|
|
query=query,
|
|
graph_context=graph_context,
|
|
)
|
|
|
|
else:
|
|
community_context = ""
|
|
if settings.GRAPHRAG_ENABLED and kb_ids:
|
|
try:
|
|
adapter = GraphRAGAdapter()
|
|
community_context, used_kb_ids = await adapter.global_context_multi(
|
|
kb_ids,
|
|
query,
|
|
level=settings.GRAPHRAG_QUERY_LEVEL,
|
|
)
|
|
graph_used = bool(community_context)
|
|
except Exception:
|
|
community_context = ""
|
|
|
|
if not community_context:
|
|
retrieval_rows = await retriever.retrieve(
|
|
query=query,
|
|
kb_vector_stores=kb_vector_stores,
|
|
fetch_k_per_kb=20,
|
|
top_k=14,
|
|
)
|
|
community_context = _build_global_community_context_fallback(retrieval_rows)
|
|
selected_chain = "D_fallback_B"
|
|
else:
|
|
selected_chain = "D_graph"
|
|
|
|
prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format(
|
|
query=query,
|
|
community_context=community_context,
|
|
)
|
|
|
|
context_payload = {
|
|
"route": decision,
|
|
"intent": intent,
|
|
"selected_chain": selected_chain,
|
|
"graph_used": graph_used,
|
|
"reranker_enabled": reranker_client.enabled,
|
|
"retrieval_preview": _preview_rows(retrieval_rows),
|
|
"context": _context_rows(retrieval_rows),
|
|
}
|
|
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
|
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
|
separator = "__LLM_RESPONSE__"
|
|
|
|
full_response = f"{base64_context}{separator}"
|
|
yield f'0:"{base64_context}{separator}"\n'
|
|
|
|
async for chunk in llm.astream(prompt_text):
|
|
text = _extract_stream_text(chunk)
|
|
if not text:
|
|
continue
|
|
full_response += text
|
|
yield f'0:"{_escape_stream_text(text)}"\n'
|
|
|
|
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
|
|
|
bot_message.content = full_response
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
error_message = f"Error generating response: {str(e)}"
|
|
print(error_message)
|
|
yield "3:{text}\n".format(text=error_message)
|
|
|
|
if "bot_message" in locals():
|
|
bot_message.content = error_message
|
|
db.commit()
|
|
finally:
|
|
db.close()
|