Files
rag_agent/rag-web-ui/backend/app/services/intent_router.py

121 lines
3.5 KiB
Python
Raw Normal View History

2026-04-13 11:34:23 +08:00
import json
import re
from typing import Any, Dict, List
from app.services.fusion_prompts import (
ROUTER_SYSTEM_PROMPT,
ROUTER_USER_PROMPT_TEMPLATE,
)
VALID_INTENTS = {"A", "B", "C", "D"}
def _extract_json_object(raw_text: str) -> Dict[str, str]:
"""Extract and parse the first JSON object from model output."""
cleaned = raw_text.strip()
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
match = re.search(r"\{[\s\S]*\}", cleaned)
if not match:
raise ValueError("No JSON object found in router output")
data = json.loads(match.group(0))
if not isinstance(data, dict):
raise ValueError("Router output JSON is not an object")
intent = str(data.get("intent", "")).strip().upper()
reason = str(data.get("reason", "")).strip()
if intent not in VALID_INTENTS:
raise ValueError(f"Invalid intent: {intent}")
if not reason:
reason = "模型未提供理由,已按规则兜底。"
return {"intent": intent, "reason": reason}
def _build_history_text(messages: dict, max_turns: int = 6) -> str:
if not isinstance(messages, dict):
return ""
history = messages.get("messages", [])
if not isinstance(history, list):
return ""
tail = history[-max_turns:]
rows: List[str] = []
for msg in tail:
role = str(msg.get("role", "unknown")).strip()
content = str(msg.get("content", "")).strip().replace("\n", " ")
if content:
rows.append(f"{role}: {content}")
return "\n".join(rows)
def _heuristic_route(query: str) -> Dict[str, str]:
text = query.strip().lower()
general_chat_patterns = [
"你好",
"您好",
"在吗",
"谢谢",
"早上好",
"晚上好",
"你是谁",
"讲个笑话",
]
global_patterns = [
"总结",
"综述",
"整体",
"全局",
"趋势",
"对比",
"比较",
"宏观",
"共性",
"差异",
]
local_graph_patterns = [
"关系",
"依赖",
"影响",
"导致",
"原因",
"链路",
"多跳",
"传导",
"耦合",
"约束",
]
if any(token in text for token in general_chat_patterns):
return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"}
if any(token in text for token in global_patterns):
return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"}
if any(token in text for token in local_graph_patterns):
return {"intent": "C", "reason": "问题强调实体关系与链式推理。"}
return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"}
async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]:
"""Route user query to A/B/C/D with LLM-first and heuristic fallback."""
history_text = _build_history_text(messages)
user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format(
chat_history=history_text or "",
query=query,
)
try:
full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}"
model_resp = await llm.ainvoke(full_prompt)
content = getattr(model_resp, "content", model_resp)
raw_text = content if isinstance(content, str) else str(content)
return _extract_json_object(raw_text)
except Exception:
return _heuristic_route(query)