init. project
This commit is contained in:
120
rag-web-ui/backend/app/services/intent_router.py
Normal file
120
rag-web-ui/backend/app/services/intent_router.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.services.fusion_prompts import (
|
||||
ROUTER_SYSTEM_PROMPT,
|
||||
ROUTER_USER_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
VALID_INTENTS = {"A", "B", "C", "D"}
|
||||
|
||||
|
||||
def _extract_json_object(raw_text: str) -> Dict[str, str]:
|
||||
"""Extract and parse the first JSON object from model output."""
|
||||
cleaned = raw_text.strip()
|
||||
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
|
||||
|
||||
match = re.search(r"\{[\s\S]*\}", cleaned)
|
||||
if not match:
|
||||
raise ValueError("No JSON object found in router output")
|
||||
|
||||
data = json.loads(match.group(0))
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Router output JSON is not an object")
|
||||
|
||||
intent = str(data.get("intent", "")).strip().upper()
|
||||
reason = str(data.get("reason", "")).strip()
|
||||
if intent not in VALID_INTENTS:
|
||||
raise ValueError(f"Invalid intent: {intent}")
|
||||
|
||||
if not reason:
|
||||
reason = "模型未提供理由,已按规则兜底。"
|
||||
|
||||
return {"intent": intent, "reason": reason}
|
||||
|
||||
|
||||
def _build_history_text(messages: dict, max_turns: int = 6) -> str:
|
||||
if not isinstance(messages, dict):
|
||||
return ""
|
||||
|
||||
history = messages.get("messages", [])
|
||||
if not isinstance(history, list):
|
||||
return ""
|
||||
|
||||
tail = history[-max_turns:]
|
||||
rows: List[str] = []
|
||||
for msg in tail:
|
||||
role = str(msg.get("role", "unknown")).strip()
|
||||
content = str(msg.get("content", "")).strip().replace("\n", " ")
|
||||
if content:
|
||||
rows.append(f"{role}: {content}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def _heuristic_route(query: str) -> Dict[str, str]:
|
||||
text = query.strip().lower()
|
||||
|
||||
general_chat_patterns = [
|
||||
"你好",
|
||||
"您好",
|
||||
"在吗",
|
||||
"谢谢",
|
||||
"早上好",
|
||||
"晚上好",
|
||||
"你是谁",
|
||||
"讲个笑话",
|
||||
]
|
||||
global_patterns = [
|
||||
"总结",
|
||||
"综述",
|
||||
"整体",
|
||||
"全局",
|
||||
"趋势",
|
||||
"对比",
|
||||
"比较",
|
||||
"宏观",
|
||||
"共性",
|
||||
"差异",
|
||||
]
|
||||
local_graph_patterns = [
|
||||
"关系",
|
||||
"依赖",
|
||||
"影响",
|
||||
"导致",
|
||||
"原因",
|
||||
"链路",
|
||||
"多跳",
|
||||
"传导",
|
||||
"耦合",
|
||||
"约束",
|
||||
]
|
||||
|
||||
if any(token in text for token in general_chat_patterns):
|
||||
return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"}
|
||||
|
||||
if any(token in text for token in global_patterns):
|
||||
return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"}
|
||||
|
||||
if any(token in text for token in local_graph_patterns):
|
||||
return {"intent": "C", "reason": "问题强调实体关系与链式推理。"}
|
||||
|
||||
return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"}
|
||||
|
||||
|
||||
async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]:
|
||||
"""Route user query to A/B/C/D with LLM-first and heuristic fallback."""
|
||||
history_text = _build_history_text(messages)
|
||||
user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format(
|
||||
chat_history=history_text or "无",
|
||||
query=query,
|
||||
)
|
||||
|
||||
try:
|
||||
full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}"
|
||||
model_resp = await llm.ainvoke(full_prompt)
|
||||
content = getattr(model_resp, "content", model_resp)
|
||||
raw_text = content if isinstance(content, str) else str(content)
|
||||
return _extract_json_object(raw_text)
|
||||
except Exception:
|
||||
return _heuristic_route(query)
|
||||
Reference in New Issue
Block a user