init. project
This commit is contained in:
0
rag-web-ui/backend/app/services/__init__.py
Normal file
0
rag-web-ui/backend/app/services/__init__.py
Normal file
61
rag-web-ui/backend/app/services/api_key.py
Normal file
61
rag-web-ui/backend/app/services/api_key.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.api_key import APIKey
|
||||
from app.schemas.api_key import APIKeyCreate, APIKeyUpdate
|
||||
|
||||
class APIKeyService:
|
||||
@staticmethod
|
||||
def get_api_keys(db: Session, user_id: int, skip: int = 0, limit: int = 100) -> List[APIKey]:
|
||||
return (
|
||||
db.query(APIKey)
|
||||
.filter(APIKey.user_id == user_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(db: Session, user_id: int, name: str) -> APIKey:
|
||||
api_key = APIKey(
|
||||
key=f"sk-{secrets.token_hex(32)}",
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
is_active=True
|
||||
)
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(db: Session, api_key_id: int) -> Optional[APIKey]:
|
||||
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_key(db: Session, key: str) -> Optional[APIKey]:
|
||||
return db.query(APIKey).filter(APIKey.key == key).first()
|
||||
|
||||
@staticmethod
|
||||
def update_api_key(db: Session, api_key: APIKey, update_data: APIKeyUpdate) -> APIKey:
|
||||
for field, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(api_key, field, value)
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(db: Session, api_key: APIKey) -> None:
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_last_used(db: Session, api_key: APIKey) -> APIKey:
|
||||
api_key.last_used_at = datetime.utcnow()
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
return api_key
|
||||
532
rag-web-ui/backend/app/services/chat_service.py
Normal file
532
rag-web-ui/backend/app/services/chat_service.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.chat import Message
|
||||
from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.fusion_prompts import (
|
||||
GENERAL_CHAT_PROMPT_TEMPLATE,
|
||||
GRAPH_GLOBAL_PROMPT_TEMPLATE,
|
||||
GRAPH_LOCAL_PROMPT_TEMPLATE,
|
||||
HYBRID_RAG_PROMPT_TEMPLATE,
|
||||
)
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
from app.services.intent_router import route_intent
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.reranker.external_api import ExternalRerankerClient
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline.pipeline import run_testing_pipeline
|
||||
from app.services.testing_pipeline.rules import REQUIREMENT_TYPES
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
|
||||
TESTING_TARGET_KEYWORDS = [
|
||||
"测试项",
|
||||
"测试用例",
|
||||
"预期成果",
|
||||
"需求类型",
|
||||
"测试分解",
|
||||
"分解",
|
||||
"正常测试",
|
||||
"异常测试",
|
||||
"测试充分性",
|
||||
]
|
||||
|
||||
TESTING_ACTION_KEYWORDS = [
|
||||
"生成",
|
||||
"输出",
|
||||
"给出",
|
||||
"写",
|
||||
"编写",
|
||||
"设计",
|
||||
"整理",
|
||||
"列出",
|
||||
"提供",
|
||||
"制定",
|
||||
]
|
||||
|
||||
TYPE_ALIAS_MAP = {
|
||||
"接口测试": "外部接口测试",
|
||||
"ui测试": "人机交互界面测试",
|
||||
"界面测试": "人机交互界面测试",
|
||||
"恢复测试": "恢复性测试",
|
||||
"可靠性": "可靠性测试",
|
||||
"安全性": "安全性测试",
|
||||
"边界": "边界测试",
|
||||
"安装": "安装性测试",
|
||||
"互操作": "互操作性测试",
|
||||
"敏感性": "敏感性测试",
|
||||
"充分性": "测试充分性要求",
|
||||
}
|
||||
|
||||
|
||||
def _escape_stream_text(text: str) -> str:
|
||||
return text.replace('"', '\\"').replace("\n", "\\n")
|
||||
|
||||
|
||||
def _extract_stream_text(chunk: Any) -> str:
|
||||
content = getattr(chunk, "content", chunk)
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
maybe_text = item.get("text")
|
||||
if isinstance(maybe_text, str):
|
||||
parts.append(maybe_text)
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "".join(parts)
|
||||
|
||||
return str(content)
|
||||
|
||||
|
||||
def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
preview = []
|
||||
for row in rows[:10]:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
preview.append(
|
||||
{
|
||||
"kb_id": row.get("kb_id"),
|
||||
"source": metadata.get("source") or metadata.get("file_name") or "unknown",
|
||||
"chunk_id": metadata.get("chunk_id") or "unknown",
|
||||
"score": row.get("final_score", 0),
|
||||
"reranker_score": row.get("reranker_score"),
|
||||
}
|
||||
)
|
||||
return preview
|
||||
|
||||
|
||||
def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
context_rows: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = dict(doc.metadata or {})
|
||||
|
||||
if "kb_id" not in metadata and row.get("kb_id") is not None:
|
||||
metadata["kb_id"] = row.get("kb_id")
|
||||
metadata.setdefault("retrieval_score", row.get("final_score", 0))
|
||||
if row.get("reranker_score") is not None:
|
||||
metadata.setdefault("reranker_score", row.get("reranker_score"))
|
||||
|
||||
context_rows.append(
|
||||
{
|
||||
"page_content": doc.page_content.strip(),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return context_rows
|
||||
|
||||
|
||||
def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
||||
entities = set()
|
||||
relations: List[Dict[str, Any]] = []
|
||||
evidences: List[str] = []
|
||||
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
for ent in metadata.get("extracted_entities", []):
|
||||
entities.add(str(ent))
|
||||
|
||||
for rel in metadata.get("extracted_relations", []):
|
||||
if isinstance(rel, dict):
|
||||
relations.append(rel)
|
||||
|
||||
evidences.append(doc.page_content.strip())
|
||||
|
||||
entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。"
|
||||
|
||||
relation_lines: List[str] = []
|
||||
for rel in relations[:120]:
|
||||
src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN"
|
||||
tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN"
|
||||
rel_type = rel.get("type") or rel.get("relation_type") or "其他"
|
||||
desc = rel.get("description") or ""
|
||||
relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}")
|
||||
|
||||
relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。"
|
||||
|
||||
evidence_block = "\n\n".join(
|
||||
f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1)
|
||||
)
|
||||
if not evidence_block:
|
||||
evidence_block = "无可用证据。"
|
||||
|
||||
return (
|
||||
"实体列表:\n"
|
||||
f"{entity_block}\n\n"
|
||||
"关系列表:\n"
|
||||
f"{relation_block}\n\n"
|
||||
"原文证据:\n"
|
||||
f"{evidence_block}"
|
||||
)
|
||||
|
||||
|
||||
def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
||||
groups: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
for row in rows:
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
community_ids = metadata.get("community_ids") or []
|
||||
|
||||
if isinstance(community_ids, list) and community_ids:
|
||||
keys = [str(item) for item in community_ids]
|
||||
else:
|
||||
source = metadata.get("source") or metadata.get("file_name") or "unknown"
|
||||
keys = [f"source:{source}"]
|
||||
|
||||
for key in keys:
|
||||
groups[key].append(doc.page_content.strip())
|
||||
|
||||
if not groups:
|
||||
return "暂无社区摘要数据,已回退为基于证据片段的全局总结。"
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, (community_id, snippets) in enumerate(groups.items(), start=1):
|
||||
merged = " ".join(snippets[:3])
|
||||
lines.append(f"社区{idx} ({community_id}) 摘要: {merged}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
||||
if not documents:
|
||||
continue
|
||||
|
||||
store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb.id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
||||
|
||||
return kb_vector_stores
|
||||
|
||||
|
||||
def _build_reranker_client() -> ExternalRerankerClient:
|
||||
return ExternalRerankerClient(
|
||||
api_url=settings.RERANKER_API_URL,
|
||||
api_key=settings.RERANKER_API_KEY,
|
||||
model=settings.RERANKER_MODEL,
|
||||
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _is_testing_generation_request(query: str) -> bool:
|
||||
text = (query or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
normalized = text.lower()
|
||||
if normalized.startswith("/testing"):
|
||||
return True
|
||||
|
||||
if any(
|
||||
token in normalized
|
||||
for token in (
|
||||
"testing_orchestrator",
|
||||
"testing-orchestrator",
|
||||
"identify_requirement_type",
|
||||
"identify-requirement-type",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS)
|
||||
has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS)
|
||||
if has_target and has_action:
|
||||
return True
|
||||
|
||||
if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")):
|
||||
if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text):
|
||||
return True
|
||||
if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _extract_requirement_type_from_query(query: str) -> Optional[str]:
|
||||
text = (query or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
for req_type in REQUIREMENT_TYPES:
|
||||
if req_type in text:
|
||||
return req_type
|
||||
|
||||
lowered = text.lower()
|
||||
for alias, req_type in TYPE_ALIAS_MAP.items():
|
||||
if alias in text or alias in lowered:
|
||||
return req_type
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def generate_response(
|
||||
query: str,
|
||||
messages: dict,
|
||||
knowledge_base_ids: List[int],
|
||||
chat_id: int,
|
||||
db: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
user_message = Message(content=query, role="user", chat_id=chat_id)
|
||||
db.add(user_message)
|
||||
db.commit()
|
||||
|
||||
bot_message = Message(content="", role="assistant", chat_id=chat_id)
|
||||
db.add(bot_message)
|
||||
db.commit()
|
||||
|
||||
if _is_testing_generation_request(query):
|
||||
explicit_type = _extract_requirement_type_from_query(query)
|
||||
|
||||
retrieval_rows: List[Dict[str, Any]] = []
|
||||
knowledge_context = ""
|
||||
kb_vector_stores = []
|
||||
if knowledge_base_ids:
|
||||
testing_kbs = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
||||
.all()
|
||||
)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
|
||||
|
||||
if kb_vector_stores:
|
||||
testing_retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
retrieval_rows = await testing_retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=16,
|
||||
top_k=8,
|
||||
)
|
||||
if retrieval_rows:
|
||||
knowledge_context = format_retrieval_context(retrieval_rows)
|
||||
|
||||
pipeline_result = run_testing_pipeline(
|
||||
user_requirement_text=query,
|
||||
requirement_type_input=explicit_type,
|
||||
debug=True,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=True,
|
||||
max_items_per_group=6,
|
||||
cases_per_item=1,
|
||||
max_focus_points=6,
|
||||
max_llm_calls=2,
|
||||
)
|
||||
|
||||
context_payload = {
|
||||
"route": {
|
||||
"intent": "TESTING",
|
||||
"reason": "命中测试生成意图,已自动调用测试工具链。",
|
||||
},
|
||||
"intent": "TESTING",
|
||||
"skill_profile": "testing-orchestrator",
|
||||
"tool_chain": [
|
||||
"identify-requirement-type",
|
||||
"decompose-test-items",
|
||||
"generate-test-cases",
|
||||
"build_expected_results",
|
||||
"format_output",
|
||||
],
|
||||
"selected_chain": "TESTING_PIPELINE",
|
||||
"graph_used": False,
|
||||
"reranker_enabled": False,
|
||||
"retrieval_preview": _preview_rows(retrieval_rows),
|
||||
"context": _context_rows(retrieval_rows),
|
||||
"testing_pipeline": {
|
||||
"trace_id": pipeline_result.get("trace_id"),
|
||||
"requirement_type": pipeline_result.get("requirement_type"),
|
||||
"candidates": pipeline_result.get("candidates", []),
|
||||
"pipeline_summary": pipeline_result.get("pipeline_summary", ""),
|
||||
"knowledge_used": pipeline_result.get("knowledge_used", False),
|
||||
"step_logs": pipeline_result.get("step_logs", []),
|
||||
},
|
||||
}
|
||||
|
||||
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
||||
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
||||
separator = "__LLM_RESPONSE__"
|
||||
|
||||
full_response = f"{base64_context}{separator}"
|
||||
yield f'0:"{base64_context}{separator}"\n'
|
||||
|
||||
rendered_text = pipeline_result.get("formatted_output", "").strip()
|
||||
if not rendered_text:
|
||||
rendered_text = "未生成测试内容,请补充更明确的需求后重试。"
|
||||
|
||||
full_response += rendered_text
|
||||
yield f'0:"{_escape_stream_text(rendered_text)}"\n'
|
||||
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
||||
|
||||
bot_message.content = full_response
|
||||
db.commit()
|
||||
return
|
||||
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
||||
.all()
|
||||
)
|
||||
kb_ids = [kb.id for kb in knowledge_bases]
|
||||
|
||||
llm = LLMFactory.create()
|
||||
decision = await route_intent(llm=llm, query=query, messages=messages)
|
||||
intent = decision["intent"]
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
if intent in {"B", "C", "D"} and not kb_vector_stores:
|
||||
intent = "A"
|
||||
decision = {
|
||||
"intent": "A",
|
||||
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
|
||||
}
|
||||
|
||||
reranker_client = _build_reranker_client()
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_client=reranker_client,
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
|
||||
retrieval_rows: List[Dict[str, Any]] = []
|
||||
graph_used = False
|
||||
selected_chain = intent
|
||||
prompt_text = ""
|
||||
|
||||
if intent == "A":
|
||||
prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query)
|
||||
|
||||
elif intent == "B":
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=16,
|
||||
top_k=12,
|
||||
)
|
||||
context = format_retrieval_context(retrieval_rows) or "无可用证据。"
|
||||
prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context)
|
||||
|
||||
elif intent == "C":
|
||||
graph_context = ""
|
||||
used_kb_ids: List[int] = []
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
graph_context, used_kb_ids = await adapter.local_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
top_k=settings.GRAPHRAG_LOCAL_TOP_K,
|
||||
level=settings.GRAPHRAG_QUERY_LEVEL,
|
||||
)
|
||||
graph_used = bool(graph_context)
|
||||
except Exception:
|
||||
graph_context = ""
|
||||
|
||||
if not graph_context:
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=18,
|
||||
top_k=14,
|
||||
)
|
||||
graph_context = _build_local_graph_context_fallback(retrieval_rows)
|
||||
selected_chain = "C_fallback_B"
|
||||
|
||||
else:
|
||||
selected_chain = "C_graph"
|
||||
|
||||
prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format(
|
||||
query=query,
|
||||
graph_context=graph_context,
|
||||
)
|
||||
|
||||
else:
|
||||
community_context = ""
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
community_context, used_kb_ids = await adapter.global_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
level=settings.GRAPHRAG_QUERY_LEVEL,
|
||||
)
|
||||
graph_used = bool(community_context)
|
||||
except Exception:
|
||||
community_context = ""
|
||||
|
||||
if not community_context:
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=query,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=20,
|
||||
top_k=14,
|
||||
)
|
||||
community_context = _build_global_community_context_fallback(retrieval_rows)
|
||||
selected_chain = "D_fallback_B"
|
||||
else:
|
||||
selected_chain = "D_graph"
|
||||
|
||||
prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format(
|
||||
query=query,
|
||||
community_context=community_context,
|
||||
)
|
||||
|
||||
context_payload = {
|
||||
"route": decision,
|
||||
"intent": intent,
|
||||
"selected_chain": selected_chain,
|
||||
"graph_used": graph_used,
|
||||
"reranker_enabled": reranker_client.enabled,
|
||||
"retrieval_preview": _preview_rows(retrieval_rows),
|
||||
"context": _context_rows(retrieval_rows),
|
||||
}
|
||||
escaped_context = json.dumps(context_payload, ensure_ascii=False)
|
||||
base64_context = base64.b64encode(escaped_context.encode()).decode()
|
||||
separator = "__LLM_RESPONSE__"
|
||||
|
||||
full_response = f"{base64_context}{separator}"
|
||||
yield f'0:"{base64_context}{separator}"\n'
|
||||
|
||||
async for chunk in llm.astream(prompt_text):
|
||||
text = _extract_stream_text(chunk)
|
||||
if not text:
|
||||
continue
|
||||
full_response += text
|
||||
yield f'0:"{_escape_stream_text(text)}"\n'
|
||||
|
||||
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
|
||||
|
||||
bot_message.content = full_response
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error generating response: {str(e)}"
|
||||
print(error_message)
|
||||
yield "3:{text}\n".format(text=error_message)
|
||||
|
||||
if "bot_message" in locals():
|
||||
bot_message.content = error_message
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
69
rag-web-ui/backend/app/services/chunk_record.py
Normal file
69
rag-web-ui/backend/app/services/chunk_record.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Optional, List, Dict, Set
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.config import settings
|
||||
from app.models.knowledge import DocumentChunk
|
||||
import json
|
||||
|
||||
class ChunkRecord:
|
||||
"""Manages chunk-level record keeping for incremental updates"""
|
||||
def __init__(self, kb_id: int):
|
||||
self.kb_id = kb_id
|
||||
self.engine = create_engine(settings.get_database_url)
|
||||
|
||||
def list_chunks(self, file_name: Optional[str] = None) -> Set[str]:
|
||||
"""List all chunk hashes for the given file"""
|
||||
with Session(self.engine) as session:
|
||||
query = session.query(DocumentChunk.hash).filter(
|
||||
DocumentChunk.kb_id == self.kb_id
|
||||
)
|
||||
|
||||
if file_name:
|
||||
query = query.filter(DocumentChunk.file_name == file_name)
|
||||
|
||||
return {row[0] for row in query.all()}
|
||||
|
||||
def add_chunks(self, chunks: List[Dict]):
|
||||
"""Add new chunks to the database"""
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
with Session(self.engine) as session:
|
||||
for chunk_data in chunks:
|
||||
chunk = DocumentChunk(
|
||||
id=chunk_data['id'],
|
||||
kb_id=chunk_data['kb_id'],
|
||||
document_id=chunk_data['document_id'],
|
||||
file_name=chunk_data['file_name'],
|
||||
chunk_metadata=chunk_data['metadata'],
|
||||
hash=chunk_data['hash']
|
||||
)
|
||||
session.merge(chunk) # Use merge instead of add to handle updates
|
||||
session.commit()
|
||||
|
||||
def delete_chunks(self, chunk_ids: List[str]):
|
||||
"""Delete chunks by their IDs"""
|
||||
if not chunk_ids:
|
||||
return
|
||||
|
||||
with Session(self.engine) as session:
|
||||
session.query(DocumentChunk).filter(
|
||||
DocumentChunk.kb_id == self.kb_id,
|
||||
DocumentChunk.id.in_(chunk_ids)
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
def get_deleted_chunks(self, current_hashes: Set[str], file_name: Optional[str] = None) -> List[str]:
|
||||
"""Get IDs of chunks that no longer exist in the current version"""
|
||||
with Session(self.engine) as session:
|
||||
query = session.query(DocumentChunk.id).filter(
|
||||
DocumentChunk.kb_id == self.kb_id
|
||||
)
|
||||
|
||||
if file_name:
|
||||
query = query.filter(DocumentChunk.file_name == file_name)
|
||||
|
||||
if current_hashes:
|
||||
query = query.filter(DocumentChunk.hash.notin_(current_hashes))
|
||||
|
||||
return [row[0] for row in query.all()]
|
||||
582
rag-web-ui/backend/app/services/document_processor.py
Normal file
582
rag-web-ui/backend/app/services/document_processor.py
Normal file
@@ -0,0 +1,582 @@
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import tempfile
|
||||
import traceback
|
||||
import json
|
||||
from app.db.session import SessionLocal
|
||||
from io import BytesIO
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import UploadFile
|
||||
from langchain_community.document_loaders import (
|
||||
PyPDFLoader,
|
||||
Docx2txtLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
TextLoader
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_core.documents import Document as LangchainDocument
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.config import settings
|
||||
from app.core.minio import get_minio_client
|
||||
from app.models.knowledge import ProcessingTask, Document, DocumentChunk
|
||||
from app.services.chunk_record import ChunkRecord
|
||||
from minio.error import MinioException
|
||||
from minio.commonconfig import CopySource
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
|
||||
class UploadResult(BaseModel):
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
file_hash: str
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
content: str
|
||||
metadata: Optional[Dict] = None
|
||||
|
||||
class PreviewResult(BaseModel):
|
||||
chunks: List[TextChunk]
|
||||
total_chunks: int
|
||||
|
||||
|
||||
def _estimate_token_count(text: str) -> int:
|
||||
# Lightweight estimation without adding tokenizer dependencies.
|
||||
return len(text)
|
||||
|
||||
|
||||
def _build_enriched_chunk_metadata(
|
||||
*,
|
||||
source_metadata: Optional[Dict[str, Any]],
|
||||
chunk_id: str,
|
||||
file_name: str,
|
||||
file_path: str,
|
||||
kb_id: int,
|
||||
document_id: int,
|
||||
chunk_index: int,
|
||||
chunk_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
source_metadata = source_metadata or {}
|
||||
token_count = _estimate_token_count(chunk_text)
|
||||
|
||||
return {
|
||||
**source_metadata,
|
||||
"source": file_name,
|
||||
"chunk_id": chunk_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"kb_id": kb_id,
|
||||
"document_id": document_id,
|
||||
"chunk_index": chunk_index,
|
||||
"chunk_text": chunk_text,
|
||||
"token_count": token_count,
|
||||
"language": source_metadata.get("language", "zh"),
|
||||
"source_type": "document",
|
||||
"mission_phase": source_metadata.get("mission_phase"),
|
||||
"section_title": source_metadata.get("section_title"),
|
||||
"publish_time": source_metadata.get("publish_time"),
|
||||
# Keep graph-linked fields for future graph/vector federation.
|
||||
"extracted_entities": source_metadata.get("extracted_entities", []),
|
||||
"extracted_entity_types": source_metadata.get("extracted_entity_types", []),
|
||||
"extracted_relations": source_metadata.get("extracted_relations", []),
|
||||
"graph_node_ids": source_metadata.get("graph_node_ids", []),
|
||||
"graph_edge_ids": source_metadata.get("graph_edge_ids", []),
|
||||
"community_ids": source_metadata.get("community_ids", []),
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Normalize metadata to satisfy Chroma's strict metadata constraints."""
|
||||
if not metadata:
|
||||
return {}
|
||||
|
||||
sanitized: Dict[str, Any] = {}
|
||||
scalar_types = (str, int, float, bool)
|
||||
|
||||
for key, value in metadata.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, scalar_types):
|
||||
sanitized[key] = value
|
||||
continue
|
||||
|
||||
if isinstance(value, list):
|
||||
primitive_items = [item for item in value if isinstance(item, scalar_types)]
|
||||
if primitive_items:
|
||||
sanitized[key] = primitive_items
|
||||
elif value:
|
||||
sanitized[key] = json.dumps(value, ensure_ascii=False)
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized[key] = json.dumps(value, ensure_ascii=False)
|
||||
continue
|
||||
|
||||
sanitized[key] = str(value)
|
||||
|
||||
return sanitized
|
||||
|
||||
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
|
||||
"""Process document and store in vector database with incremental updates"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
preview_result = await preview_document(file_path, chunk_size, chunk_overlap)
|
||||
|
||||
# Initialize embeddings
|
||||
logger.info("Initializing OpenAI embeddings...")
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
|
||||
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
|
||||
# Initialize chunk record manager
|
||||
chunk_manager = ChunkRecord(kb_id)
|
||||
|
||||
# Get existing chunk hashes for this file
|
||||
existing_hashes = chunk_manager.list_chunks(file_name)
|
||||
|
||||
# Prepare new chunks
|
||||
new_chunks = []
|
||||
current_hashes = set()
|
||||
documents_to_update = []
|
||||
|
||||
for i, chunk in enumerate(preview_result.chunks):
|
||||
# Calculate chunk hash
|
||||
chunk_hash = hashlib.sha256(
|
||||
(chunk.content + str(chunk.metadata)).encode()
|
||||
).hexdigest()
|
||||
current_hashes.add(chunk_hash)
|
||||
|
||||
# Skip if chunk hasn't changed
|
||||
if chunk_hash in existing_hashes:
|
||||
continue
|
||||
|
||||
# Create unique ID for the chunk
|
||||
chunk_id = hashlib.sha256(
|
||||
f"{kb_id}:{file_name}:{chunk_hash}".encode()
|
||||
).hexdigest()
|
||||
|
||||
metadata = _build_enriched_chunk_metadata(
|
||||
source_metadata=chunk.metadata,
|
||||
chunk_id=chunk_id,
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
chunk_index=i,
|
||||
chunk_text=chunk.content,
|
||||
)
|
||||
vector_metadata = _sanitize_metadata_for_vector_store(metadata)
|
||||
|
||||
new_chunks.append({
|
||||
"id": chunk_id,
|
||||
"kb_id": kb_id,
|
||||
"document_id": document_id,
|
||||
"file_name": file_name,
|
||||
"metadata": metadata,
|
||||
"hash": chunk_hash
|
||||
})
|
||||
|
||||
# Prepare document for vector store
|
||||
doc = LangchainDocument(
|
||||
page_content=chunk.content,
|
||||
metadata=vector_metadata
|
||||
)
|
||||
documents_to_update.append(doc)
|
||||
|
||||
# Add new chunks to database and vector store
|
||||
if new_chunks:
|
||||
logger.info(f"Adding {len(new_chunks)} new/updated chunks")
|
||||
chunk_manager.add_chunks(new_chunks)
|
||||
vector_store.add_documents(documents_to_update)
|
||||
if settings.GRAPHRAG_ENABLED:
|
||||
try:
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
graph_adapter = GraphRAGAdapter()
|
||||
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
|
||||
await graph_adapter.ingest_texts(kb_id, source_texts)
|
||||
logger.info("GraphRAG ingestion completed in incremental processing")
|
||||
except Exception as graph_exc:
|
||||
logger.error(f"GraphRAG ingestion failed in incremental processing: {graph_exc}")
|
||||
|
||||
# Delete removed chunks
|
||||
chunks_to_delete = chunk_manager.get_deleted_chunks(current_hashes, file_name)
|
||||
if chunks_to_delete:
|
||||
logger.info(f"Removing {len(chunks_to_delete)} deleted chunks")
|
||||
chunk_manager.delete_chunks(chunks_to_delete)
|
||||
vector_store.delete(chunks_to_delete)
|
||||
|
||||
logger.info("Document processing completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upload_document(file: UploadFile, kb_id: int) -> UploadResult:
|
||||
"""Step 1: Upload document to MinIO"""
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
file_hash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# Clean and normalize filename
|
||||
file_name = "".join(c for c in file.filename if c.isalnum() or c in ('-', '_', '.')).strip()
|
||||
object_path = f"kb_{kb_id}/{file_name}"
|
||||
|
||||
content_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain"
|
||||
}
|
||||
|
||||
_, ext = os.path.splitext(file_name)
|
||||
content_type = content_types.get(ext.lower(), "application/octet-stream")
|
||||
|
||||
# Upload to MinIO
|
||||
minio_client = get_minio_client()
|
||||
try:
|
||||
minio_client.put_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=object_path,
|
||||
data=BytesIO(content),
|
||||
length=file_size,
|
||||
content_type=content_type
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to upload file to MinIO: {str(e)}")
|
||||
raise
|
||||
|
||||
return UploadResult(
|
||||
file_path=object_path,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
file_hash=file_hash
|
||||
)
|
||||
|
||||
async def preview_document(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> PreviewResult:
|
||||
"""Step 2: Generate preview chunks"""
|
||||
# Get file from MinIO
|
||||
minio_client = get_minio_client()
|
||||
_, ext = os.path.splitext(file_path)
|
||||
ext = ext.lower()
|
||||
|
||||
# Download to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
|
||||
minio_client.fget_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=file_path,
|
||||
file_path=temp_file.name
|
||||
)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Select appropriate loader
|
||||
if ext == ".pdf":
|
||||
loader = PyPDFLoader(temp_path)
|
||||
elif ext == ".docx":
|
||||
loader = Docx2txtLoader(temp_path)
|
||||
elif ext == ".md":
|
||||
loader = UnstructuredMarkdownLoader(temp_path)
|
||||
else: # Default to text loader
|
||||
loader = TextLoader(temp_path)
|
||||
|
||||
# Load and split the document
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
chunks = text_splitter.split_documents(documents)
|
||||
|
||||
# Convert to preview format
|
||||
preview_chunks = [
|
||||
TextChunk(
|
||||
content=chunk.page_content,
|
||||
metadata=chunk.metadata
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
return PreviewResult(
|
||||
chunks=preview_chunks,
|
||||
total_chunks=len(chunks)
|
||||
)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
async def process_document_background(
|
||||
temp_path: str,
|
||||
file_name: str,
|
||||
kb_id: int,
|
||||
task_id: int,
|
||||
db: Session = None,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200
|
||||
) -> None:
|
||||
"""Process document in background"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Starting background processing for task {task_id}, file: {file_name}")
|
||||
|
||||
# if we don't pass in db, create a new database session
|
||||
if db is None:
|
||||
db = SessionLocal()
|
||||
should_close_db = True
|
||||
else:
|
||||
should_close_db = False
|
||||
|
||||
task = db.query(ProcessingTask).get(task_id)
|
||||
if not task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
minio_client = None
|
||||
local_temp_path = None
|
||||
|
||||
try:
|
||||
logger.info(f"Task {task_id}: Setting status to processing")
|
||||
task.status = "processing"
|
||||
db.commit()
|
||||
|
||||
# 1. 从临时目录下载文件
|
||||
minio_client = get_minio_client()
|
||||
try:
|
||||
local_temp_path = f"/tmp/temp_{task_id}_{file_name}" # 使用系统临时目录
|
||||
logger.info(f"Task {task_id}: Downloading file from MinIO: {temp_path} to {local_temp_path}")
|
||||
minio_client.fget_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=temp_path,
|
||||
file_path=local_temp_path
|
||||
)
|
||||
logger.info(f"Task {task_id}: File downloaded successfully")
|
||||
except MinioException as e:
|
||||
# Idempotent fallback: temp object may already be consumed by another task.
|
||||
# If the final document is already created, treat current task as completed.
|
||||
if "NoSuchKey" in str(e) and task.document_upload:
|
||||
existing_document = db.query(Document).filter(
|
||||
Document.knowledge_base_id == kb_id,
|
||||
Document.file_name == file_name,
|
||||
Document.file_hash == task.document_upload.file_hash,
|
||||
).first()
|
||||
if existing_document:
|
||||
logger.warning(
|
||||
f"Task {task_id}: Temp object missing but document already exists, "
|
||||
f"marking task as completed (document_id={existing_document.id})"
|
||||
)
|
||||
task.status = "completed"
|
||||
task.document_id = existing_document.id
|
||||
task.error_message = None
|
||||
task.document_upload.status = "completed"
|
||||
task.document_upload.error_message = None
|
||||
db.commit()
|
||||
return
|
||||
|
||||
error_msg = f"Failed to download temp file: {str(e)}"
|
||||
logger.error(f"Task {task_id}: {error_msg}")
|
||||
raise Exception(error_msg)
|
||||
|
||||
try:
|
||||
# 2. 加载和分块文档
|
||||
_, ext = os.path.splitext(file_name)
|
||||
ext = ext.lower()
|
||||
|
||||
logger.info(f"Task {task_id}: Loading document with extension {ext}")
|
||||
# 选择合适的加载器
|
||||
if ext == ".pdf":
|
||||
loader = PyPDFLoader(local_temp_path)
|
||||
elif ext == ".docx":
|
||||
loader = Docx2txtLoader(local_temp_path)
|
||||
elif ext == ".md":
|
||||
loader = UnstructuredMarkdownLoader(local_temp_path)
|
||||
else: # 默认使用文本加载器
|
||||
loader = TextLoader(local_temp_path)
|
||||
|
||||
logger.info(f"Task {task_id}: Loading document content")
|
||||
documents = loader.load()
|
||||
logger.info(f"Task {task_id}: Document loaded successfully")
|
||||
|
||||
logger.info(f"Task {task_id}: Splitting document into chunks")
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
chunks = text_splitter.split_documents(documents)
|
||||
logger.info(f"Task {task_id}: Document split into {len(chunks)} chunks")
|
||||
|
||||
# 3. 创建向量存储
|
||||
logger.info(f"Task {task_id}: Initializing vector store")
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
|
||||
# 4. 将临时文件移动到永久目录
|
||||
permanent_path = f"kb_{kb_id}/{file_name}"
|
||||
try:
|
||||
logger.info(f"Task {task_id}: Moving file to permanent storage")
|
||||
# 复制到永久目录
|
||||
source = CopySource(settings.MINIO_BUCKET_NAME, temp_path)
|
||||
minio_client.copy_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=permanent_path,
|
||||
source=source
|
||||
)
|
||||
logger.info(f"Task {task_id}: File moved to permanent storage")
|
||||
|
||||
# 删除临时文件
|
||||
logger.info(f"Task {task_id}: Removing temporary file from MinIO")
|
||||
minio_client.remove_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=temp_path
|
||||
)
|
||||
logger.info(f"Task {task_id}: Temporary file removed")
|
||||
except MinioException as e:
|
||||
error_msg = f"Failed to move file to permanent storage: {str(e)}"
|
||||
logger.error(f"Task {task_id}: {error_msg}")
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 5. 创建文档记录
|
||||
logger.info(f"Task {task_id}: Creating document record")
|
||||
document = Document(
|
||||
file_name=file_name,
|
||||
file_path=permanent_path,
|
||||
file_hash=task.document_upload.file_hash,
|
||||
file_size=task.document_upload.file_size,
|
||||
content_type=task.document_upload.content_type,
|
||||
knowledge_base_id=kb_id
|
||||
)
|
||||
db.add(document)
|
||||
db.flush()
|
||||
db.refresh(document)
|
||||
logger.info(f"Task {task_id}: Document record created with ID {document.id}")
|
||||
|
||||
# 6. 存储文档块
|
||||
logger.info(f"Task {task_id}: Storing document chunks")
|
||||
for i, chunk in enumerate(chunks):
|
||||
# 为每个 chunk 生成唯一的 ID
|
||||
chunk_id = hashlib.sha256(
|
||||
f"{kb_id}:{file_name}:{chunk.page_content}".encode()
|
||||
).hexdigest()
|
||||
|
||||
metadata = _build_enriched_chunk_metadata(
|
||||
source_metadata=chunk.metadata,
|
||||
chunk_id=chunk_id,
|
||||
file_name=file_name,
|
||||
file_path=permanent_path,
|
||||
kb_id=kb_id,
|
||||
document_id=document.id,
|
||||
chunk_index=i,
|
||||
chunk_text=chunk.page_content,
|
||||
)
|
||||
chunk.metadata = metadata
|
||||
|
||||
doc_chunk = DocumentChunk(
|
||||
id=chunk_id, # 添加 ID 字段
|
||||
document_id=document.id,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
chunk_metadata={
|
||||
"page_content": chunk.page_content,
|
||||
**metadata
|
||||
},
|
||||
hash=hashlib.sha256(
|
||||
(chunk.page_content + str(metadata)).encode()
|
||||
).hexdigest()
|
||||
)
|
||||
db.add(doc_chunk)
|
||||
if i > 0 and i % 100 == 0:
|
||||
logger.info(f"Task {task_id}: Stored {i} chunks")
|
||||
db.flush()
|
||||
|
||||
# 7. 添加到向量存储
|
||||
logger.info(f"Task {task_id}: Adding chunks to vector store")
|
||||
vector_chunks = [
|
||||
LangchainDocument(
|
||||
page_content=chunk.page_content,
|
||||
metadata=_sanitize_metadata_for_vector_store(chunk.metadata),
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
vector_store.add_documents(vector_chunks)
|
||||
# 移除 persist() 调用,因为新版本不需要
|
||||
logger.info(f"Task {task_id}: Chunks added to vector store")
|
||||
|
||||
if settings.GRAPHRAG_ENABLED:
|
||||
try:
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
|
||||
graph_adapter = GraphRAGAdapter()
|
||||
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
|
||||
await graph_adapter.ingest_texts(kb_id, source_texts)
|
||||
logger.info(f"Task {task_id}: GraphRAG ingestion completed")
|
||||
except Exception as graph_exc:
|
||||
logger.error(f"Task {task_id}: GraphRAG ingestion failed: {graph_exc}")
|
||||
|
||||
# 8. 更新任务状态
|
||||
logger.info(f"Task {task_id}: Updating task status to completed")
|
||||
task.status = "completed"
|
||||
task.document_id = document.id # 更新为新创建的文档ID
|
||||
|
||||
# 9. 更新上传记录状态
|
||||
upload = task.document_upload # 直接通过关系获取
|
||||
if upload:
|
||||
logger.info(f"Task {task_id}: Updating upload record status to completed")
|
||||
upload.status = "completed"
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Task {task_id}: Processing completed successfully")
|
||||
|
||||
finally:
|
||||
# 清理本地临时文件
|
||||
try:
|
||||
if os.path.exists(local_temp_path):
|
||||
logger.info(f"Task {task_id}: Cleaning up local temp file")
|
||||
os.remove(local_temp_path)
|
||||
logger.info(f"Task {task_id}: Local temp file cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Task {task_id}: Failed to clean up local temp file: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id}: Error processing document: {str(e)}")
|
||||
logger.error(f"Task {task_id}: Stack trace: {traceback.format_exc()}")
|
||||
db.rollback()
|
||||
|
||||
failed_task = db.query(ProcessingTask).get(task_id)
|
||||
if failed_task:
|
||||
failed_task.status = "failed"
|
||||
failed_task.error_message = str(e)
|
||||
if failed_task.document_upload:
|
||||
failed_task.document_upload.status = "failed"
|
||||
failed_task.document_upload.error_message = str(e)
|
||||
db.commit()
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
logger.info(f"Task {task_id}: Cleaning up temporary file after error")
|
||||
if minio_client is not None:
|
||||
minio_client.remove_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=temp_path
|
||||
)
|
||||
logger.info(f"Task {task_id}: Temporary file cleaned up after error")
|
||||
except:
|
||||
logger.warning(f"Task {task_id}: Failed to clean up temporary file after error")
|
||||
finally:
|
||||
# if we create the db session, we need to close it
|
||||
if should_close_db and db:
|
||||
db.close()
|
||||
@@ -0,0 +1,46 @@
|
||||
from app.core.config import settings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
# If you plan on adding other embeddings, import them here
|
||||
# from some_other_module import AnotherEmbeddingClass
|
||||
|
||||
|
||||
class EmbeddingsFactory:
|
||||
@staticmethod
|
||||
def create():
|
||||
"""
|
||||
Factory method to create an embeddings instance based on .env config.
|
||||
"""
|
||||
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
|
||||
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
|
||||
|
||||
if embeddings_provider == "openai":
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=settings.OPENAI_API_KEY,
|
||||
openai_api_base=settings.OPENAI_API_BASE,
|
||||
model=settings.OPENAI_EMBEDDINGS_MODEL
|
||||
)
|
||||
elif embeddings_provider == "dashscope":
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=settings.DASH_SCOPE_API_KEY,
|
||||
openai_api_base=settings.DASH_SCOPE_API_BASE,
|
||||
model=settings.DASH_SCOPE_EMBEDDINGS_MODEL,
|
||||
# DashScope OpenAI-compatible embedding expects string input,
|
||||
# while LangChain's len-safe path may send token ids.
|
||||
check_embedding_ctx_length=False,
|
||||
tiktoken_enabled=False,
|
||||
skip_empty=True,
|
||||
# DashScope embedding API supports at most 10 inputs per batch.
|
||||
chunk_size=10,
|
||||
)
|
||||
elif embeddings_provider == "ollama":
|
||||
return OllamaEmbeddings(
|
||||
model=settings.OLLAMA_EMBEDDINGS_MODEL,
|
||||
base_url=settings.OLLAMA_API_BASE
|
||||
)
|
||||
|
||||
# Extend with other providers:
|
||||
# elif embeddings_provider == "another_provider":
|
||||
# return AnotherEmbeddingClass(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")
|
||||
116
rag-web-ui/backend/app/services/fusion_prompts.py
Normal file
116
rag-web-ui/backend/app/services/fusion_prompts.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Fusion RAG prompts for aerospace Chinese QA."""
|
||||
|
||||
ROUTER_SYSTEM_PROMPT = """
|
||||
你是一个检索路由器。你的唯一任务是把用户请求分类到以下四类之一。
|
||||
|
||||
分类标签:
|
||||
A: 通用对话路
|
||||
- 适用:问候、寒暄、角色扮演、无须知识库支持的常识闲聊。
|
||||
- 特征:没有明确的专业实体约束,也不依赖当前知识库文档。
|
||||
|
||||
B: 混合检索路 (Hybrid RAG)
|
||||
- 适用:单实体事实查询、定义解释、时间/数值/指标问答。
|
||||
- 特征:问题通常可由少量文本片段直接回答,核心是“找准证据”。
|
||||
|
||||
C: 局部图检索路 (Graph Local Search)
|
||||
- 适用:实体关系、多跳因果、组件依赖、跨段落链式推理。
|
||||
- 特征:问题包含“谁影响谁/为什么/如何传导/依赖链”。
|
||||
|
||||
D: 全局图检索路 (Graph Global Search)
|
||||
- 适用:全局总结、趋势分析、跨系统比较、宏观评估。
|
||||
- 特征:问题面向整个语料或多个主题社区,不是单点事实。
|
||||
|
||||
判定规则(按优先级):
|
||||
1. 若请求明确是问候、寒暄、开放闲聊,判 A。
|
||||
2. 若请求强调全局综述、趋势、横向比较,判 D。
|
||||
3. 若请求强调实体关系、影响路径、多跳推理,判 C。
|
||||
4. 其余知识查询默认判 B。
|
||||
|
||||
输出要求:
|
||||
- 只能输出 JSON,不要额外文本。
|
||||
- 格式必须是:
|
||||
{
|
||||
"intent": "A/B/C/D",
|
||||
"reason": "中文简要理由"
|
||||
}
|
||||
""".strip()
|
||||
|
||||
ROUTER_USER_PROMPT_TEMPLATE = """
|
||||
请基于以下用户问题进行路由分类。
|
||||
|
||||
历史对话(可选):
|
||||
{chat_history}
|
||||
|
||||
用户问题:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
GENERAL_CHAT_PROMPT_TEMPLATE = """
|
||||
你是中文航天问答助手。当前请求被路由为“通用对话路”。
|
||||
请直接回答用户问题,要求:
|
||||
- 简洁自然
|
||||
- 不要伪造具体文献或数据来源
|
||||
- 若涉及专业细节但无上下文支撑,请明确说明是一般性知识
|
||||
|
||||
用户问题:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
HYBRID_RAG_PROMPT_TEMPLATE = """
|
||||
你是航天领域事实问答助手。你会收到按相关性排序的文本证据片段,请严格基于证据作答。
|
||||
|
||||
要求:
|
||||
1. 回答正文应自然连贯,不要使用“直接答案”“证据依据”等分节标题。
|
||||
2. 关键信息需要有可追溯引用,引用编号使用 [1]、[2] 等格式。
|
||||
3. 引用标号尽量集中放在回答末尾,不要在句中频繁插入。
|
||||
4. 不得编造未在证据中出现的事实、时间、参数、型号。
|
||||
5. 若证据不足,明确写:信息不足,缺少 xxx。
|
||||
6. 输出中文,术语严谨,避免冗长。
|
||||
|
||||
问题:
|
||||
{query}
|
||||
|
||||
证据片段:
|
||||
{context}
|
||||
""".strip()
|
||||
|
||||
GRAPH_LOCAL_PROMPT_TEMPLATE = """
|
||||
你是航天知识图谱推理助手。你将获得一个局部子图上下文(实体、关系、证据)。
|
||||
|
||||
要求:
|
||||
1. 输出结构固定为:
|
||||
- 结论
|
||||
- 推理链路
|
||||
- 证据映射
|
||||
- 不确定性
|
||||
2. 推理链路需按步骤编号(步骤1、步骤2...),明确“实体 -> 关系 -> 实体/结论”的链式过程。
|
||||
3. 若局部子图不完整,必须指出断点,不能臆造链路。
|
||||
4. 输出中文。
|
||||
|
||||
问题:
|
||||
{query}
|
||||
|
||||
局部子图上下文:
|
||||
{graph_context}
|
||||
""".strip()
|
||||
|
||||
GRAPH_GLOBAL_PROMPT_TEMPLATE = """
|
||||
你是航天领域全局分析助手。你将获得多个社区摘要,请进行跨社区综合研判。
|
||||
|
||||
要求:
|
||||
1. 输出结构固定为:
|
||||
- 总体结论
|
||||
- 跨社区共性
|
||||
- 关键差异
|
||||
- 趋势判断
|
||||
- 风险与建议
|
||||
2. 每条关键判断尽量给出对应社区编号。
|
||||
3. 仅依据输入摘要,证据不足时明确说明。
|
||||
4. 输出中文,适合技术管理层阅读。
|
||||
|
||||
问题:
|
||||
{query}
|
||||
|
||||
社区摘要:
|
||||
{community_context}
|
||||
""".strip()
|
||||
3
rag-web-ui/backend/app/services/graph/__init__.py
Normal file
3
rag-web-ui/backend/app/services/graph/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
__all__ = ["GraphRAGAdapter"]
|
||||
183
rag-web-ui/backend/app/services/graph/graphrag_adapter.py
Normal file
183
rag-web-ui/backend/app/services/graph/graphrag_adapter.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
|
||||
|
||||
class GraphRAGAdapter:
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._graphrag_instances: Dict[int, Any] = {}
|
||||
self._kb_locks: Dict[int, asyncio.Lock] = {}
|
||||
self._embedding_model = EmbeddingsFactory.create()
|
||||
self._llm_model = LLMFactory.create(streaming=False)
|
||||
self._symbols = self._load_symbols()
|
||||
|
||||
def _load_symbols(self) -> Dict[str, Any]:
|
||||
module = importlib.import_module("nano_graphrag")
|
||||
|
||||
storage_module = importlib.import_module("nano_graphrag._storage")
|
||||
utils_module = importlib.import_module("nano_graphrag._utils")
|
||||
|
||||
return {
|
||||
"GraphRAG": module.GraphRAG,
|
||||
"QueryParam": module.QueryParam,
|
||||
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
|
||||
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
|
||||
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
|
||||
}
|
||||
|
||||
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
|
||||
if kb_id not in self._kb_locks:
|
||||
self._kb_locks[kb_id] = asyncio.Lock()
|
||||
return self._kb_locks[kb_id]
|
||||
|
||||
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
|
||||
history_messages = history_messages or []
|
||||
|
||||
history_lines: List[str] = []
|
||||
for item in history_messages:
|
||||
if isinstance(item, dict):
|
||||
role = str(item.get("role", "user"))
|
||||
content = item.get("content", "")
|
||||
if isinstance(content, list):
|
||||
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
|
||||
history_lines.append(f"{role}: {joined}")
|
||||
else:
|
||||
history_lines.append(f"{role}: {content}")
|
||||
else:
|
||||
history_lines.append(str(item))
|
||||
|
||||
full_prompt = "\n\n".join(
|
||||
part
|
||||
for part in [
|
||||
f"系统提示: {system_prompt}" if system_prompt else "",
|
||||
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
|
||||
"用户输入:\n" + prompt,
|
||||
]
|
||||
if part
|
||||
)
|
||||
|
||||
model = self._llm_model
|
||||
max_tokens = kwargs.get("max_tokens")
|
||||
if max_tokens is not None:
|
||||
try:
|
||||
model = model.bind(max_tokens=max_tokens)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response = await model.ainvoke(full_prompt)
|
||||
content = getattr(response, "content", response)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
|
||||
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
|
||||
return np.array(vectors)
|
||||
|
||||
async def _get_or_create(self, kb_id: int) -> Any:
|
||||
if kb_id in self._graphrag_instances:
|
||||
return self._graphrag_instances[kb_id]
|
||||
|
||||
async with GraphRAGAdapter._instance_lock:
|
||||
if kb_id in self._graphrag_instances:
|
||||
return self._graphrag_instances[kb_id]
|
||||
|
||||
GraphRAG = self._symbols["GraphRAG"]
|
||||
EmbeddingFunc = self._symbols["EmbeddingFunc"]
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
|
||||
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
|
||||
func=self._embedding_call,
|
||||
)
|
||||
|
||||
graph_storage_cls = self._symbols["NetworkXStorage"]
|
||||
addon_params: Dict[str, Any] = {}
|
||||
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
|
||||
graph_storage_cls = self._symbols["Neo4jStorage"]
|
||||
addon_params = {
|
||||
"neo4j_url": settings.NEO4J_URL,
|
||||
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
|
||||
}
|
||||
|
||||
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
|
||||
|
||||
rag = GraphRAG(
|
||||
working_dir=working_dir,
|
||||
enable_local=True,
|
||||
enable_naive_rag=True,
|
||||
graph_storage_cls=graph_storage_cls,
|
||||
addon_params=addon_params,
|
||||
embedding_func=embedding_func,
|
||||
best_model_func=self._llm_complete,
|
||||
cheap_model_func=self._llm_complete,
|
||||
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
|
||||
)
|
||||
self._graphrag_instances[kb_id] = rag
|
||||
return rag
|
||||
|
||||
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
|
||||
cleaned = [text.strip() for text in texts if text and text.strip()]
|
||||
if not cleaned:
|
||||
return
|
||||
|
||||
rag = await self._get_or_create(kb_id)
|
||||
lock = self._get_kb_lock(kb_id)
|
||||
async with lock:
|
||||
await rag.ainsert(cleaned)
|
||||
|
||||
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
|
||||
rag = await self._get_or_create(kb_id)
|
||||
QueryParam = self._symbols["QueryParam"]
|
||||
param = QueryParam(
|
||||
mode="local",
|
||||
top_k=top_k,
|
||||
level=level,
|
||||
only_need_context=True,
|
||||
)
|
||||
return await rag.aquery(query, param)
|
||||
|
||||
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
|
||||
rag = await self._get_or_create(kb_id)
|
||||
QueryParam = self._symbols["QueryParam"]
|
||||
param = QueryParam(
|
||||
mode="global",
|
||||
level=level,
|
||||
only_need_context=True,
|
||||
)
|
||||
return await rag.aquery(query, param)
|
||||
|
||||
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
|
||||
contexts: List[str] = []
|
||||
used_kb_ids: List[int] = []
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
|
||||
if ctx:
|
||||
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
||||
used_kb_ids.append(kb_id)
|
||||
except Exception:
|
||||
continue
|
||||
return "\n\n".join(contexts), used_kb_ids
|
||||
|
||||
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
|
||||
contexts: List[str] = []
|
||||
used_kb_ids: List[int] = []
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
ctx = await self.global_context(kb_id, query, level=level)
|
||||
if ctx:
|
||||
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
||||
used_kb_ids.append(kb_id)
|
||||
except Exception:
|
||||
continue
|
||||
return "\n\n".join(contexts), used_kb_ids
|
||||
85
rag-web-ui/backend/app/services/hybrid_retriever.py
Normal file
85
rag-web-ui/backend/app/services/hybrid_retriever.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.services.vector_store.base import BaseVectorStore
|
||||
|
||||
|
||||
def _tokenize_for_keyword_score(text: str) -> List[str]:
|
||||
"""Simple multilingual tokenizer for lexical matching without extra dependencies."""
|
||||
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
|
||||
return [token for token in tokens if token.strip()]
|
||||
|
||||
|
||||
def _keyword_score(query: str, doc_text: str) -> float:
|
||||
query_terms = set(_tokenize_for_keyword_score(query))
|
||||
doc_terms = set(_tokenize_for_keyword_score(doc_text))
|
||||
|
||||
if not query_terms or not doc_terms:
|
||||
return 0.0
|
||||
|
||||
overlap = len(query_terms.intersection(doc_terms))
|
||||
return overlap / max(1, len(query_terms))
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
vector_store: BaseVectorStore,
|
||||
query: str,
|
||||
top_k: int = 6,
|
||||
fetch_k: int = 20,
|
||||
alpha: float = 0.65,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Hybrid retrieval via vector candidate generation + lexical reranking.
|
||||
|
||||
score = alpha * vector_rank_score + (1 - alpha) * keyword_score
|
||||
"""
|
||||
raw_results = vector_store.similarity_search_with_score(query, k=fetch_k)
|
||||
if not raw_results:
|
||||
return []
|
||||
|
||||
ranked: List[Dict[str, Any]] = []
|
||||
total = len(raw_results)
|
||||
|
||||
for index, item in enumerate(raw_results):
|
||||
if not isinstance(item, (tuple, list)) or len(item) < 1:
|
||||
continue
|
||||
|
||||
doc = item[0]
|
||||
if not hasattr(doc, "page_content"):
|
||||
continue
|
||||
|
||||
rank_score = 1.0 - (index / max(1, total))
|
||||
lexical_score = _keyword_score(query, doc.page_content)
|
||||
final_score = alpha * rank_score + (1.0 - alpha) * lexical_score
|
||||
|
||||
ranked.append(
|
||||
{
|
||||
"document": doc,
|
||||
"vector_rank_score": round(rank_score, 6),
|
||||
"keyword_score": round(lexical_score, 6),
|
||||
"final_score": round(final_score, 6),
|
||||
}
|
||||
)
|
||||
|
||||
ranked.sort(key=lambda row: row["final_score"], reverse=True)
|
||||
return ranked[:top_k]
|
||||
|
||||
|
||||
def format_hybrid_context(rows: List[Dict[str, Any]]) -> str:
|
||||
parts: List[str] = []
|
||||
|
||||
for i, row in enumerate(rows, start=1):
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
source = metadata.get("source") or metadata.get("file_name") or "unknown"
|
||||
chunk_id = metadata.get("chunk_id") or "unknown"
|
||||
|
||||
parts.append(
|
||||
(
|
||||
f"[{i}] source={source}, chunk_id={chunk_id}, "
|
||||
f"score={row['final_score']}\n"
|
||||
f"{doc.page_content.strip()}"
|
||||
)
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
120
rag-web-ui/backend/app/services/intent_router.py
Normal file
120
rag-web-ui/backend/app/services/intent_router.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.services.fusion_prompts import (
|
||||
ROUTER_SYSTEM_PROMPT,
|
||||
ROUTER_USER_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
VALID_INTENTS = {"A", "B", "C", "D"}
|
||||
|
||||
|
||||
def _extract_json_object(raw_text: str) -> Dict[str, str]:
|
||||
"""Extract and parse the first JSON object from model output."""
|
||||
cleaned = raw_text.strip()
|
||||
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
|
||||
|
||||
match = re.search(r"\{[\s\S]*\}", cleaned)
|
||||
if not match:
|
||||
raise ValueError("No JSON object found in router output")
|
||||
|
||||
data = json.loads(match.group(0))
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Router output JSON is not an object")
|
||||
|
||||
intent = str(data.get("intent", "")).strip().upper()
|
||||
reason = str(data.get("reason", "")).strip()
|
||||
if intent not in VALID_INTENTS:
|
||||
raise ValueError(f"Invalid intent: {intent}")
|
||||
|
||||
if not reason:
|
||||
reason = "模型未提供理由,已按规则兜底。"
|
||||
|
||||
return {"intent": intent, "reason": reason}
|
||||
|
||||
|
||||
def _build_history_text(messages: dict, max_turns: int = 6) -> str:
|
||||
if not isinstance(messages, dict):
|
||||
return ""
|
||||
|
||||
history = messages.get("messages", [])
|
||||
if not isinstance(history, list):
|
||||
return ""
|
||||
|
||||
tail = history[-max_turns:]
|
||||
rows: List[str] = []
|
||||
for msg in tail:
|
||||
role = str(msg.get("role", "unknown")).strip()
|
||||
content = str(msg.get("content", "")).strip().replace("\n", " ")
|
||||
if content:
|
||||
rows.append(f"{role}: {content}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def _heuristic_route(query: str) -> Dict[str, str]:
|
||||
text = query.strip().lower()
|
||||
|
||||
general_chat_patterns = [
|
||||
"你好",
|
||||
"您好",
|
||||
"在吗",
|
||||
"谢谢",
|
||||
"早上好",
|
||||
"晚上好",
|
||||
"你是谁",
|
||||
"讲个笑话",
|
||||
]
|
||||
global_patterns = [
|
||||
"总结",
|
||||
"综述",
|
||||
"整体",
|
||||
"全局",
|
||||
"趋势",
|
||||
"对比",
|
||||
"比较",
|
||||
"宏观",
|
||||
"共性",
|
||||
"差异",
|
||||
]
|
||||
local_graph_patterns = [
|
||||
"关系",
|
||||
"依赖",
|
||||
"影响",
|
||||
"导致",
|
||||
"原因",
|
||||
"链路",
|
||||
"多跳",
|
||||
"传导",
|
||||
"耦合",
|
||||
"约束",
|
||||
]
|
||||
|
||||
if any(token in text for token in general_chat_patterns):
|
||||
return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"}
|
||||
|
||||
if any(token in text for token in global_patterns):
|
||||
return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"}
|
||||
|
||||
if any(token in text for token in local_graph_patterns):
|
||||
return {"intent": "C", "reason": "问题强调实体关系与链式推理。"}
|
||||
|
||||
return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"}
|
||||
|
||||
|
||||
async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]:
|
||||
"""Route user query to A/B/C/D with LLM-first and heuristic fallback."""
|
||||
history_text = _build_history_text(messages)
|
||||
user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format(
|
||||
chat_history=history_text or "无",
|
||||
query=query,
|
||||
)
|
||||
|
||||
try:
|
||||
full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}"
|
||||
model_resp = await llm.ainvoke(full_prompt)
|
||||
content = getattr(model_resp, "content", model_resp)
|
||||
raw_text = content if isinstance(content, str) else str(content)
|
||||
return _extract_json_object(raw_text)
|
||||
except Exception:
|
||||
return _heuristic_route(query)
|
||||
57
rag-web-ui/backend/app/services/llm/llm_factory.py
Normal file
57
rag-web-ui/backend/app/services/llm/llm_factory.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Optional
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_ollama import OllamaLLM
|
||||
from app.core.config import settings
|
||||
|
||||
class LLMFactory:
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
streaming: bool = True,
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Create a LLM instance based on the provider
|
||||
"""
|
||||
# If no provider specified, use the one from settings
|
||||
provider = provider or settings.CHAT_PROVIDER
|
||||
|
||||
if provider.lower() == "openai":
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.OPENAI_MODEL,
|
||||
openai_api_key=settings.OPENAI_API_KEY,
|
||||
openai_api_base=settings.OPENAI_API_BASE
|
||||
)
|
||||
elif provider.lower() == "deepseek":
|
||||
return ChatDeepSeek(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.DEEPSEEK_MODEL,
|
||||
api_key=settings.DEEPSEEK_API_KEY,
|
||||
api_base=settings.DEEPSEEK_API_BASE
|
||||
)
|
||||
elif provider.lower() == "dashscope":
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.DASH_SCOPE_CHAT_MODEL,
|
||||
openai_api_key=settings.DASH_SCOPE_API_KEY,
|
||||
openai_api_base=settings.DASH_SCOPE_API_BASE,
|
||||
)
|
||||
elif provider.lower() == "ollama":
|
||||
# Initialize Ollama model
|
||||
return OllamaLLM(
|
||||
model=settings.OLLAMA_MODEL,
|
||||
base_url=settings.OLLAMA_API_BASE,
|
||||
temperature=temperature,
|
||||
streaming=streaming
|
||||
)
|
||||
# Add more providers here as needed
|
||||
# elif provider.lower() == "anthropic":
|
||||
# return ChatAnthropic(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
3
rag-web-ui/backend/app/services/reranker/__init__.py
Normal file
3
rag-web-ui/backend/app/services/reranker/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.services.reranker.external_api import ExternalRerankerClient
|
||||
|
||||
__all__ = ["ExternalRerankerClient"]
|
||||
164
rag-web-ui/backend/app/services/reranker/external_api.py
Normal file
164
rag-web-ui/backend/app/services/reranker/external_api.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib import request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalRerankerClient:
|
||||
api_url: str
|
||||
api_key: str = ""
|
||||
model: str = ""
|
||||
timeout_seconds: float = 8.0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return bool(self.api_url)
|
||||
|
||||
@property
|
||||
def is_dashscope_rerank(self) -> bool:
|
||||
return "dashscope.aliyuncs.com" in self.api_url and "/services/rerank/" in self.api_url
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
metadata: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[List[float]]:
|
||||
if not self.enabled:
|
||||
return None
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
payload = self._build_payload(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n or len(documents),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(self._post_json, payload)
|
||||
scores = self._parse_scores(response, len(documents))
|
||||
return scores
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
req = request.Request(
|
||||
self.api_url,
|
||||
data=json.dumps(payload).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=self.timeout_seconds) as resp:
|
||||
body = resp.read().decode("utf-8")
|
||||
return json.loads(body)
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: int,
|
||||
metadata: Optional[List[Dict[str, Any]]],
|
||||
) -> Dict[str, Any]:
|
||||
if self.is_dashscope_rerank:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": {
|
||||
"return_documents": True,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return payload
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n,
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return payload
|
||||
|
||||
def _parse_scores(self, response: Dict[str, Any], expected_len: int) -> List[float]:
|
||||
# DashScope format:
|
||||
# {"output": {"results": [{"index": 0, "relevance_score": 0.98}, ...]}}
|
||||
output_block = response.get("output")
|
||||
if isinstance(output_block, dict) and isinstance(output_block.get("results"), list):
|
||||
raw_results = output_block["results"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
# Common response format #1:
|
||||
# {"results": [{"index": 0, "relevance_score": 0.98}, ...]}
|
||||
if isinstance(response.get("results"), list):
|
||||
raw_results = response["results"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
# Common response format #2:
|
||||
# {"scores": [0.9, 0.1, ...]}
|
||||
if isinstance(response.get("scores"), list):
|
||||
values = response["scores"]
|
||||
scores: List[float] = []
|
||||
for i in range(expected_len):
|
||||
try:
|
||||
scores.append(float(values[i]))
|
||||
except Exception:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
# Common response format #3:
|
||||
# {"data": [{"index": 0, "score": 0.88}, ...]}
|
||||
if isinstance(response.get("data"), list):
|
||||
raw_results = response["data"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("score", item.get("relevance_score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
return [0.0] * expected_len
|
||||
3
rag-web-ui/backend/app/services/retrieval/__init__.py
Normal file
3
rag-web-ui/backend/app/services/retrieval/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
|
||||
__all__ = ["MultiKBRetriever", "format_retrieval_context"]
|
||||
131
rag-web-ui/backend/app/services/retrieval/multi_kb_retriever.py
Normal file
131
rag-web-ui/backend/app/services/retrieval/multi_kb_retriever.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.services.reranker.external_api import ExternalRerankerClient
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
|
||||
return [token for token in tokens if token.strip()]
|
||||
|
||||
|
||||
def _keyword_score(query: str, text: str) -> float:
|
||||
query_terms = set(_tokenize(query))
|
||||
text_terms = set(_tokenize(text))
|
||||
if not query_terms or not text_terms:
|
||||
return 0.0
|
||||
overlap = len(query_terms.intersection(text_terms))
|
||||
return overlap / max(1, len(query_terms))
|
||||
|
||||
|
||||
def format_retrieval_context(rows: List[Dict[str, Any]]) -> str:
|
||||
blocks: List[str] = []
|
||||
for i, row in enumerate(rows, start=1):
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
blocks.append(
|
||||
(
|
||||
f"[{i}] kb_id={row.get('kb_id')}, source={metadata.get('source') or metadata.get('file_name') or 'unknown'}, "
|
||||
f"chunk_id={metadata.get('chunk_id') or 'unknown'}, score={row.get('final_score', 0):.6f}\n"
|
||||
f"{doc.page_content.strip()}"
|
||||
)
|
||||
)
|
||||
return "\n\n".join(blocks)
|
||||
|
||||
|
||||
class MultiKBRetriever:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
reranker_client: Optional[ExternalRerankerClient] = None,
|
||||
reranker_weight: float = 0.75,
|
||||
vector_weight: float = 0.2,
|
||||
keyword_weight: float = 0.05,
|
||||
):
|
||||
self.reranker_client = reranker_client
|
||||
self.reranker_weight = reranker_weight
|
||||
self.vector_weight = vector_weight
|
||||
self.keyword_weight = keyword_weight
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
kb_vector_stores: List[Dict[str, Any]],
|
||||
fetch_k_per_kb: int = 12,
|
||||
top_k: int = 12,
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidates: List[Dict[str, Any]] = []
|
||||
|
||||
for kb_store in kb_vector_stores:
|
||||
kb_id = kb_store["kb_id"]
|
||||
vector_store = kb_store["store"]
|
||||
raw = vector_store.similarity_search_with_score(query, k=fetch_k_per_kb)
|
||||
total = len(raw)
|
||||
|
||||
for index, item in enumerate(raw):
|
||||
if not isinstance(item, (tuple, list)) or not item:
|
||||
continue
|
||||
|
||||
doc = item[0]
|
||||
if not hasattr(doc, "page_content"):
|
||||
continue
|
||||
|
||||
metadata = doc.metadata or {}
|
||||
rank_score = 1.0 - (index / max(1, total))
|
||||
lexical_score = _keyword_score(query, doc.page_content)
|
||||
|
||||
candidates.append(
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"document": doc,
|
||||
"chunk_key": f"{kb_id}:{metadata.get('chunk_id', index)}",
|
||||
"vector_rank_score": round(rank_score, 6),
|
||||
"keyword_score": round(lexical_score, 6),
|
||||
}
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# Dedupe by KB + chunk id to avoid repeated chunks from same collection.
|
||||
unique_map: Dict[str, Dict[str, Any]] = {}
|
||||
for row in candidates:
|
||||
key = row["chunk_key"]
|
||||
existing = unique_map.get(key)
|
||||
if existing is None:
|
||||
unique_map[key] = row
|
||||
continue
|
||||
if row["vector_rank_score"] > existing["vector_rank_score"]:
|
||||
unique_map[key] = row
|
||||
|
||||
merged = list(unique_map.values())
|
||||
merged.sort(key=lambda x: x["vector_rank_score"], reverse=True)
|
||||
|
||||
reranker_scores: Optional[List[float]] = None
|
||||
if self.reranker_client is not None and self.reranker_client.enabled:
|
||||
reranker_scores = await self.reranker_client.rerank(
|
||||
query=query,
|
||||
documents=[row["document"].page_content for row in merged],
|
||||
top_n=min(top_k, len(merged)),
|
||||
metadata=[{"kb_id": row["kb_id"]} for row in merged],
|
||||
)
|
||||
|
||||
for idx, row in enumerate(merged):
|
||||
base_score = (
|
||||
self.vector_weight * row["vector_rank_score"]
|
||||
+ self.keyword_weight * row["keyword_score"]
|
||||
)
|
||||
|
||||
if reranker_scores is not None:
|
||||
rerank_value = float(reranker_scores[idx])
|
||||
final_score = self.reranker_weight * rerank_value + (1 - self.reranker_weight) * base_score
|
||||
row["reranker_score"] = round(rerank_value, 6)
|
||||
else:
|
||||
final_score = base_score
|
||||
row["reranker_score"] = None
|
||||
|
||||
row["final_score"] = round(final_score, 6)
|
||||
|
||||
merged.sort(key=lambda x: x["final_score"], reverse=True)
|
||||
return merged[:top_k]
|
||||
187
rag-web-ui/backend/app/services/srs_job_service.py
Normal file
187
rag-web-ui/backend/app/services/srs_job_service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
|
||||
def run_srs_job(job_id: int) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
job.error_message = None
|
||||
db.commit()
|
||||
|
||||
payload = get_srs_tool().run(job.input_file_path)
|
||||
|
||||
extraction = SRSExtraction(
|
||||
job_id=job.id,
|
||||
document_name=payload["document_name"],
|
||||
document_title=payload.get("document_title") or payload["document_name"],
|
||||
generated_at=_parse_generated_at(payload.get("generated_at")),
|
||||
total_requirements=len(payload.get("requirements", [])),
|
||||
statistics=payload.get("statistics", {}),
|
||||
raw_output=payload.get("raw_output", {}),
|
||||
)
|
||||
db.add(extraction)
|
||||
db.flush()
|
||||
|
||||
for item in payload.get("requirements", []):
|
||||
requirement = SRSRequirement(
|
||||
extraction_id=extraction.id,
|
||||
requirement_uid=item["id"],
|
||||
title=item.get("title") or item["id"],
|
||||
description=item.get("description") or "",
|
||||
priority=item.get("priority") or "中",
|
||||
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
|
||||
source_field=item.get("source_field") or "文档解析",
|
||||
section_number=item.get("section_number"),
|
||||
section_title=item.get("section_title"),
|
||||
requirement_type=item.get("requirement_type"),
|
||||
sort_order=int(item.get("sort_order") or 0),
|
||||
)
|
||||
db.add(requirement)
|
||||
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.output_summary = {
|
||||
"total_requirements": extraction.total_requirements,
|
||||
"document_name": extraction.document_name,
|
||||
}
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
_mark_job_failed(job_id=job_id, error_message=str(exc))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _mark_job_failed(job_id: int, error_message: str) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
job.status = "failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.error_message = error_message[:2000]
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _parse_generated_at(value: Any) -> datetime:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return datetime.utcnow()
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
def ensure_upload_path(job_id: int, file_name: str) -> Path:
|
||||
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
return target_dir / file_name
|
||||
|
||||
|
||||
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
|
||||
requirements: List[Dict[str, Any]] = []
|
||||
for item in extraction.requirements:
|
||||
requirements.append(
|
||||
{
|
||||
"id": item.requirement_uid,
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"priority": item.priority,
|
||||
"acceptanceCriteria": item.acceptance_criteria or [],
|
||||
"sourceField": item.source_field,
|
||||
"sectionNumber": item.section_number,
|
||||
"sectionTitle": item.section_title,
|
||||
"requirementType": item.requirement_type,
|
||||
"sortOrder": item.sort_order,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"jobId": job.id,
|
||||
"documentName": extraction.document_name,
|
||||
"generatedAt": extraction.generated_at.isoformat(),
|
||||
"statistics": extraction.statistics or {},
|
||||
"requirements": requirements,
|
||||
}
|
||||
|
||||
|
||||
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
|
||||
existing = {
|
||||
req.requirement_uid: req
|
||||
for req in db.query(SRSRequirement)
|
||||
.filter(SRSRequirement.extraction_id == extraction.id)
|
||||
.all()
|
||||
}
|
||||
seen_ids = set()
|
||||
|
||||
for index, item in enumerate(updates):
|
||||
uid = item["id"]
|
||||
seen_ids.add(uid)
|
||||
req = existing.get(uid)
|
||||
if req is None:
|
||||
req = SRSRequirement(
|
||||
extraction_id=extraction.id,
|
||||
requirement_uid=uid,
|
||||
title=item.get("title") or uid,
|
||||
description=item.get("description") if item.get("description") is not None else "",
|
||||
priority=item.get("priority") or "中",
|
||||
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||||
source_field=item.get("sourceField") or "文档解析",
|
||||
section_number=item.get("sectionNumber"),
|
||||
section_title=item.get("sectionTitle"),
|
||||
requirement_type=item.get("requirementType"),
|
||||
sort_order=int(item.get("sortOrder") or index),
|
||||
)
|
||||
db.add(req)
|
||||
continue
|
||||
|
||||
req.title = item.get("title", req.title)
|
||||
req.description = item.get("description", req.description)
|
||||
req.priority = item.get("priority", req.priority)
|
||||
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
|
||||
req.source_field = item.get("sourceField", req.source_field)
|
||||
req.section_number = item.get("sectionNumber", req.section_number)
|
||||
req.section_title = item.get("sectionTitle", req.section_title)
|
||||
req.requirement_type = item.get("requirementType", req.requirement_type)
|
||||
req.sort_order = int(item.get("sortOrder", index))
|
||||
|
||||
for uid, req in existing.items():
|
||||
if uid not in seen_ids:
|
||||
db.delete(req)
|
||||
|
||||
extraction.total_requirements = len(updates)
|
||||
extraction.statistics = {
|
||||
"total": len(updates),
|
||||
"by_type": _count_requirement_types(updates),
|
||||
}
|
||||
extraction.raw_output = {
|
||||
"document_name": extraction.document_name,
|
||||
"generated_at": extraction.generated_at.isoformat(),
|
||||
"requirements": updates,
|
||||
}
|
||||
|
||||
|
||||
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
stats: Dict[str, int] = {}
|
||||
for item in items:
|
||||
req_type = item.get("requirementType") or "functional"
|
||||
stats[req_type] = stats.get(req_type, 0) + 1
|
||||
return stats
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.services.testing_pipeline.pipeline import run_testing_pipeline
|
||||
|
||||
__all__ = ["run_testing_pipeline"]
|
||||
20
rag-web-ui/backend/app/services/testing_pipeline/base.py
Normal file
20
rag-web-ui/backend/app/services/testing_pipeline/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutionResult:
|
||||
context: Dict[str, Any]
|
||||
output_summary: str
|
||||
fallback_used: bool = False
|
||||
|
||||
|
||||
class TestingTool(ABC):
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
raise NotImplementedError
|
||||
99
rag-web-ui/backend/app/services/testing_pipeline/pipeline.py
Normal file
99
rag-web-ui/backend/app/services/testing_pipeline/pipeline.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.testing_pipeline.tools import build_default_tool_chain
|
||||
|
||||
|
||||
def _build_input_summary(context: Dict[str, Any]) -> str:
|
||||
req_text = str(context.get("user_requirement_text", "")).strip()
|
||||
req_type = str(context.get("requirement_type_input", "")).strip() or "auto"
|
||||
short_text = req_text if len(req_text) <= 60 else f"{req_text[:60]}..."
|
||||
return f"requirement_type_input={req_type}; requirement_text={short_text}"
|
||||
|
||||
|
||||
def _build_output_summary(context: Dict[str, Any]) -> str:
|
||||
req_type_result = context.get("requirement_type_result", {})
|
||||
req_type = req_type_result.get("requirement_type", "")
|
||||
test_items = context.get("test_items", {})
|
||||
test_cases = context.get("test_cases", {})
|
||||
|
||||
return (
|
||||
f"requirement_type={req_type}; "
|
||||
f"items={len(test_items.get('normal', [])) + len(test_items.get('abnormal', []))}; "
|
||||
f"cases={len(test_cases.get('normal', [])) + len(test_cases.get('abnormal', []))}"
|
||||
)
|
||||
|
||||
|
||||
def run_testing_pipeline(
|
||||
user_requirement_text: str,
|
||||
requirement_type_input: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
knowledge_context: Optional[str] = None,
|
||||
use_model_generation: bool = False,
|
||||
max_items_per_group: int = 12,
|
||||
cases_per_item: int = 2,
|
||||
max_focus_points: int = 6,
|
||||
max_llm_calls: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
llm_model = None
|
||||
if use_model_generation:
|
||||
try:
|
||||
llm_model = LLMFactory.create(streaming=False)
|
||||
except Exception:
|
||||
llm_model = None
|
||||
|
||||
context: Dict[str, Any] = {
|
||||
"trace_id": str(uuid4()),
|
||||
"user_requirement_text": user_requirement_text,
|
||||
"requirement_type_input": requirement_type_input,
|
||||
"debug": bool(debug),
|
||||
"knowledge_context": (knowledge_context or "").strip(),
|
||||
"knowledge_used": bool((knowledge_context or "").strip()),
|
||||
"use_model_generation": bool(use_model_generation),
|
||||
"llm_model": llm_model,
|
||||
"max_items_per_group": max(4, min(int(max_items_per_group), 30)),
|
||||
"cases_per_item": max(1, min(int(cases_per_item), 5)),
|
||||
"max_focus_points": max(3, min(int(max_focus_points), 12)),
|
||||
"llm_call_budget": max(0, min(int(max_llm_calls), 100)),
|
||||
}
|
||||
|
||||
step_logs: List[Dict[str, Any]] = []
|
||||
|
||||
for tool in build_default_tool_chain():
|
||||
start = perf_counter()
|
||||
input_summary = _build_input_summary(context)
|
||||
|
||||
execution = tool.execute(context)
|
||||
context = execution.context
|
||||
|
||||
duration_ms = (perf_counter() - start) * 1000
|
||||
step_logs.append(
|
||||
{
|
||||
"step_name": tool.name,
|
||||
"input_summary": input_summary,
|
||||
"output_summary": execution.output_summary,
|
||||
"success": True,
|
||||
"fallback_used": execution.fallback_used,
|
||||
"duration_ms": round(duration_ms, 3),
|
||||
}
|
||||
)
|
||||
|
||||
req_result = context.get("requirement_type_result", {})
|
||||
|
||||
return {
|
||||
"trace_id": context.get("trace_id"),
|
||||
"requirement_type": req_result.get("requirement_type", "未知类型"),
|
||||
"reason": req_result.get("reason", ""),
|
||||
"candidates": req_result.get("candidates", []),
|
||||
"test_items": context.get("test_items", {"normal": [], "abnormal": []}),
|
||||
"test_cases": context.get("test_cases", {"normal": [], "abnormal": []}),
|
||||
"expected_results": context.get("expected_results", {"normal": [], "abnormal": []}),
|
||||
"formatted_output": context.get("formatted_output", ""),
|
||||
"pipeline_summary": _build_output_summary(context),
|
||||
"knowledge_used": bool(context.get("knowledge_used", False)),
|
||||
"step_logs": step_logs if debug else [],
|
||||
}
|
||||
203
rag-web-ui/backend/app/services/testing_pipeline/rules.py
Normal file
203
rag-web-ui/backend/app/services/testing_pipeline/rules.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
REQUIREMENT_TYPES: List[str] = [
|
||||
"功能测试",
|
||||
"性能测试",
|
||||
"外部接口测试",
|
||||
"人机交互界面测试",
|
||||
"强度测试",
|
||||
"余量测试",
|
||||
"可靠性测试",
|
||||
"安全性测试",
|
||||
"恢复性测试",
|
||||
"边界测试",
|
||||
"安装性测试",
|
||||
"互操作性测试",
|
||||
"敏感性测试",
|
||||
"测试充分性要求",
|
||||
]
|
||||
|
||||
|
||||
|
||||
TYPE_SIGNAL_RULES: Dict[str, str] = {
|
||||
"功能测试": "关注功能需求逐项验证、业务流程正确性、输入输出行为、状态转换与边界值处理。",
|
||||
"性能测试": "关注处理精度、响应时间、处理数据量、系统协调性、负载潜力与运行占用空间。",
|
||||
"外部接口测试": "关注外部输入输出接口的格式、内容、协议与正常/异常交互表现。",
|
||||
"人机交互界面测试": "关注界面一致性、界面风格、操作流程、误操作健壮性与错误提示能力。",
|
||||
"强度测试": "关注系统在极限、超负荷、饱和和降级条件下的稳定性与承受能力。",
|
||||
"余量测试": "关注存储余量、输入输出通道余量、功能处理时间余量等资源裕度。",
|
||||
"可靠性测试": "关注真实或仿真环境下的失效等级、运行剖面、输入覆盖和长期稳定运行能力。",
|
||||
"安全性测试": "关注危险状态响应、安全关键部件、异常输入防护、非法访问阻断和数据完整性保护。",
|
||||
"恢复性测试": "关注故障探测、备用切换、系统状态保护与从无错误状态继续执行能力。",
|
||||
"边界测试": "关注输入输出域边界、状态转换端点、功能界限、性能界限与容量界限。",
|
||||
"安装性测试": "关注不同配置下安装卸载流程和安装规程执行正确性。",
|
||||
"互操作性测试": "关注多个软件并行运行时的互操作能力与协同正确性。",
|
||||
"敏感性测试": "关注有效输入类中可能引发不稳定或不正常处理的数据组合。",
|
||||
"测试充分性要求": "关注需求覆盖率、配置项覆盖、语句覆盖、分支覆盖及未覆盖分析确认。",
|
||||
}
|
||||
|
||||
|
||||
DECOMPOSE_FORCE_RULES: List[str] = [
|
||||
"每个软件功能至少应被正常测试与被认可的异常场景覆盖;复杂功能需继续细分。",
|
||||
"每个测试项必须语义完整、可直接执行。",
|
||||
"覆盖必须包含:正常流程、边界条件(适用时)、异常条件。",
|
||||
"粒度需适中,避免过粗或过细。",
|
||||
"对未知类型必须执行通用分解,并保持正常/异常分组。",
|
||||
"对需求说明未显式给出但在用户手册或操作手册体现的功能,也应补充测试项覆盖。",
|
||||
]
|
||||
|
||||
|
||||
REQUIREMENT_RULES: Dict[str, Dict[str, List[str]]] = {
|
||||
"功能测试": {
|
||||
"keywords": ["功能", "业务流程", "输入输出", "状态转换", "边界值"],
|
||||
"normal": [
|
||||
"正常覆盖功能主路径、基本数据类型、合法边界值与状态转换。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖非法输入、不规则输入、非法边界值与最坏情况。",
|
||||
],
|
||||
},
|
||||
"性能测试": {
|
||||
"keywords": ["性能", "处理精度", "响应时间", "处理数据量", "负载", "占用空间"],
|
||||
"normal": [
|
||||
"正常覆盖处理精度、响应时间、处理数据量与模块协调性。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖超负荷、软硬件限制、负载潜力上限与资源占用异常。",
|
||||
],
|
||||
},
|
||||
"外部接口测试": {
|
||||
"keywords": ["外部接口", "输入接口", "输出接口", "格式", "内容", "协议", "异常交互"],
|
||||
"normal": [
|
||||
"正常覆盖全部外部接口格式与内容正确性。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖每个输入输出接口的错误格式、错误内容与异常交互。",
|
||||
],
|
||||
},
|
||||
"人机交互界面测试": {
|
||||
"keywords": ["界面", "风格", "交互", "误操作", "错误提示", "操作流程"],
|
||||
"normal": [
|
||||
"正常覆盖界面风格一致性与标准操作流程。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖误操作、快速操作、非法输入、错误命令与错误流程提示。",
|
||||
],
|
||||
},
|
||||
"强度测试": {
|
||||
"keywords": ["强度", "极限", "超负荷", "饱和", "降级", "健壮性"],
|
||||
"normal": [
|
||||
"正常覆盖设计极限下系统功能和性能表现。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖超出极限时的降级行为、健壮性与饱和表现。",
|
||||
],
|
||||
},
|
||||
"余量测试": {
|
||||
"keywords": ["余量", "存储余量", "通道余量", "处理时间余量", "资源裕度"],
|
||||
"normal": [
|
||||
"正常覆盖存储、通道、处理时间余量是否满足要求。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖余量不足或耗尽时系统告警与受控行为。",
|
||||
],
|
||||
},
|
||||
"可靠性测试": {
|
||||
"keywords": ["可靠性", "运行剖面", "失效等级", "输入覆盖", "长期稳定"],
|
||||
"normal": [
|
||||
"正常覆盖典型环境、运行剖面与输入变量组合。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖失效等级场景、边界环境变化、不合法输入域及失效记录。",
|
||||
],
|
||||
},
|
||||
"安全性测试": {
|
||||
"keywords": ["安全", "危险状态", "安全关键部件", "非法进入", "完整性", "防护"],
|
||||
"normal": [
|
||||
"正常覆盖安全关键部件、安全结构与合法操作路径。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖危险状态、故障模式、边界接合部、非法进入与数据完整性保护。",
|
||||
],
|
||||
},
|
||||
"恢复性测试": {
|
||||
"keywords": ["恢复", "故障探测", "备用切换", "状态保护", "继续执行", "reset"],
|
||||
"normal": [
|
||||
"正常覆盖故障探测、备用切换、恢复后继续执行。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖故障中作业保护、状态保护与恢复失败路径。",
|
||||
],
|
||||
},
|
||||
"边界测试": {
|
||||
"keywords": ["边界", "端点", "输入输出域", "状态转换", "性能界限", "容量界限"],
|
||||
"normal": [
|
||||
"正常覆盖输入输出域边界、状态转换端点与功能界限。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖性能界限、容量界限和越界端点。",
|
||||
],
|
||||
},
|
||||
"安装性测试": {
|
||||
"keywords": ["安装", "卸载", "配置", "安装规程", "部署", "中断"],
|
||||
"normal": [
|
||||
"正常覆盖标准及不同配置下安装卸载流程。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖安装规程错误、依赖异常与中断后的处理。",
|
||||
],
|
||||
},
|
||||
"互操作性测试": {
|
||||
"keywords": ["互操作", "并行运行", "协同", "兼容", "冲突", "互操作失败"],
|
||||
"normal": [
|
||||
"正常覆盖两个或多个软件同时运行与互操作过程。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖互操作失败、并行冲突与协同异常。",
|
||||
],
|
||||
},
|
||||
"敏感性测试": {
|
||||
"keywords": ["敏感性", "输入类", "数据组合", "不稳定", "不正常处理"],
|
||||
"normal": [
|
||||
"正常覆盖有效输入类中典型数据组合。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖引发不稳定或不正常处理的特殊数据组合。",
|
||||
],
|
||||
},
|
||||
"测试充分性要求": {
|
||||
"keywords": ["测试充分性", "需求覆盖率", "配置项覆盖", "语句覆盖", "分支覆盖", "未覆盖分析"],
|
||||
"normal": [
|
||||
"正常覆盖需求覆盖率、配置项覆盖与代码覆盖达标。",
|
||||
],
|
||||
"abnormal": [
|
||||
"异常覆盖未覆盖部分逐项分析、确认与报告输出。",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
GENERIC_DECOMPOSITION_RULES: Dict[str, List[str]] = {
|
||||
"normal": [
|
||||
"主流程正确性。",
|
||||
"合法边界值。",
|
||||
"标准输入输出。",
|
||||
],
|
||||
"abnormal": [
|
||||
"非法输入。",
|
||||
"越界输入。",
|
||||
"资源异常或状态冲突。",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_RESULT_PLACEHOLDER_MAP: Dict[str, str] = {
|
||||
"{{return_value}}": "接口或函数返回值验证。",
|
||||
"{{state_change}}": "系统状态变化验证。",
|
||||
"{{error_message}}": "异常场景错误信息验证。",
|
||||
"{{data_persistence}}": "数据库或存储落库结果验证。",
|
||||
"{{ui_display}}": "界面显示反馈验证。",
|
||||
}
|
||||
867
rag-web-ui/backend/app/services/testing_pipeline/tools.py
Normal file
867
rag-web-ui/backend/app/services/testing_pipeline/tools.py
Normal file
@@ -0,0 +1,867 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.services.testing_pipeline.base import TestingTool, ToolExecutionResult
|
||||
from app.services.testing_pipeline.rules import (
|
||||
DECOMPOSE_FORCE_RULES,
|
||||
EXPECTED_RESULT_PLACEHOLDER_MAP,
|
||||
GENERIC_DECOMPOSITION_RULES,
|
||||
REQUIREMENT_RULES,
|
||||
REQUIREMENT_TYPES,
|
||||
TYPE_SIGNAL_RULES,
|
||||
)
|
||||
|
||||
|
||||
def _clean_text(value: str) -> str:
|
||||
return " ".join((value or "").replace("\n", " ").split())
|
||||
|
||||
|
||||
def _truncate_text(value: str, max_len: int = 2000) -> str:
|
||||
text = _clean_text(value)
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return f"{text[:max_len]}..."
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int, low: int, high: int) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
return max(low, min(parsed, high))
|
||||
|
||||
|
||||
def _strip_instruction_prefix(value: str) -> str:
|
||||
text = _clean_text(value)
|
||||
if not text:
|
||||
return text
|
||||
|
||||
lowered = text.lower()
|
||||
if lowered.startswith("/testing"):
|
||||
text = _clean_text(text[len("/testing") :])
|
||||
|
||||
prefixes = [
|
||||
"为以下需求生成测试用例",
|
||||
"根据以下需求生成测试用例",
|
||||
"请根据以下需求生成测试用例",
|
||||
"请根据需求生成测试用例",
|
||||
"请生成测试用例",
|
||||
"生成测试用例",
|
||||
]
|
||||
for prefix in prefixes:
|
||||
if text.startswith(prefix):
|
||||
for sep in (":", ":"):
|
||||
idx = text.find(sep)
|
||||
if idx != -1:
|
||||
text = _clean_text(text[idx + 1 :])
|
||||
break
|
||||
else:
|
||||
text = _clean_text(text[len(prefix) :])
|
||||
break
|
||||
|
||||
pattern = re.compile(r"^(请)?(根据|按|基于).{0,40}(需求|场景).{0,30}(生成|输出).{0,20}(测试项|测试用例)[::]")
|
||||
matched = pattern.match(text)
|
||||
if matched:
|
||||
text = _clean_text(text[matched.end() :])
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def _extract_focus_points(value: str, max_points: int = 6) -> List[str]:
|
||||
text = _strip_instruction_prefix(value)
|
||||
if not text:
|
||||
return []
|
||||
|
||||
parts = [_clean_text(part) for part in re.split(r"[,,。;;]", text)]
|
||||
parts = [part for part in parts if part]
|
||||
|
||||
ignored_tokens = ["生成测试用例", "测试项分解", "测试用例生成", "以下需求"]
|
||||
filtered = [
|
||||
part
|
||||
for part in parts
|
||||
if len(part) >= 4 and not any(token in part for token in ignored_tokens)
|
||||
]
|
||||
if not filtered:
|
||||
filtered = parts
|
||||
|
||||
priority_keywords = [
|
||||
"启停",
|
||||
"开启",
|
||||
"关闭",
|
||||
"远程控制",
|
||||
"保护",
|
||||
"联动",
|
||||
"状态",
|
||||
"故障",
|
||||
"恢复",
|
||||
"切换",
|
||||
"告警",
|
||||
"模式",
|
||||
"边界",
|
||||
"时序",
|
||||
]
|
||||
priority = [part for part in filtered if any(keyword in part for keyword in priority_keywords)]
|
||||
candidates = priority if priority else filtered
|
||||
|
||||
unique: List[str] = []
|
||||
for part in candidates:
|
||||
if part not in unique:
|
||||
unique.append(part)
|
||||
|
||||
return unique[:max_points]
|
||||
|
||||
|
||||
def _build_type_scores(text: str) -> Dict[str, int]:
|
||||
scores: Dict[str, int] = {}
|
||||
lowered = text.lower()
|
||||
|
||||
for req_type, rule in REQUIREMENT_RULES.items():
|
||||
score = 0
|
||||
if req_type in text:
|
||||
score += 5
|
||||
for keyword in rule.get("keywords", []):
|
||||
if keyword.lower() in lowered:
|
||||
score += 2
|
||||
scores[req_type] = score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _top_candidates(scores: Dict[str, int], top_n: int = 3) -> List[str]:
|
||||
sorted_pairs = sorted(scores.items(), key=lambda pair: pair[1], reverse=True)
|
||||
non_zero = [name for name, score in sorted_pairs if score > 0]
|
||||
if non_zero:
|
||||
return non_zero[:top_n]
|
||||
return ["功能测试", "边界测试", "性能测试"][:top_n]
|
||||
|
||||
|
||||
def _message_to_text(value: Any) -> str:
|
||||
content = getattr(value, "content", value)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
chunks: List[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
chunks.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
chunks.append(text)
|
||||
else:
|
||||
chunks.append(str(item))
|
||||
return "".join(chunks)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_json_object(value: str) -> Optional[Dict[str, Any]]:
|
||||
text = (value or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
|
||||
if text.endswith("```"):
|
||||
text = text[:-3].strip()
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
start = text.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
|
||||
depth = 0
|
||||
for idx in range(start, len(text)):
|
||||
ch = text[idx]
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
fragment = text[start : idx + 1]
|
||||
try:
|
||||
data = json.loads(fragment)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _invoke_llm_json(context: Dict[str, Any], prompt: str) -> Optional[Dict[str, Any]]:
|
||||
model = context.get("llm_model")
|
||||
if model is None or not context.get("use_model_generation"):
|
||||
return None
|
||||
|
||||
budget = context.get("llm_call_budget")
|
||||
if isinstance(budget, int):
|
||||
if budget <= 0:
|
||||
return None
|
||||
context["llm_call_budget"] = budget - 1
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
text = _message_to_text(response)
|
||||
return _extract_json_object(text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _invoke_llm_text(context: Dict[str, Any], prompt: str) -> str:
|
||||
model = context.get("llm_model")
|
||||
if model is None or not context.get("use_model_generation"):
|
||||
return ""
|
||||
|
||||
budget = context.get("llm_call_budget")
|
||||
if isinstance(budget, int):
|
||||
if budget <= 0:
|
||||
return ""
|
||||
context["llm_call_budget"] = budget - 1
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
return _clean_text(_message_to_text(response))
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _normalize_item_entry(item: Any) -> Optional[Dict[str, Any]]:
|
||||
if isinstance(item, str):
|
||||
content = _clean_text(item)
|
||||
if not content:
|
||||
return None
|
||||
return {"content": content, "coverage_tags": []}
|
||||
|
||||
if isinstance(item, dict):
|
||||
content = _clean_text(str(item.get("content", "")))
|
||||
if not content:
|
||||
return None
|
||||
tags = item.get("coverage_tags") or item.get("covered_points") or []
|
||||
if not isinstance(tags, list):
|
||||
tags = [str(tags)]
|
||||
tags = [_clean_text(str(tag)) for tag in tags if _clean_text(str(tag))]
|
||||
return {"content": content, "coverage_tags": tags}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _dedupe_items(items: List[Dict[str, Any]], max_items: int) -> List[Dict[str, Any]]:
|
||||
merged: Dict[str, Dict[str, Any]] = {}
|
||||
for item in items:
|
||||
content = _clean_text(item.get("content", ""))
|
||||
if not content:
|
||||
continue
|
||||
existing = merged.get(content)
|
||||
if existing is None:
|
||||
merged[content] = {
|
||||
"content": content,
|
||||
"coverage_tags": list(item.get("coverage_tags") or []),
|
||||
}
|
||||
else:
|
||||
existing_tags = set(existing.get("coverage_tags") or [])
|
||||
for tag in item.get("coverage_tags") or []:
|
||||
if tag and tag not in existing_tags:
|
||||
existing_tags.add(tag)
|
||||
existing["coverage_tags"] = list(existing_tags)
|
||||
|
||||
deduped = list(merged.values())
|
||||
return deduped[:max_items]
|
||||
|
||||
|
||||
def _pick_expected_result_placeholder(content: str, abnormal: bool) -> str:
|
||||
text = content or ""
|
||||
|
||||
if abnormal or any(token in text for token in ["非法", "异常", "错误", "拒绝", "越界", "失败"]):
|
||||
return "{{error_message}}"
|
||||
if any(token in text for token in ["状态", "切换", "转换", "恢复"]):
|
||||
return "{{state_change}}"
|
||||
if any(token in text for token in ["数据库", "存储", "落库", "持久化"]):
|
||||
return "{{data_persistence}}"
|
||||
if any(token in text for token in ["界面", "UI", "页面", "按钮", "提示"]):
|
||||
return "{{ui_display}}"
|
||||
return "{{return_value}}"
|
||||
|
||||
|
||||
class IdentifyRequirementTypeTool(TestingTool):
|
||||
name = "identify-requirement-type"
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
raw_text = _clean_text(context.get("user_requirement_text", ""))
|
||||
text = _strip_instruction_prefix(raw_text)
|
||||
if not text:
|
||||
text = raw_text
|
||||
|
||||
max_focus_points = _safe_int(context.get("max_focus_points"), 6, 3, 12)
|
||||
provided_type = _clean_text(context.get("requirement_type_input", ""))
|
||||
focus_points = _extract_focus_points(text, max_points=max_focus_points)
|
||||
fallback_used = False
|
||||
|
||||
if provided_type in REQUIREMENT_TYPES:
|
||||
result = {
|
||||
"requirement_type": provided_type,
|
||||
"reason": "用户已显式指定需求类型,系统按指定类型执行。",
|
||||
"candidates": [],
|
||||
"scores": {},
|
||||
"secondary_types": [],
|
||||
}
|
||||
else:
|
||||
scores = _build_type_scores(text)
|
||||
sorted_pairs = sorted(scores.items(), key=lambda pair: pair[1], reverse=True)
|
||||
best_type, best_score = sorted_pairs[0]
|
||||
secondary = [name for name, score in sorted_pairs[1:4] if score > 0]
|
||||
|
||||
if best_score <= 0:
|
||||
fallback_used = True
|
||||
candidates = _top_candidates(scores)
|
||||
result = {
|
||||
"requirement_type": "未知类型",
|
||||
"reason": "未命中明确分类规则,已回退到未知类型并提供最接近候选。",
|
||||
"candidates": candidates,
|
||||
"scores": scores,
|
||||
"secondary_types": [],
|
||||
}
|
||||
else:
|
||||
signal = TYPE_SIGNAL_RULES.get(best_type, "")
|
||||
result = {
|
||||
"requirement_type": best_type,
|
||||
"reason": f"命中{best_type}识别信号。{signal}",
|
||||
"candidates": [],
|
||||
"scores": scores,
|
||||
"secondary_types": secondary,
|
||||
}
|
||||
|
||||
context["requirement_type_result"] = result
|
||||
context["normalized_requirement_text"] = text
|
||||
context["requirement_focus_points"] = focus_points
|
||||
context["knowledge_used"] = bool(context.get("knowledge_context"))
|
||||
|
||||
return ToolExecutionResult(
|
||||
context=context,
|
||||
output_summary=(
|
||||
f"type={result['requirement_type']}; candidates={len(result['candidates'])}; "
|
||||
f"secondary_types={len(result.get('secondary_types', []))}; focus_points={len(focus_points)}"
|
||||
),
|
||||
fallback_used=fallback_used,
|
||||
)
|
||||
|
||||
|
||||
class DecomposeTestItemsTool(TestingTool):
|
||||
name = "decompose-test-items"
|
||||
|
||||
@staticmethod
|
||||
def _seed_items(
|
||||
req_type: str,
|
||||
req_text: str,
|
||||
focus_points: List[str],
|
||||
max_items: int,
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
if req_type in REQUIREMENT_RULES:
|
||||
source_rules = REQUIREMENT_RULES[req_type]
|
||||
normal_templates = list(source_rules.get("normal", []))
|
||||
abnormal_templates = list(source_rules.get("abnormal", []))
|
||||
else:
|
||||
normal_templates = list(GENERIC_DECOMPOSITION_RULES["normal"])
|
||||
abnormal_templates = list(GENERIC_DECOMPOSITION_RULES["abnormal"])
|
||||
|
||||
normal: List[Dict[str, Any]] = []
|
||||
abnormal: List[Dict[str, Any]] = []
|
||||
|
||||
for template in normal_templates:
|
||||
normal.append({"content": template, "coverage_tags": [req_type]})
|
||||
for template in abnormal_templates:
|
||||
abnormal.append({"content": template, "coverage_tags": [req_type]})
|
||||
|
||||
for point in focus_points:
|
||||
normal.extend(
|
||||
[
|
||||
{
|
||||
"content": f"验证{point}在标准作业流程下稳定执行且结果符合业务约束。",
|
||||
"coverage_tags": [point, "正常流程"],
|
||||
},
|
||||
{
|
||||
"content": f"验证{point}与相关联动控制、状态同步和回执反馈的一致性。",
|
||||
"coverage_tags": [point, "联动一致性"],
|
||||
},
|
||||
]
|
||||
)
|
||||
abnormal.extend(
|
||||
[
|
||||
{
|
||||
"content": f"验证{point}在非法输入、错误指令或权限异常时的保护与拒绝机制。",
|
||||
"coverage_tags": [point, "异常输入"],
|
||||
},
|
||||
{
|
||||
"content": f"验证{point}在边界条件、时序冲突或设备故障下的告警和恢复行为。",
|
||||
"coverage_tags": [point, "边界异常"],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
if any(token in req_text for token in ["手册", "操作手册", "用户手册", "作业指导"]):
|
||||
normal.append(
|
||||
{
|
||||
"content": "验证需求说明未显式给出但在用户手册或操作手册体现的功能流程。",
|
||||
"coverage_tags": ["手册功能"],
|
||||
}
|
||||
)
|
||||
|
||||
return _dedupe_items(normal, max_items), _dedupe_items(abnormal, max_items)
|
||||
|
||||
@staticmethod
|
||||
def _generate_by_llm(context: Dict[str, Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
req_result = context.get("requirement_type_result", {})
|
||||
req_type = req_result.get("requirement_type", "未知类型")
|
||||
req_text = context.get("normalized_requirement_text", "")
|
||||
focus_points = context.get("requirement_focus_points", [])
|
||||
max_items = _safe_int(context.get("max_items_per_group"), 12, 4, 30)
|
||||
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=2500)
|
||||
|
||||
prompt = f"""
|
||||
你是资深测试分析师。请根据需求、分解规则和知识库片段,生成尽可能覆盖要点的测试项。
|
||||
|
||||
需求文本:{req_text}
|
||||
需求类型:{req_type}
|
||||
需求要点:{focus_points}
|
||||
知识库片段:{knowledge_context or '无'}
|
||||
|
||||
分解约束:
|
||||
1. 正常测试与异常测试必须分组输出。
|
||||
2. 每条测试项必须可执行、可验证,避免模板化空话。
|
||||
3. 尽可能覆盖全部需求要点;每组建议输出6-{max_items}条。
|
||||
4. 优先生成与需求对象/控制逻辑/异常处理/边界条件强相关的测试项。
|
||||
|
||||
请仅输出 JSON 对象,结构如下:
|
||||
{{
|
||||
"normal_test_items": [
|
||||
{{"content": "...", "coverage_tags": ["..."]}}
|
||||
],
|
||||
"abnormal_test_items": [
|
||||
{{"content": "...", "coverage_tags": ["..."]}}
|
||||
]
|
||||
}}
|
||||
""".strip()
|
||||
|
||||
data = _invoke_llm_json(context, prompt)
|
||||
if not data:
|
||||
return [], []
|
||||
|
||||
normal_raw = data.get("normal_test_items", [])
|
||||
abnormal_raw = data.get("abnormal_test_items", [])
|
||||
|
||||
normal: List[Dict[str, Any]] = []
|
||||
abnormal: List[Dict[str, Any]] = []
|
||||
|
||||
for item in normal_raw if isinstance(normal_raw, list) else []:
|
||||
normalized = _normalize_item_entry(item)
|
||||
if normalized:
|
||||
normal.append(normalized)
|
||||
|
||||
for item in abnormal_raw if isinstance(abnormal_raw, list) else []:
|
||||
normalized = _normalize_item_entry(item)
|
||||
if normalized:
|
||||
abnormal.append(normalized)
|
||||
|
||||
return _dedupe_items(normal, max_items), _dedupe_items(abnormal, max_items)
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
req_result = context.get("requirement_type_result", {})
|
||||
req_type = req_result.get("requirement_type", "未知类型")
|
||||
req_text = context.get("normalized_requirement_text") or _strip_instruction_prefix(
|
||||
context.get("user_requirement_text", "")
|
||||
)
|
||||
focus_points = context.get("requirement_focus_points", [])
|
||||
max_items = _safe_int(context.get("max_items_per_group"), 12, 4, 30)
|
||||
|
||||
seeded_normal, seeded_abnormal = self._seed_items(req_type, req_text, focus_points, max_items)
|
||||
llm_normal, llm_abnormal = self._generate_by_llm(context)
|
||||
|
||||
merged_normal = _dedupe_items(llm_normal + seeded_normal, max_items)
|
||||
merged_abnormal = _dedupe_items(llm_abnormal + seeded_abnormal, max_items)
|
||||
|
||||
fallback_used = not bool(llm_normal or llm_abnormal)
|
||||
|
||||
normal_items: List[Dict[str, Any]] = []
|
||||
abnormal_items: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, item in enumerate(merged_normal, start=1):
|
||||
normal_items.append(
|
||||
{
|
||||
"id": f"N{idx}",
|
||||
"content": item["content"],
|
||||
"coverage_tags": item.get("coverage_tags", []),
|
||||
}
|
||||
)
|
||||
|
||||
for idx, item in enumerate(merged_abnormal, start=1):
|
||||
abnormal_items.append(
|
||||
{
|
||||
"id": f"E{idx}",
|
||||
"content": item["content"],
|
||||
"coverage_tags": item.get("coverage_tags", []),
|
||||
}
|
||||
)
|
||||
|
||||
context["test_items"] = {
|
||||
"normal": normal_items,
|
||||
"abnormal": abnormal_items,
|
||||
}
|
||||
context["decompose_force_rules"] = DECOMPOSE_FORCE_RULES
|
||||
|
||||
return ToolExecutionResult(
|
||||
context=context,
|
||||
output_summary=(
|
||||
f"normal_items={len(normal_items)}; abnormal_items={len(abnormal_items)}; "
|
||||
f"llm_items={len(llm_normal) + len(llm_abnormal)}"
|
||||
),
|
||||
fallback_used=fallback_used,
|
||||
)
|
||||
|
||||
|
||||
class GenerateTestCasesTool(TestingTool):
|
||||
name = "generate-test-cases"
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_steps(item_content: str, abnormal: bool, variant: str) -> List[str]:
|
||||
if abnormal:
|
||||
return [
|
||||
"确认测试前置环境、设备状态与日志采集开关已准备就绪。",
|
||||
f"准备异常场景“{variant}”所需的输入数据、操作账号和触发条件。",
|
||||
f"在目标对象执行异常触发操作,重点验证:{item_content}",
|
||||
"持续观察系统返回码、错误文案、告警信息与日志链路完整性。",
|
||||
"检查保护机制是否生效,包括拒绝策略、回滚行为和状态一致性。",
|
||||
"记录证据并复位环境,确认异常处理后系统可恢复到稳定状态。",
|
||||
]
|
||||
|
||||
return [
|
||||
"确认测试环境、设备连接状态和前置业务数据均已初始化。",
|
||||
f"准备“{variant}”所需输入参数、操作路径和判定阈值。",
|
||||
f"在目标对象执行业务控制流程,重点验证:{item_content}",
|
||||
"校验关键返回值、状态变化、控制回执及界面或接口反馈结果。",
|
||||
"检查联动模块、日志记录和数据落库是否满足一致性要求。",
|
||||
"沉淀测试证据并恢复环境,确保后续用例可重复执行。",
|
||||
]
|
||||
|
||||
def _generate_cases_by_llm(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
item: Dict[str, Any],
|
||||
abnormal: bool,
|
||||
cases_per_item: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
req_text = context.get("normalized_requirement_text", "")
|
||||
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=1800)
|
||||
|
||||
prompt = f"""
|
||||
你是资深测试工程师。请围绕给定测试项生成详细测试用例。
|
||||
|
||||
需求:{req_text}
|
||||
测试项:{item.get('content', '')}
|
||||
测试类型:{'异常测试' if abnormal else '正常测试'}
|
||||
知识库片段:{knowledge_context or '无'}
|
||||
|
||||
要求:
|
||||
1. 生成 {cases_per_item}-{max(cases_per_item + 1, cases_per_item)} 条测试用例。
|
||||
2. 每条用例包含 test_content 与 operation_steps。
|
||||
3. operation_steps 必须详细,至少5步,包含前置、执行、观察、校验与证据留存。
|
||||
4. 内容必须围绕当前测试项,不要输出空洞模板。
|
||||
|
||||
仅输出 JSON:
|
||||
{{
|
||||
"test_cases": [
|
||||
{{
|
||||
"title": "...",
|
||||
"test_content": "...",
|
||||
"operation_steps": ["...", "..."]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""".strip()
|
||||
|
||||
data = _invoke_llm_json(context, prompt)
|
||||
if not data:
|
||||
return []
|
||||
|
||||
raw_cases = data.get("test_cases", [])
|
||||
if not isinstance(raw_cases, list):
|
||||
return []
|
||||
|
||||
normalized_cases: List[Dict[str, Any]] = []
|
||||
for case in raw_cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
test_content = _clean_text(str(case.get("test_content", "")))
|
||||
if not test_content:
|
||||
continue
|
||||
steps = case.get("operation_steps", [])
|
||||
if not isinstance(steps, list):
|
||||
continue
|
||||
cleaned_steps = [_clean_text(str(step)) for step in steps if _clean_text(str(step))]
|
||||
if len(cleaned_steps) < 5:
|
||||
continue
|
||||
normalized_cases.append(
|
||||
{
|
||||
"title": _clean_text(str(case.get("title", ""))),
|
||||
"test_content": test_content,
|
||||
"operation_steps": cleaned_steps,
|
||||
}
|
||||
)
|
||||
|
||||
return normalized_cases[: max(1, cases_per_item)]
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
test_items = context.get("test_items", {})
|
||||
cases_per_item = _safe_int(context.get("cases_per_item"), 2, 1, 5)
|
||||
|
||||
normal_cases: List[Dict[str, Any]] = []
|
||||
abnormal_cases: List[Dict[str, Any]] = []
|
||||
llm_case_count = 0
|
||||
|
||||
for item in test_items.get("normal", []):
|
||||
generated = self._generate_cases_by_llm(context, item, abnormal=False, cases_per_item=cases_per_item)
|
||||
if not generated:
|
||||
generated = [
|
||||
{
|
||||
"title": "标准流程验证",
|
||||
"test_content": f"验证{item['content']}",
|
||||
"operation_steps": self._build_fallback_steps(item["content"], False, "标准流程"),
|
||||
},
|
||||
{
|
||||
"title": "边界与联动验证",
|
||||
"test_content": f"验证{item['content']}在边界条件和联动场景下的稳定性",
|
||||
"operation_steps": self._build_fallback_steps(item["content"], False, "边界与联动"),
|
||||
},
|
||||
][:cases_per_item]
|
||||
else:
|
||||
llm_case_count += len(generated)
|
||||
|
||||
for idx, case in enumerate(generated, start=1):
|
||||
merged_content = _clean_text(case.get("test_content", item["content"]))
|
||||
placeholder = _pick_expected_result_placeholder(merged_content, abnormal=False)
|
||||
normal_cases.append(
|
||||
{
|
||||
"id": f"{item['id']}-C{idx}",
|
||||
"item_id": item["id"],
|
||||
"title": _clean_text(case.get("title", "")),
|
||||
"operation_steps": case.get("operation_steps", []),
|
||||
"test_content": merged_content,
|
||||
"expected_result_placeholder": placeholder,
|
||||
}
|
||||
)
|
||||
|
||||
for item in test_items.get("abnormal", []):
|
||||
generated = self._generate_cases_by_llm(context, item, abnormal=True, cases_per_item=cases_per_item)
|
||||
if not generated:
|
||||
generated = [
|
||||
{
|
||||
"title": "非法输入与权限异常验证",
|
||||
"test_content": f"验证{item['content']}在非法输入与权限异常下的处理表现",
|
||||
"operation_steps": self._build_fallback_steps(item["content"], True, "非法输入与权限异常"),
|
||||
},
|
||||
{
|
||||
"title": "故障与时序冲突验证",
|
||||
"test_content": f"验证{item['content']}在故障和时序冲突场景下的保护行为",
|
||||
"operation_steps": self._build_fallback_steps(item["content"], True, "故障与时序冲突"),
|
||||
},
|
||||
][:cases_per_item]
|
||||
else:
|
||||
llm_case_count += len(generated)
|
||||
|
||||
for idx, case in enumerate(generated, start=1):
|
||||
merged_content = _clean_text(case.get("test_content", item["content"]))
|
||||
placeholder = _pick_expected_result_placeholder(merged_content, abnormal=True)
|
||||
abnormal_cases.append(
|
||||
{
|
||||
"id": f"{item['id']}-C{idx}",
|
||||
"item_id": item["id"],
|
||||
"title": _clean_text(case.get("title", "")),
|
||||
"operation_steps": case.get("operation_steps", []),
|
||||
"test_content": merged_content,
|
||||
"expected_result_placeholder": placeholder,
|
||||
}
|
||||
)
|
||||
|
||||
context["test_cases"] = {
|
||||
"normal": normal_cases,
|
||||
"abnormal": abnormal_cases,
|
||||
}
|
||||
|
||||
return ToolExecutionResult(
|
||||
context=context,
|
||||
output_summary=(
|
||||
f"normal_cases={len(normal_cases)}; abnormal_cases={len(abnormal_cases)}; llm_cases={llm_case_count}"
|
||||
),
|
||||
fallback_used=llm_case_count == 0,
|
||||
)
|
||||
|
||||
|
||||
class BuildExpectedResultsTool(TestingTool):
|
||||
name = "build_expected_results"
|
||||
|
||||
def _expected_for_case(self, context: Dict[str, Any], case: Dict[str, Any], abnormal: bool) -> str:
|
||||
placeholder = case.get("expected_result_placeholder", "{{return_value}}")
|
||||
if placeholder not in EXPECTED_RESULT_PLACEHOLDER_MAP:
|
||||
placeholder = "{{return_value}}"
|
||||
|
||||
req_text = context.get("normalized_requirement_text", "")
|
||||
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=1200)
|
||||
prompt = f"""
|
||||
请基于以下信息生成一条可验证、可度量的测试预期结果,避免模板化空话。
|
||||
|
||||
需求:{req_text}
|
||||
测试内容:{case.get('test_content', '')}
|
||||
测试类型:{'异常测试' if abnormal else '正常测试'}
|
||||
占位符语义:{placeholder} -> {EXPECTED_RESULT_PLACEHOLDER_MAP.get(placeholder, '')}
|
||||
知识库片段:{knowledge_context or '无'}
|
||||
|
||||
输出要求:
|
||||
1. 仅输出一句中文预期结果。
|
||||
2. 结果必须可判定成功/失败。
|
||||
3. 包含关键观测项(返回值、状态、告警、日志、数据一致性中的相关项)。
|
||||
""".strip()
|
||||
|
||||
llm_text = _invoke_llm_text(context, prompt)
|
||||
if llm_text:
|
||||
return _truncate_text(llm_text, max_len=220)
|
||||
|
||||
test_content = _clean_text(case.get("test_content", ""))
|
||||
if placeholder == "{{error_message}}":
|
||||
return f"触发{test_content}后,系统应返回明确错误码与错误文案,拒绝非法请求且核心状态保持一致。"
|
||||
if placeholder == "{{state_change}}":
|
||||
return f"执行{test_content}后,系统状态转换应符合需求定义,状态变化可被日志与回执共同验证。"
|
||||
if placeholder == "{{data_persistence}}":
|
||||
return f"执行{test_content}后,数据库或存储层应产生符合约束的持久化结果且无脏数据。"
|
||||
if placeholder == "{{ui_display}}":
|
||||
return f"执行{test_content}后,界面应展示与控制结果一致的反馈信息且提示可被用户执行。"
|
||||
|
||||
if abnormal:
|
||||
return f"执行异常场景“{test_content}”后,系统应触发保护策略并输出可追溯日志,业务状态保持可恢复。"
|
||||
|
||||
return f"执行“{test_content}”后,返回值与状态变化应满足需求约束,关键结果可通过日志或回执验证。"
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
test_cases = context.get("test_cases", {})
|
||||
|
||||
normal_expected: List[Dict[str, str]] = []
|
||||
abnormal_expected: List[Dict[str, str]] = []
|
||||
|
||||
for case in test_cases.get("normal", []):
|
||||
normal_expected.append(
|
||||
{
|
||||
"id": case["id"],
|
||||
"case_id": case["id"],
|
||||
"result": self._expected_for_case(context, case, abnormal=False),
|
||||
}
|
||||
)
|
||||
|
||||
for case in test_cases.get("abnormal", []):
|
||||
abnormal_expected.append(
|
||||
{
|
||||
"id": case["id"],
|
||||
"case_id": case["id"],
|
||||
"result": self._expected_for_case(context, case, abnormal=True),
|
||||
}
|
||||
)
|
||||
|
||||
context["expected_results"] = {
|
||||
"normal": normal_expected,
|
||||
"abnormal": abnormal_expected,
|
||||
}
|
||||
|
||||
return ToolExecutionResult(
|
||||
context=context,
|
||||
output_summary=(
|
||||
f"normal_expected={len(normal_expected)}; abnormal_expected={len(abnormal_expected)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FormatOutputTool(TestingTool):
|
||||
name = "format_output"
|
||||
|
||||
@staticmethod
|
||||
def _format_case_block(case: Dict[str, Any], index: int) -> List[str]:
|
||||
item_id = case.get("item_id", case.get("id", ""))
|
||||
title = _clean_text(case.get("title", ""))
|
||||
|
||||
block: List[str] = []
|
||||
block.append(f"{index}. [用例 {case['id']}](对应测试项 {item_id}):{case.get('test_content', '')}")
|
||||
if title:
|
||||
block.append(f" 场景标题:{title}")
|
||||
block.append(" 操作步骤:")
|
||||
for step_idx, step in enumerate(case.get("operation_steps", []), start=1):
|
||||
block.append(f" {step_idx}) {step}")
|
||||
return block
|
||||
|
||||
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
|
||||
test_items = context.get("test_items", {"normal": [], "abnormal": []})
|
||||
test_cases = context.get("test_cases", {"normal": [], "abnormal": []})
|
||||
expected_results = context.get("expected_results", {"normal": [], "abnormal": []})
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
lines.append("**测试项**")
|
||||
lines.append("")
|
||||
lines.append("**正常测试**:")
|
||||
for index, item in enumerate(test_items.get("normal", []), start=1):
|
||||
lines.append(f"{index}. [测试项 {item['id']}]:{item['content']}")
|
||||
lines.append("")
|
||||
lines.append("**异常测试**:")
|
||||
for index, item in enumerate(test_items.get("abnormal", []), start=1):
|
||||
lines.append(f"{index}. [测试项 {item['id']}]:{item['content']}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("**测试用例**")
|
||||
lines.append("")
|
||||
lines.append("**正常测试**:")
|
||||
for index, case in enumerate(test_cases.get("normal", []), start=1):
|
||||
lines.extend(self._format_case_block(case, index))
|
||||
lines.append("")
|
||||
lines.append("**异常测试**:")
|
||||
for index, case in enumerate(test_cases.get("abnormal", []), start=1):
|
||||
lines.extend(self._format_case_block(case, index))
|
||||
|
||||
lines.append("")
|
||||
lines.append("**预期成果**")
|
||||
lines.append("")
|
||||
lines.append("**正常测试**:")
|
||||
for index, expected in enumerate(expected_results.get("normal", []), start=1):
|
||||
lines.append(
|
||||
f"{index}. [预期 {expected['id']}](对应用例 {expected['case_id']}):{expected['result']}"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("**异常测试**:")
|
||||
for index, expected in enumerate(expected_results.get("abnormal", []), start=1):
|
||||
lines.append(
|
||||
f"{index}. [预期 {expected['id']}](对应用例 {expected['case_id']}):{expected['result']}"
|
||||
)
|
||||
|
||||
context["formatted_output"] = "\n".join(lines)
|
||||
context["structured_output"] = {
|
||||
"test_items": test_items,
|
||||
"test_cases": test_cases,
|
||||
"expected_results": expected_results,
|
||||
}
|
||||
|
||||
return ToolExecutionResult(
|
||||
context=context,
|
||||
output_summary="formatted_sections=3",
|
||||
)
|
||||
|
||||
|
||||
def build_default_tool_chain() -> List[TestingTool]:
|
||||
return [
|
||||
IdentifyRequirementTypeTool(),
|
||||
DecomposeTestItemsTool(),
|
||||
GenerateTestCasesTool(),
|
||||
BuildExpectedResultsTool(),
|
||||
FormatOutputTool(),
|
||||
]
|
||||
122
rag-web-ui/backend/app/services/vector_schema.py
Normal file
122
rag-web-ui/backend/app/services/vector_schema.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkVectorMetadata:
|
||||
"""Metadata payload for vector DB and graph linkage."""
|
||||
|
||||
chunk_id: str
|
||||
kb_id: int
|
||||
document_id: int
|
||||
document_name: str
|
||||
document_path: str
|
||||
chunk_index: int
|
||||
chunk_text: str
|
||||
token_count: int
|
||||
language: str = "zh"
|
||||
source_type: str = "document"
|
||||
mission_phase: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
publish_time: Optional[str] = None
|
||||
extracted_entities: List[str] = field(default_factory=list)
|
||||
extracted_entity_types: List[str] = field(default_factory=list)
|
||||
extracted_relations: List[Dict[str, Any]] = field(default_factory=list)
|
||||
graph_node_ids: List[str] = field(default_factory=list)
|
||||
graph_edge_ids: List[str] = field(default_factory=list)
|
||||
community_ids: List[str] = field(default_factory=list)
|
||||
embedding_model: str = ""
|
||||
embedding_dim: int = 0
|
||||
ingest_time: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"chunk_id": self.chunk_id,
|
||||
"kb_id": self.kb_id,
|
||||
"document_id": self.document_id,
|
||||
"document_name": self.document_name,
|
||||
"document_path": self.document_path,
|
||||
"chunk_index": self.chunk_index,
|
||||
"chunk_text": self.chunk_text,
|
||||
"token_count": self.token_count,
|
||||
"language": self.language,
|
||||
"source_type": self.source_type,
|
||||
"mission_phase": self.mission_phase,
|
||||
"section_title": self.section_title,
|
||||
"publish_time": self.publish_time,
|
||||
"extracted_entities": self.extracted_entities,
|
||||
"extracted_entity_types": self.extracted_entity_types,
|
||||
"extracted_relations": self.extracted_relations,
|
||||
"graph_node_ids": self.graph_node_ids,
|
||||
"graph_edge_ids": self.graph_edge_ids,
|
||||
"community_ids": self.community_ids,
|
||||
"embedding_model": self.embedding_model,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"ingest_time": self.ingest_time,
|
||||
}
|
||||
|
||||
|
||||
def qdrant_collection_schema(collection_name: str, vector_size: int) -> Dict[str, Any]:
|
||||
"""Qdrant collection and payload index recommendations."""
|
||||
return {
|
||||
"collection_name": collection_name,
|
||||
"vectors": {
|
||||
"size": vector_size,
|
||||
"distance": "Cosine",
|
||||
},
|
||||
"payload_indexes": [
|
||||
{"field_name": "kb_id", "field_schema": "integer"},
|
||||
{"field_name": "document_id", "field_schema": "integer"},
|
||||
{"field_name": "document_name", "field_schema": "keyword"},
|
||||
{"field_name": "chunk_id", "field_schema": "keyword"},
|
||||
{"field_name": "mission_phase", "field_schema": "keyword"},
|
||||
{"field_name": "community_ids", "field_schema": "keyword"},
|
||||
{"field_name": "extracted_entities", "field_schema": "keyword"},
|
||||
{"field_name": "ingest_time", "field_schema": "datetime"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def milvus_collection_schema(collection_name: str, vector_size: int) -> Dict[str, Any]:
|
||||
"""Milvus field design for vector+graph linkage."""
|
||||
return {
|
||||
"collection_name": collection_name,
|
||||
"fields": [
|
||||
{"name": "id", "type": "VARCHAR", "max_length": 64, "is_primary": True},
|
||||
{"name": "kb_id", "type": "INT64"},
|
||||
{"name": "document_id", "type": "INT64"},
|
||||
{"name": "chunk_index", "type": "INT32"},
|
||||
{"name": "document_name", "type": "VARCHAR", "max_length": 255},
|
||||
{"name": "mission_phase", "type": "VARCHAR", "max_length": 64},
|
||||
{"name": "community_ids", "type": "VARCHAR", "max_length": 512},
|
||||
{"name": "extracted_entities", "type": "VARCHAR", "max_length": 2048},
|
||||
{"name": "ingest_time", "type": "VARCHAR", "max_length": 64},
|
||||
{"name": "embedding", "type": "FLOAT_VECTOR", "dim": vector_size},
|
||||
],
|
||||
"index": {
|
||||
"field_name": "embedding",
|
||||
"index_type": "HNSW",
|
||||
"metric_type": "COSINE",
|
||||
"params": {"M": 16, "efConstruction": 200},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
DOCUMENT_CHUNK_METADATA_DDL = """
|
||||
ALTER TABLE document_chunks
|
||||
ADD COLUMN IF NOT EXISTS chunk_index INT NULL,
|
||||
ADD COLUMN IF NOT EXISTS token_count INT NULL,
|
||||
ADD COLUMN IF NOT EXISTS language VARCHAR(16) DEFAULT 'zh',
|
||||
ADD COLUMN IF NOT EXISTS mission_phase VARCHAR(64) NULL,
|
||||
ADD COLUMN IF NOT EXISTS extracted_entities JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS extracted_entity_types JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS extracted_relations JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS graph_node_ids JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS graph_edge_ids JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS community_ids JSON NULL,
|
||||
ADD COLUMN IF NOT EXISTS embedding_model VARCHAR(128) NULL,
|
||||
ADD COLUMN IF NOT EXISTS embedding_dim INT NULL;
|
||||
""".strip()
|
||||
11
rag-web-ui/backend/app/services/vector_store/__init__.py
Normal file
11
rag-web-ui/backend/app/services/vector_store/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .base import BaseVectorStore
|
||||
from .chroma import ChromaVectorStore
|
||||
from .qdrant import QdrantStore
|
||||
from .factory import VectorStoreFactory
|
||||
|
||||
__all__ = [
|
||||
'BaseVectorStore',
|
||||
'ChromaVectorStore',
|
||||
'QdrantStore',
|
||||
'VectorStoreFactory'
|
||||
]
|
||||
42
rag-web-ui/backend/app/services/vector_store/base.py
Normal file
42
rag-web-ui/backend/app/services/vector_store/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict, Any
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
"""Abstract base class for vector store implementations"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
|
||||
"""Initialize the vector store"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_documents(self, documents: List[Document]) -> None:
|
||||
"""Add documents to the vector store"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, ids: List[str]) -> None:
|
||||
"""Delete documents from the vector store"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_retriever(self, **kwargs: Any):
|
||||
"""Return a retriever interface for the vector store"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents with score"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self) -> None:
|
||||
"""Delete the entire collection"""
|
||||
pass
|
||||
47
rag-web-ui/backend/app/services/vector_store/chroma.py
Normal file
47
rag-web-ui/backend/app/services/vector_store/chroma.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Any
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_chroma import Chroma
|
||||
import chromadb
|
||||
from app.core.config import settings
|
||||
|
||||
from .base import BaseVectorStore
|
||||
|
||||
class ChromaVectorStore(BaseVectorStore):
|
||||
"""Chroma vector store implementation"""
|
||||
|
||||
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
|
||||
"""Initialize Chroma vector store"""
|
||||
chroma_client = chromadb.HttpClient(
|
||||
host=settings.CHROMA_DB_HOST,
|
||||
port=settings.CHROMA_DB_PORT,
|
||||
)
|
||||
|
||||
self._store = Chroma(
|
||||
client=chroma_client,
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
def add_documents(self, documents: List[Document]) -> None:
|
||||
"""Add documents to Chroma"""
|
||||
self._store.add_documents(documents)
|
||||
|
||||
def delete(self, ids: List[str]) -> None:
|
||||
"""Delete documents from Chroma"""
|
||||
self._store.delete(ids)
|
||||
|
||||
def as_retriever(self, **kwargs: Any):
|
||||
"""Return a retriever interface"""
|
||||
return self._store.as_retriever(**kwargs)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents in Chroma"""
|
||||
return self._store.similarity_search(query, k=k, **kwargs)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents in Chroma with score"""
|
||||
return self._store.similarity_search_with_score(query, k=k, **kwargs)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""Delete the entire collection"""
|
||||
self._store._client.delete_collection(self._store._collection.name)
|
||||
59
rag-web-ui/backend/app/services/vector_store/factory.py
Normal file
59
rag-web-ui/backend/app/services/vector_store/factory.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Dict, Type, Any
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from .base import BaseVectorStore
|
||||
from .chroma import ChromaVectorStore
|
||||
from .qdrant import QdrantStore
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""Factory for creating vector store instances"""
|
||||
|
||||
_stores: Dict[str, Type[BaseVectorStore]] = {
|
||||
'chroma': ChromaVectorStore,
|
||||
'qdrant': QdrantStore
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
store_type: str,
|
||||
collection_name: str,
|
||||
embedding_function: Embeddings,
|
||||
**kwargs: Any
|
||||
) -> BaseVectorStore:
|
||||
"""Create a vector store instance
|
||||
|
||||
Args:
|
||||
store_type: Type of vector store ('chroma', 'qdrant', etc.)
|
||||
collection_name: Name of the collection
|
||||
embedding_function: Embedding function to use
|
||||
**kwargs: Additional arguments for specific vector store implementations
|
||||
|
||||
Returns:
|
||||
An instance of the requested vector store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_type is not supported
|
||||
"""
|
||||
store_class = cls._stores.get(store_type.lower())
|
||||
if not store_class:
|
||||
raise ValueError(
|
||||
f"Unsupported vector store type: {store_type}. "
|
||||
f"Supported types are: {', '.join(cls._stores.keys())}"
|
||||
)
|
||||
|
||||
return store_class(
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_store(cls, name: str, store_class: Type[BaseVectorStore]) -> None:
|
||||
"""Register a new vector store implementation
|
||||
|
||||
Args:
|
||||
name: Name of the vector store type
|
||||
store_class: Vector store class implementation
|
||||
"""
|
||||
cls._stores[name.lower()] = store_class
|
||||
43
rag-web-ui/backend/app/services/vector_store/qdrant.py
Normal file
43
rag-web-ui/backend/app/services/vector_store/qdrant.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List, Any
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
from app.core.config import settings
|
||||
|
||||
from .base import BaseVectorStore
|
||||
|
||||
class QdrantStore(BaseVectorStore):
|
||||
"""Qdrant vector store implementation"""
|
||||
|
||||
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
|
||||
"""Initialize Qdrant vector store"""
|
||||
self._store = Qdrant(
|
||||
collection_name=collection_name,
|
||||
embeddings=embedding_function,
|
||||
url=settings.QDRANT_URL,
|
||||
prefer_grpc=settings.QDRANT_PREFER_GRPC
|
||||
)
|
||||
|
||||
def add_documents(self, documents: List[Document]) -> None:
|
||||
"""Add documents to Qdrant"""
|
||||
self._store.add_documents(documents)
|
||||
|
||||
def delete(self, ids: List[str]) -> None:
|
||||
"""Delete documents from Qdrant"""
|
||||
self._store.delete(ids)
|
||||
|
||||
def as_retriever(self, **kwargs: Any):
|
||||
"""Return a retriever interface"""
|
||||
return self._store.as_retriever(**kwargs)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents in Qdrant"""
|
||||
return self._store.similarity_search(query, k=k, **kwargs)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
|
||||
"""Search for similar documents in Qdrant with score"""
|
||||
return self._store.similarity_search_with_score(query, k=k, **kwargs)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""Delete the entire collection"""
|
||||
self._store._client.delete_collection(self._store._collection_name)
|
||||
Reference in New Issue
Block a user