init. project
This commit is contained in:
532
rag-web-ui/backend/app/services/chat_service.py
Normal file
532
rag-web-ui/backend/app/services/chat_service.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.chat import Message
|
||||
from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.fusion_prompts import (
|
||||
GENERAL_CHAT_PROMPT_TEMPLATE,
|
||||
GRAPH_GLOBAL_PROMPT_TEMPLATE,
|
||||
GRAPH_LOCAL_PROMPT_TEMPLATE,
|
||||
HYBRID_RAG_PROMPT_TEMPLATE,
|
||||
)
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
from app.services.intent_router import route_intent
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.reranker.external_api import ExternalRerankerClient
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline.pipeline import run_testing_pipeline
|
||||
from app.services.testing_pipeline.rules import REQUIREMENT_TYPES
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
|
||||
TESTING_TARGET_KEYWORDS = [
|
||||
"测试项",
|
||||
"测试用例",
|
||||
"预期成果",
|
||||
"需求类型",
|
||||
"测试分解",
|
||||
"分解",
|
||||
"正常测试",
|
||||
"异常测试",
|
||||
"测试充分性",
|
||||
]
|
||||
|
||||
TESTING_ACTION_KEYWORDS = [
|
||||
"生成",
|
||||
"输出",
|
||||
"给出",
|
||||
"写",
|
||||
"编写",
|
||||
"设计",
|
||||
"整理",
|
||||
"列出",
|
||||
"提供",
|
||||
"制定",
|
||||
]
|
||||
|
||||
TYPE_ALIAS_MAP = {
|
||||
"接口测试": "外部接口测试",
|
||||
"ui测试": "人机交互界面测试",
|
||||
"界面测试": "人机交互界面测试",
|
||||
"恢复测试": "恢复性测试",
|
||||
"可靠性": "可靠性测试",
|
||||
"安全性": "安全性测试",
|
||||
"边界": "边界测试",
|
||||
"安装": "安装性测试",
|
||||
"互操作": "互操作性测试",
|
||||
"敏感性": "敏感性测试",
|
||||
"充分性": "测试充分性要求",
|
||||
}
|
||||
|
||||
|
||||
def _escape_stream_text(text: str) -> str:
|
||||
return text.replace('"', '\\"').replace("\n", "\\n")
|
||||
|
||||
|
||||
def _extract_stream_text(chunk: Any) -> str:
|
||||
content = getattr(chunk, "content", chunk)
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
maybe_text = item.get("text")
|
||||
if isinstance(maybe_text, str):
|
||||
parts.append(maybe_text)
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "".join(parts)
|
||||
|
||||
return str(content)
|
||||
|
||||
|
||||
def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
preview = []
|
||||
for row in rows[:10]:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
preview.append(
|
||||
{
|
||||
"kb_id": row.get("kb_id"),
|
||||
"source": metadata.get("source") or metadata.get("file_name") or "unknown",
|
||||
"chunk_id": metadata.get("chunk_id") or "unknown",
|
||||
"score": row.get("final_score", 0),
|
||||
"reranker_score": row.get("reranker_score"),
|
||||
}
|
||||
)
|
||||
return preview
|
||||
|
||||
|
||||
def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
context_rows: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = dict(doc.metadata or {})
|
||||
|
||||
if "kb_id" not in metadata and row.get("kb_id") is not None:
|
||||
metadata["kb_id"] = row.get("kb_id")
|
||||
metadata.setdefault("retrieval_score", row.get("final_score", 0))
|
||||
if row.get("reranker_score") is not None:
|
||||
metadata.setdefault("reranker_score", row.get("reranker_score"))
|
||||
|
||||
context_rows.append(
|
||||
{
|
||||
"page_content": doc.page_content.strip(),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return context_rows
|
||||
|
||||
|
||||
def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
||||
entities = set()
|
||||
relations: List[Dict[str, Any]] = []
|
||||
evidences: List[str] = []
|
||||
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
for ent in metadata.get("extracted_entities", []):
|
||||
entities.add(str(ent))
|
||||
|
||||
for rel in metadata.get("extracted_relations", []):
|
||||
if isinstance(rel, dict):
|
||||
relations.append(rel)
|
||||
|
||||
evidences.append(doc.page_content.strip())
|
||||
|
||||
entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。"
|
||||
|
||||
relation_lines: List[str] = []
|
||||
for rel in relations[:120]:
|
||||
src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN"
|
||||
tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN"
|
||||
rel_type = rel.get("type") or rel.get("relation_type") or "其他"
|
||||
desc = rel.get("description") or ""
|
||||
relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}")
|
||||
|
||||
relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。"
|
||||
|
||||
evidence_block = "\n\n".join(
|
||||
f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1)
|
||||
)
|
||||
if not evidence_block:
|
||||
evidence_block = "无可用证据。"
|
||||
|
||||
return (
|
||||
"实体列表:\n"
|
||||
f"{entity_block}\n\n"
|
||||
"关系列表:\n"
|
||||
f"{relation_block}\n\n"
|
||||
"原文证据:\n"
|
||||
f"{evidence_block}"
|
||||
)
|
||||
|
||||
|
||||
def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
||||
groups: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
community_ids = metadata.get("community_ids") or []
|
||||
|
||||
if isinstance(community_ids, list) and community_ids:
|
||||
keys = [str(item) for item in community_ids]
|
||||
else:
|
||||
source = metadata.get("source") or metadata.get("file_name") or "unknown"
|
||||
keys = [f"source:{source}"]
|
||||
|
||||
for key in keys:
|
||||
groups[key].append(doc.page_content.strip())
|
||||
|
||||
if not groups:
|
||||
return "暂无社区摘要数据,已回退为基于证据片段的全局总结。"
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, (community_id, snippets) in enumerate(groups.items(), start=1):
|
||||
merged = " ".join(snippets[:3])
|
||||
lines.append(f"社区{idx} ({community_id}) 摘要: {merged}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
||||
if not documents:
|
||||
continue
|
||||
|
||||
store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb.id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
||||
|
||||
return kb_vector_stores
|
||||
|
||||
|
||||
def _build_reranker_client() -> ExternalRerankerClient:
|
||||
return ExternalRerankerClient(
|
||||
api_url=settings.RERANKER_API_URL,
|
||||
api_key=settings.RERANKER_API_KEY,
|
||||
model=settings.RERANKER_MODEL,
|
||||
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _is_testing_generation_request(query: str) -> bool:
|
||||
text = (query or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
normalized = text.lower()
|
||||
if normalized.startswith("/testing"):
|
||||
return True
|
||||
|
||||
if any(
|
||||
token in normalized
|
||||
for token in (
|
||||
"testing_orchestrator",
|
||||
"testing-orchestrator",
|
||||
"identify_requirement_type",
|
||||
"identify-requirement-type",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS)
|
||||
has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS)
|
||||
if has_target and has_action:
|
||||
return True
|
||||
|
||||
if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")):
|
||||
if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text):
|
||||
return True
|
||||
if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _extract_requirement_type_from_query(query: str) -> Optional[str]:
|
||||
text = (query or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
for req_type in REQUIREMENT_TYPES:
|
||||
if req_type in text:
|
||||
return req_type
|
||||
|
||||
lowered = text.lower()
|
||||
for alias, req_type in TYPE_ALIAS_MAP.items():
|
||||
if alias in text or alias in lowered:
|
||||
return req_type
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def generate_response(
|
||||
query: str,
|
||||
messages: dict,
|
||||
knowledge_base_ids: List[int],
|
||||
chat_id: int,
|
||||
db: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
user_message = Message(content=query, role="user", chat_id=chat_id)
|
||||
db.add(user_message)
|
||||
db.commit()
|
||||
|
||||
bot_message = Message(content="", role="assistant", chat_id=chat_id)
|
||||
db.add(bot_message)
|
||||
db.commit()
|
||||
|
||||
if _is_testing_generation_request(query):
|
||||
explicit_type = _extract_requirement_type_from_query(query)
|
||||
|
||||
retrieval_rows: List[Dict[str, Any]] = []
|
||||
knowledge_context = ""
|
||||
kb_vector_stores = []
|
||||
if knowledge_base_ids:
|
||||
testing_kbs = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
||||
.all()
|
||||
)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
|
||||
|
||||
if kb_vector_stores:
|
||||
testing_retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
retrieval_rows = await testing_retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=16,
|
||||
top_k=8,
|
||||
)
|
||||
if retrieval_rows:
|
||||
knowledge_context = format_retrieval_context(retrieval_rows)
|
||||
|
||||
pipeline_result = run_testing_pipeline(
|
||||
user_requirement_text=query,
|
||||
requirement_type_input=explicit_type,
|
||||
debug=True,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=True,
|
||||
max_items_per_group=6,
|
||||
cases_per_item=1,
|
||||
max_focus_points=6,
|
||||
max_llm_calls=2,
|
||||
)
|
||||
|
||||
context_payload = {
|
||||
"route": {
|
||||
"intent": "TESTING",
|
||||
"reason": "命中测试生成意图,已自动调用测试工具链。",
|
||||
},
|
||||
"intent": "TESTING",
|
||||
"skill_profile": "testing-orchestrator",
|
||||
"tool_chain": [
|
||||
"identify-requirement-type",
|
||||
"decompose-test-items",
|
||||
"generate-test-cases",
|
||||
"build_expected_results",
|
||||
"format_output",
|
||||
],
|
||||
"selected_chain": "TESTING_PIPELINE",
|
||||
"graph_used": False,
|
||||
"reranker_enabled": False,
|
||||
"retrieval_preview": _preview_rows(retrieval_rows),
|
||||
"context": _context_rows(retrieval_rows),
|
||||
"testing_pipeline": {
|
||||
"trace_id": pipeline_result.get("trace_id"),
|
||||
"requirement_type": pipeline_result.get("requirement_type"),
|
||||
"candidates": pipeline_result.get("candidates", []),
|
||||
"pipeline_summary": pipeline_result.get("pipeline_summary", ""),
|
||||
"knowledge_used": pipeline_result.get("knowledge_used", False),
|
||||
"step_logs": pipeline_result.get("step_logs", []),
|
||||
},
|
||||
}
|
||||
|
||||
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
||||
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
||||
separator = "__LLM_RESPONSE__"
|
||||
|
||||
full_response = f"{base64_context}{separator}"
|
||||
yield f'0:"{base64_context}{separator}"\n'
|
||||
|
||||
rendered_text = pipeline_result.get("formatted_output", "").strip()
|
||||
if not rendered_text:
|
||||
rendered_text = "未生成测试内容,请补充更明确的需求后重试。"
|
||||
|
||||
full_response += rendered_text
|
||||
yield f'0:"{_escape_stream_text(rendered_text)}"\n'
|
||||
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
||||
|
||||
bot_message.content = full_response
|
||||
db.commit()
|
||||
return
|
||||
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
||||
.all()
|
||||
)
|
||||
kb_ids = [kb.id for kb in knowledge_bases]
|
||||
|
||||
llm = LLMFactory.create()
|
||||
decision = await route_intent(llm=llm, query=query, messages=messages)
|
||||
intent = decision["intent"]
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
if intent in {"B", "C", "D"} and not kb_vector_stores:
|
||||
intent = "A"
|
||||
decision = {
|
||||
"intent": "A",
|
||||
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
|
||||
}
|
||||
|
||||
reranker_client = _build_reranker_client()
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_client=reranker_client,
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
|
||||
retrieval_rows: List[Dict[str, Any]] = []
|
||||
graph_used = False
|
||||
selected_chain = intent
|
||||
prompt_text = ""
|
||||
|
||||
if intent == "A":
|
||||
prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query)
|
||||
|
||||
elif intent == "B":
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=16,
|
||||
top_k=12,
|
||||
)
|
||||
context = format_retrieval_context(retrieval_rows) or "无可用证据。"
|
||||
prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context)
|
||||
|
||||
elif intent == "C":
|
||||
graph_context = ""
|
||||
used_kb_ids: List[int] = []
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
graph_context, used_kb_ids = await adapter.local_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
top_k=settings.GRAPHRAG_LOCAL_TOP_K,
|
||||
level=settings.GRAPHRAG_QUERY_LEVEL,
|
||||
)
|
||||
graph_used = bool(graph_context)
|
||||
except Exception:
|
||||
graph_context = ""
|
||||
|
||||
if not graph_context:
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=18,
|
||||
top_k=14,
|
||||
)
|
||||
graph_context = _build_local_graph_context_fallback(retrieval_rows)
|
||||
selected_chain = "C_fallback_B"
|
||||
|
||||
else:
|
||||
selected_chain = "C_graph"
|
||||
|
||||
prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format(
|
||||
query=query,
|
||||
graph_context=graph_context,
|
||||
)
|
||||
|
||||
else:
|
||||
community_context = ""
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
community_context, used_kb_ids = await adapter.global_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
level=settings.GRAPHRAG_QUERY_LEVEL,
|
||||
)
|
||||
graph_used = bool(community_context)
|
||||
except Exception:
|
||||
community_context = ""
|
||||
|
||||
if not community_context:
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=20,
|
||||
top_k=14,
|
||||
)
|
||||
community_context = _build_global_community_context_fallback(retrieval_rows)
|
||||
selected_chain = "D_fallback_B"
|
||||
else:
|
||||
selected_chain = "D_graph"
|
||||
|
||||
prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format(
|
||||
query=query,
|
||||
community_context=community_context,
|
||||
)
|
||||
|
||||
context_payload = {
|
||||
"route": decision,
|
||||
"intent": intent,
|
||||
"selected_chain": selected_chain,
|
||||
"graph_used": graph_used,
|
||||
"reranker_enabled": reranker_client.enabled,
|
||||
"retrieval_preview": _preview_rows(retrieval_rows),
|
||||
"context": _context_rows(retrieval_rows),
|
||||
}
|
||||
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
||||
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
||||
separator = "__LLM_RESPONSE__"
|
||||
|
||||
full_response = f"{base64_context}{separator}"
|
||||
yield f'0:"{base64_context}{separator}"\n'
|
||||
|
||||
async for chunk in llm.astream(prompt_text):
|
||||
text = _extract_stream_text(chunk)
|
||||
if not text:
|
||||
continue
|
||||
full_response += text
|
||||
yield f'0:"{_escape_stream_text(text)}"\n'
|
||||
|
||||
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
||||
|
||||
bot_message.content = full_response
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error generating response: {str(e)}"
|
||||
print(error_message)
|
||||
yield "3:{text}\n".format(text=error_message)
|
||||
|
||||
if "bot_message" in locals():
|
||||
bot_message.content = error_message
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user