import json import re from typing import Any, Dict, List from app.services.fusion_prompts import ( ROUTER_SYSTEM_PROMPT, ROUTER_USER_PROMPT_TEMPLATE, ) VALID_INTENTS = {"A", "B", "C", "D"} def _extract_json_object(raw_text: str) -> Dict[str, str]: """Extract and parse the first JSON object from model output.""" cleaned = raw_text.strip() cleaned = cleaned.replace("```json", "").replace("```", "").strip() match = re.search(r"\{[\s\S]*\}", cleaned) if not match: raise ValueError("No JSON object found in router output") data = json.loads(match.group(0)) if not isinstance(data, dict): raise ValueError("Router output JSON is not an object") intent = str(data.get("intent", "")).strip().upper() reason = str(data.get("reason", "")).strip() if intent not in VALID_INTENTS: raise ValueError(f"Invalid intent: {intent}") if not reason: reason = "模型未提供理由,已按规则兜底。" return {"intent": intent, "reason": reason} def _build_history_text(messages: dict, max_turns: int = 6) -> str: if not isinstance(messages, dict): return "" history = messages.get("messages", []) if not isinstance(history, list): return "" tail = history[-max_turns:] rows: List[str] = [] for msg in tail: role = str(msg.get("role", "unknown")).strip() content = str(msg.get("content", "")).strip().replace("\n", " ") if content: rows.append(f"{role}: {content}") return "\n".join(rows) def _heuristic_route(query: str) -> Dict[str, str]: text = query.strip().lower() general_chat_patterns = [ "你好", "您好", "在吗", "谢谢", "早上好", "晚上好", "你是谁", "讲个笑话", ] global_patterns = [ "总结", "综述", "整体", "全局", "趋势", "对比", "比较", "宏观", "共性", "差异", ] local_graph_patterns = [ "关系", "依赖", "影响", "导致", "原因", "链路", "多跳", "传导", "耦合", "约束", ] if any(token in text for token in general_chat_patterns): return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"} if any(token in text for token in global_patterns): return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"} if any(token in text for token in local_graph_patterns): return {"intent": "C", "reason": "问题强调实体关系与链式推理。"} return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"} async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]: """Route user query to A/B/C/D with LLM-first and heuristic fallback.""" history_text = _build_history_text(messages) user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format( chat_history=history_text or "无", query=query, ) try: full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}" model_resp = await llm.ainvoke(full_prompt) content = getattr(model_resp, "content", model_resp) raw_text = content if isinstance(content, str) else str(content) return _extract_json_object(raw_text) except Exception: return _heuristic_route(query)