121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
import json
|
|
import re
|
|
from typing import Any, Dict, List
|
|
|
|
from app.services.fusion_prompts import (
|
|
ROUTER_SYSTEM_PROMPT,
|
|
ROUTER_USER_PROMPT_TEMPLATE,
|
|
)
|
|
|
|
VALID_INTENTS = {"A", "B", "C", "D"}
|
|
|
|
|
|
def _extract_json_object(raw_text: str) -> Dict[str, str]:
|
|
"""Extract and parse the first JSON object from model output."""
|
|
cleaned = raw_text.strip()
|
|
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
|
|
|
|
match = re.search(r"\{[\s\S]*\}", cleaned)
|
|
if not match:
|
|
raise ValueError("No JSON object found in router output")
|
|
|
|
data = json.loads(match.group(0))
|
|
if not isinstance(data, dict):
|
|
raise ValueError("Router output JSON is not an object")
|
|
|
|
intent = str(data.get("intent", "")).strip().upper()
|
|
reason = str(data.get("reason", "")).strip()
|
|
if intent not in VALID_INTENTS:
|
|
raise ValueError(f"Invalid intent: {intent}")
|
|
|
|
if not reason:
|
|
reason = "模型未提供理由,已按规则兜底。"
|
|
|
|
return {"intent": intent, "reason": reason}
|
|
|
|
|
|
def _build_history_text(messages: dict, max_turns: int = 6) -> str:
|
|
if not isinstance(messages, dict):
|
|
return ""
|
|
|
|
history = messages.get("messages", [])
|
|
if not isinstance(history, list):
|
|
return ""
|
|
|
|
tail = history[-max_turns:]
|
|
rows: List[str] = []
|
|
for msg in tail:
|
|
role = str(msg.get("role", "unknown")).strip()
|
|
content = str(msg.get("content", "")).strip().replace("\n", " ")
|
|
if content:
|
|
rows.append(f"{role}: {content}")
|
|
return "\n".join(rows)
|
|
|
|
|
|
def _heuristic_route(query: str) -> Dict[str, str]:
|
|
text = query.strip().lower()
|
|
|
|
general_chat_patterns = [
|
|
"你好",
|
|
"您好",
|
|
"在吗",
|
|
"谢谢",
|
|
"早上好",
|
|
"晚上好",
|
|
"你是谁",
|
|
"讲个笑话",
|
|
]
|
|
global_patterns = [
|
|
"总结",
|
|
"综述",
|
|
"整体",
|
|
"全局",
|
|
"趋势",
|
|
"对比",
|
|
"比较",
|
|
"宏观",
|
|
"共性",
|
|
"差异",
|
|
]
|
|
local_graph_patterns = [
|
|
"关系",
|
|
"依赖",
|
|
"影响",
|
|
"导致",
|
|
"原因",
|
|
"链路",
|
|
"多跳",
|
|
"传导",
|
|
"耦合",
|
|
"约束",
|
|
]
|
|
|
|
if any(token in text for token in general_chat_patterns):
|
|
return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"}
|
|
|
|
if any(token in text for token in global_patterns):
|
|
return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"}
|
|
|
|
if any(token in text for token in local_graph_patterns):
|
|
return {"intent": "C", "reason": "问题强调实体关系与链式推理。"}
|
|
|
|
return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"}
|
|
|
|
|
|
async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]:
|
|
"""Route user query to A/B/C/D with LLM-first and heuristic fallback."""
|
|
history_text = _build_history_text(messages)
|
|
user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format(
|
|
chat_history=history_text or "无",
|
|
query=query,
|
|
)
|
|
|
|
try:
|
|
full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}"
|
|
model_resp = await llm.ainvoke(full_prompt)
|
|
content = getattr(model_resp, "content", model_resp)
|
|
raw_text = content if isinstance(content, str) else str(content)
|
|
return _extract_json_object(raw_text)
|
|
except Exception:
|
|
return _heuristic_route(query)
|