init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
from typing import List, Optional
from datetime import datetime
import secrets
from sqlalchemy.orm import Session
from app.models.api_key import APIKey
from app.schemas.api_key import APIKeyCreate, APIKeyUpdate
class APIKeyService:
@staticmethod
def get_api_keys(db: Session, user_id: int, skip: int = 0, limit: int = 100) -> List[APIKey]:
return (
db.query(APIKey)
.filter(APIKey.user_id == user_id)
.offset(skip)
.limit(limit)
.all()
)
@staticmethod
def create_api_key(db: Session, user_id: int, name: str) -> APIKey:
api_key = APIKey(
key=f"sk-{secrets.token_hex(32)}",
name=name,
user_id=user_id,
is_active=True
)
db.add(api_key)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
def get_api_key(db: Session, api_key_id: int) -> Optional[APIKey]:
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
@staticmethod
def get_api_key_by_key(db: Session, key: str) -> Optional[APIKey]:
return db.query(APIKey).filter(APIKey.key == key).first()
@staticmethod
def update_api_key(db: Session, api_key: APIKey, update_data: APIKeyUpdate) -> APIKey:
for field, value in update_data.model_dump(exclude_unset=True).items():
setattr(api_key, field, value)
db.add(api_key)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
def delete_api_key(db: Session, api_key: APIKey) -> None:
db.delete(api_key)
db.commit()
@staticmethod
def update_last_used(db: Session, api_key: APIKey) -> APIKey:
api_key.last_used_at = datetime.utcnow()
db.add(api_key)
db.commit()
db.refresh(api_key)
return api_key

View File

@@ -0,0 +1,532 @@
import base64
import json
import re
from collections import defaultdict
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.config import settings
from app.models.chat import Message
from app.models.knowledge import Document, KnowledgeBase
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.fusion_prompts import (
GENERAL_CHAT_PROMPT_TEMPLATE,
GRAPH_GLOBAL_PROMPT_TEMPLATE,
GRAPH_LOCAL_PROMPT_TEMPLATE,
HYBRID_RAG_PROMPT_TEMPLATE,
)
from app.services.graph.graphrag_adapter import GraphRAGAdapter
from app.services.intent_router import route_intent
from app.services.llm.llm_factory import LLMFactory
from app.services.reranker.external_api import ExternalRerankerClient
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline.pipeline import run_testing_pipeline
from app.services.testing_pipeline.rules import REQUIREMENT_TYPES
from app.services.vector_store import VectorStoreFactory
TESTING_TARGET_KEYWORDS = [
"测试项",
"测试用例",
"预期成果",
"需求类型",
"测试分解",
"分解",
"正常测试",
"异常测试",
"测试充分性",
]
TESTING_ACTION_KEYWORDS = [
"生成",
"输出",
"给出",
"",
"编写",
"设计",
"整理",
"列出",
"提供",
"制定",
]
TYPE_ALIAS_MAP = {
"接口测试": "外部接口测试",
"ui测试": "人机交互界面测试",
"界面测试": "人机交互界面测试",
"恢复测试": "恢复性测试",
"可靠性": "可靠性测试",
"安全性": "安全性测试",
"边界": "边界测试",
"安装": "安装性测试",
"互操作": "互操作性测试",
"敏感性": "敏感性测试",
"充分性": "测试充分性要求",
}
def _escape_stream_text(text: str) -> str:
return text.replace('"', '\\"').replace("\n", "\\n")
def _extract_stream_text(chunk: Any) -> str:
content = getattr(chunk, "content", chunk)
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
maybe_text = item.get("text")
if isinstance(maybe_text, str):
parts.append(maybe_text)
else:
parts.append(str(item))
return "".join(parts)
return str(content)
def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
preview = []
for row in rows[:10]:
doc = row["document"]
metadata = doc.metadata or {}
preview.append(
{
"kb_id": row.get("kb_id"),
"source": metadata.get("source") or metadata.get("file_name") or "unknown",
"chunk_id": metadata.get("chunk_id") or "unknown",
"score": row.get("final_score", 0),
"reranker_score": row.get("reranker_score"),
}
)
return preview
def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
context_rows: List[Dict[str, Any]] = []
for row in rows:
doc = row["document"]
metadata = dict(doc.metadata or {})
if "kb_id" not in metadata and row.get("kb_id") is not None:
metadata["kb_id"] = row.get("kb_id")
metadata.setdefault("retrieval_score", row.get("final_score", 0))
if row.get("reranker_score") is not None:
metadata.setdefault("reranker_score", row.get("reranker_score"))
context_rows.append(
{
"page_content": doc.page_content.strip(),
"metadata": metadata,
}
)
return context_rows
def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str:
entities = set()
relations: List[Dict[str, Any]] = []
evidences: List[str] = []
for row in rows:
doc = row["document"]
metadata = doc.metadata or {}
for ent in metadata.get("extracted_entities", []):
entities.add(str(ent))
for rel in metadata.get("extracted_relations", []):
if isinstance(rel, dict):
relations.append(rel)
evidences.append(doc.page_content.strip())
entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。"
relation_lines: List[str] = []
for rel in relations[:120]:
src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN"
tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN"
rel_type = rel.get("type") or rel.get("relation_type") or "其他"
desc = rel.get("description") or ""
relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}")
relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。"
evidence_block = "\n\n".join(
f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1)
)
if not evidence_block:
evidence_block = "无可用证据。"
return (
"实体列表:\n"
f"{entity_block}\n\n"
"关系列表:\n"
f"{relation_block}\n\n"
"原文证据:\n"
f"{evidence_block}"
)
def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
groups: Dict[str, List[str]] = defaultdict(list)
for row in rows:
doc = row["document"]
metadata = doc.metadata or {}
community_ids = metadata.get("community_ids") or []
if isinstance(community_ids, list) and community_ids:
keys = [str(item) for item in community_ids]
else:
source = metadata.get("source") or metadata.get("file_name") or "unknown"
keys = [f"source:{source}"]
for key in keys:
groups[key].append(doc.page_content.strip())
if not groups:
return "暂无社区摘要数据,已回退为基于证据片段的全局总结。"
lines: List[str] = []
for idx, (community_id, snippets) in enumerate(groups.items(), start=1):
merged = " ".join(snippets[:3])
lines.append(f"社区{idx} ({community_id}) 摘要: {merged}")
return "\n\n".join(lines)
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _build_reranker_client() -> ExternalRerankerClient:
return ExternalRerankerClient(
api_url=settings.RERANKER_API_URL,
api_key=settings.RERANKER_API_KEY,
model=settings.RERANKER_MODEL,
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
)
def _is_testing_generation_request(query: str) -> bool:
text = (query or "").strip()
if not text:
return False
normalized = text.lower()
if normalized.startswith("/testing"):
return True
if any(
token in normalized
for token in (
"testing_orchestrator",
"testing-orchestrator",
"identify_requirement_type",
"identify-requirement-type",
)
):
return True
has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS)
has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS)
if has_target and has_action:
return True
if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")):
if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text):
return True
if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")):
return True
return False
def _extract_requirement_type_from_query(query: str) -> Optional[str]:
text = (query or "").strip()
if not text:
return None
for req_type in REQUIREMENT_TYPES:
if req_type in text:
return req_type
lowered = text.lower()
for alias, req_type in TYPE_ALIAS_MAP.items():
if alias in text or alias in lowered:
return req_type
return None
async def generate_response(
query: str,
messages: dict,
knowledge_base_ids: List[int],
chat_id: int,
db: Any,
) -> AsyncGenerator[str, None]:
try:
user_message = Message(content=query, role="user", chat_id=chat_id)
db.add(user_message)
db.commit()
bot_message = Message(content="", role="assistant", chat_id=chat_id)
db.add(bot_message)
db.commit()
if _is_testing_generation_request(query):
explicit_type = _extract_requirement_type_from_query(query)
retrieval_rows: List[Dict[str, Any]] = []
knowledge_context = ""
kb_vector_stores = []
if knowledge_base_ids:
testing_kbs = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
if kb_vector_stores:
testing_retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await testing_retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
pipeline_result = run_testing_pipeline(
user_requirement_text=query,
requirement_type_input=explicit_type,
debug=True,
knowledge_context=knowledge_context,
use_model_generation=True,
max_items_per_group=6,
cases_per_item=1,
max_focus_points=6,
max_llm_calls=2,
)
context_payload = {
"route": {
"intent": "TESTING",
"reason": "命中测试生成意图,已自动调用测试工具链。",
},
"intent": "TESTING",
"skill_profile": "testing-orchestrator",
"tool_chain": [
"identify-requirement-type",
"decompose-test-items",
"generate-test-cases",
"build_expected_results",
"format_output",
],
"selected_chain": "TESTING_PIPELINE",
"graph_used": False,
"reranker_enabled": False,
"retrieval_preview": _preview_rows(retrieval_rows),
"context": _context_rows(retrieval_rows),
"testing_pipeline": {
"trace_id": pipeline_result.get("trace_id"),
"requirement_type": pipeline_result.get("requirement_type"),
"candidates": pipeline_result.get("candidates", []),
"pipeline_summary": pipeline_result.get("pipeline_summary", ""),
"knowledge_used": pipeline_result.get("knowledge_used", False),
"step_logs": pipeline_result.get("step_logs", []),
},
}
escaped_context = json.dumps(context_payload, ensure_ascii=False)
base64_context = base64.b64encode(escaped_context.encode()).decode()
separator = "__LLM_RESPONSE__"
full_response = f"{base64_context}{separator}"
yield f'0:"{base64_context}{separator}"\n'
rendered_text = pipeline_result.get("formatted_output", "").strip()
if not rendered_text:
rendered_text = "未生成测试内容,请补充更明确的需求后重试。"
full_response += rendered_text
yield f'0:"{_escape_stream_text(rendered_text)}"\n'
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
bot_message.content = full_response
db.commit()
return
knowledge_bases = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_ids = [kb.id for kb in knowledge_bases]
llm = LLMFactory.create()
decision = await route_intent(llm=llm, query=query, messages=messages)
intent = decision["intent"]
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if intent in {"B", "C", "D"} and not kb_vector_stores:
intent = "A"
decision = {
"intent": "A",
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
}
reranker_client = _build_reranker_client()
retriever = MultiKBRetriever(
reranker_client=reranker_client,
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows: List[Dict[str, Any]] = []
graph_used = False
selected_chain = intent
prompt_text = ""
if intent == "A":
prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query)
elif intent == "B":
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=12,
)
context = format_retrieval_context(retrieval_rows) or "无可用证据。"
prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context)
elif intent == "C":
graph_context = ""
used_kb_ids: List[int] = []
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
graph_context, used_kb_ids = await adapter.local_context_multi(
kb_ids,
query,
top_k=settings.GRAPHRAG_LOCAL_TOP_K,
level=settings.GRAPHRAG_QUERY_LEVEL,
)
graph_used = bool(graph_context)
except Exception:
graph_context = ""
if not graph_context:
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=18,
top_k=14,
)
graph_context = _build_local_graph_context_fallback(retrieval_rows)
selected_chain = "C_fallback_B"
else:
selected_chain = "C_graph"
prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format(
query=query,
graph_context=graph_context,
)
else:
community_context = ""
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
community_context, used_kb_ids = await adapter.global_context_multi(
kb_ids,
query,
level=settings.GRAPHRAG_QUERY_LEVEL,
)
graph_used = bool(community_context)
except Exception:
community_context = ""
if not community_context:
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=20,
top_k=14,
)
community_context = _build_global_community_context_fallback(retrieval_rows)
selected_chain = "D_fallback_B"
else:
selected_chain = "D_graph"
prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format(
query=query,
community_context=community_context,
)
context_payload = {
"route": decision,
"intent": intent,
"selected_chain": selected_chain,
"graph_used": graph_used,
"reranker_enabled": reranker_client.enabled,
"retrieval_preview": _preview_rows(retrieval_rows),
"context": _context_rows(retrieval_rows),
}
escaped_context = json.dumps(context_payload, ensure_ascii=False)
base64_context = base64.b64encode(escaped_context.encode()).decode()
separator = "__LLM_RESPONSE__"
full_response = f"{base64_context}{separator}"
yield f'0:"{base64_context}{separator}"\n'
async for chunk in llm.astream(prompt_text):
text = _extract_stream_text(chunk)
if not text:
continue
full_response += text
yield f'0:"{_escape_stream_text(text)}"\n'
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
bot_message.content = full_response
db.commit()
except Exception as e:
error_message = f"Error generating response: {str(e)}"
print(error_message)
yield "3:{text}\n".format(text=error_message)
if "bot_message" in locals():
bot_message.content = error_message
db.commit()
finally:
db.close()

View File

@@ -0,0 +1,69 @@
from typing import Optional, List, Dict, Set
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.knowledge import DocumentChunk
import json
class ChunkRecord:
"""Manages chunk-level record keeping for incremental updates"""
def __init__(self, kb_id: int):
self.kb_id = kb_id
self.engine = create_engine(settings.get_database_url)
def list_chunks(self, file_name: Optional[str] = None) -> Set[str]:
"""List all chunk hashes for the given file"""
with Session(self.engine) as session:
query = session.query(DocumentChunk.hash).filter(
DocumentChunk.kb_id == self.kb_id
)
if file_name:
query = query.filter(DocumentChunk.file_name == file_name)
return {row[0] for row in query.all()}
def add_chunks(self, chunks: List[Dict]):
"""Add new chunks to the database"""
if not chunks:
return
with Session(self.engine) as session:
for chunk_data in chunks:
chunk = DocumentChunk(
id=chunk_data['id'],
kb_id=chunk_data['kb_id'],
document_id=chunk_data['document_id'],
file_name=chunk_data['file_name'],
chunk_metadata=chunk_data['metadata'],
hash=chunk_data['hash']
)
session.merge(chunk) # Use merge instead of add to handle updates
session.commit()
def delete_chunks(self, chunk_ids: List[str]):
"""Delete chunks by their IDs"""
if not chunk_ids:
return
with Session(self.engine) as session:
session.query(DocumentChunk).filter(
DocumentChunk.kb_id == self.kb_id,
DocumentChunk.id.in_(chunk_ids)
).delete(synchronize_session=False)
session.commit()
def get_deleted_chunks(self, current_hashes: Set[str], file_name: Optional[str] = None) -> List[str]:
"""Get IDs of chunks that no longer exist in the current version"""
with Session(self.engine) as session:
query = session.query(DocumentChunk.id).filter(
DocumentChunk.kb_id == self.kb_id
)
if file_name:
query = query.filter(DocumentChunk.file_name == file_name)
if current_hashes:
query = query.filter(DocumentChunk.hash.notin_(current_hashes))
return [row[0] for row in query.all()]

View File

@@ -0,0 +1,582 @@
import logging
import os
import hashlib
import tempfile
import traceback
import json
from app.db.session import SessionLocal
from io import BytesIO
from typing import Optional, List, Dict, Any
from fastapi import UploadFile
from langchain_community.document_loaders import (
PyPDFLoader,
Docx2txtLoader,
UnstructuredMarkdownLoader,
TextLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document as LangchainDocument
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.minio import get_minio_client
from app.models.knowledge import ProcessingTask, Document, DocumentChunk
from app.services.chunk_record import ChunkRecord
from minio.error import MinioException
from minio.commonconfig import CopySource
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
class UploadResult(BaseModel):
file_path: str
file_name: str
file_size: int
content_type: str
file_hash: str
class TextChunk(BaseModel):
content: str
metadata: Optional[Dict] = None
class PreviewResult(BaseModel):
chunks: List[TextChunk]
total_chunks: int
def _estimate_token_count(text: str) -> int:
# Lightweight estimation without adding tokenizer dependencies.
return len(text)
def _build_enriched_chunk_metadata(
*,
source_metadata: Optional[Dict[str, Any]],
chunk_id: str,
file_name: str,
file_path: str,
kb_id: int,
document_id: int,
chunk_index: int,
chunk_text: str,
) -> Dict[str, Any]:
source_metadata = source_metadata or {}
token_count = _estimate_token_count(chunk_text)
return {
**source_metadata,
"source": file_name,
"chunk_id": chunk_id,
"file_name": file_name,
"file_path": file_path,
"kb_id": kb_id,
"document_id": document_id,
"chunk_index": chunk_index,
"chunk_text": chunk_text,
"token_count": token_count,
"language": source_metadata.get("language", "zh"),
"source_type": "document",
"mission_phase": source_metadata.get("mission_phase"),
"section_title": source_metadata.get("section_title"),
"publish_time": source_metadata.get("publish_time"),
# Keep graph-linked fields for future graph/vector federation.
"extracted_entities": source_metadata.get("extracted_entities", []),
"extracted_entity_types": source_metadata.get("extracted_entity_types", []),
"extracted_relations": source_metadata.get("extracted_relations", []),
"graph_node_ids": source_metadata.get("graph_node_ids", []),
"graph_edge_ids": source_metadata.get("graph_edge_ids", []),
"community_ids": source_metadata.get("community_ids", []),
}
def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Normalize metadata to satisfy Chroma's strict metadata constraints."""
if not metadata:
return {}
sanitized: Dict[str, Any] = {}
scalar_types = (str, int, float, bool)
for key, value in metadata.items():
if value is None:
continue
if isinstance(value, scalar_types):
sanitized[key] = value
continue
if isinstance(value, list):
primitive_items = [item for item in value if isinstance(item, scalar_types)]
if primitive_items:
sanitized[key] = primitive_items
elif value:
sanitized[key] = json.dumps(value, ensure_ascii=False)
continue
if isinstance(value, dict):
sanitized[key] = json.dumps(value, ensure_ascii=False)
continue
sanitized[key] = str(value)
return sanitized
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
"""Process document and store in vector database with incremental updates"""
logger = logging.getLogger(__name__)
try:
preview_result = await preview_document(file_path, chunk_size, chunk_overlap)
# Initialize embeddings
logger.info("Initializing OpenAI embeddings...")
embeddings = EmbeddingsFactory.create()
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# Initialize chunk record manager
chunk_manager = ChunkRecord(kb_id)
# Get existing chunk hashes for this file
existing_hashes = chunk_manager.list_chunks(file_name)
# Prepare new chunks
new_chunks = []
current_hashes = set()
documents_to_update = []
for i, chunk in enumerate(preview_result.chunks):
# Calculate chunk hash
chunk_hash = hashlib.sha256(
(chunk.content + str(chunk.metadata)).encode()
).hexdigest()
current_hashes.add(chunk_hash)
# Skip if chunk hasn't changed
if chunk_hash in existing_hashes:
continue
# Create unique ID for the chunk
chunk_id = hashlib.sha256(
f"{kb_id}:{file_name}:{chunk_hash}".encode()
).hexdigest()
metadata = _build_enriched_chunk_metadata(
source_metadata=chunk.metadata,
chunk_id=chunk_id,
file_name=file_name,
file_path=file_path,
kb_id=kb_id,
document_id=document_id,
chunk_index=i,
chunk_text=chunk.content,
)
vector_metadata = _sanitize_metadata_for_vector_store(metadata)
new_chunks.append({
"id": chunk_id,
"kb_id": kb_id,
"document_id": document_id,
"file_name": file_name,
"metadata": metadata,
"hash": chunk_hash
})
# Prepare document for vector store
doc = LangchainDocument(
page_content=chunk.content,
metadata=vector_metadata
)
documents_to_update.append(doc)
# Add new chunks to database and vector store
if new_chunks:
logger.info(f"Adding {len(new_chunks)} new/updated chunks")
chunk_manager.add_chunks(new_chunks)
vector_store.add_documents(documents_to_update)
if settings.GRAPHRAG_ENABLED:
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
graph_adapter = GraphRAGAdapter()
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info("GraphRAG ingestion completed in incremental processing")
except Exception as graph_exc:
logger.error(f"GraphRAG ingestion failed in incremental processing: {graph_exc}")
# Delete removed chunks
chunks_to_delete = chunk_manager.get_deleted_chunks(current_hashes, file_name)
if chunks_to_delete:
logger.info(f"Removing {len(chunks_to_delete)} deleted chunks")
chunk_manager.delete_chunks(chunks_to_delete)
vector_store.delete(chunks_to_delete)
logger.info("Document processing completed successfully")
except Exception as e:
logger.error(f"Error processing document: {str(e)}")
raise
async def upload_document(file: UploadFile, kb_id: int) -> UploadResult:
"""Step 1: Upload document to MinIO"""
content = await file.read()
file_size = len(content)
file_hash = hashlib.sha256(content).hexdigest()
# Clean and normalize filename
file_name = "".join(c for c in file.filename if c.isalnum() or c in ('-', '_', '.')).strip()
object_path = f"kb_{kb_id}/{file_name}"
content_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".md": "text/markdown",
".txt": "text/plain"
}
_, ext = os.path.splitext(file_name)
content_type = content_types.get(ext.lower(), "application/octet-stream")
# Upload to MinIO
minio_client = get_minio_client()
try:
minio_client.put_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=object_path,
data=BytesIO(content),
length=file_size,
content_type=content_type
)
except Exception as e:
logging.error(f"Failed to upload file to MinIO: {str(e)}")
raise
return UploadResult(
file_path=object_path,
file_name=file_name,
file_size=file_size,
content_type=content_type,
file_hash=file_hash
)
async def preview_document(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> PreviewResult:
"""Step 2: Generate preview chunks"""
# Get file from MinIO
minio_client = get_minio_client()
_, ext = os.path.splitext(file_path)
ext = ext.lower()
# Download to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
minio_client.fget_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=file_path,
file_path=temp_file.name
)
temp_path = temp_file.name
try:
# Select appropriate loader
if ext == ".pdf":
loader = PyPDFLoader(temp_path)
elif ext == ".docx":
loader = Docx2txtLoader(temp_path)
elif ext == ".md":
loader = UnstructuredMarkdownLoader(temp_path)
else: # Default to text loader
loader = TextLoader(temp_path)
# Load and split the document
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(documents)
# Convert to preview format
preview_chunks = [
TextChunk(
content=chunk.page_content,
metadata=chunk.metadata
)
for chunk in chunks
]
return PreviewResult(
chunks=preview_chunks,
total_chunks=len(chunks)
)
finally:
os.unlink(temp_path)
async def process_document_background(
temp_path: str,
file_name: str,
kb_id: int,
task_id: int,
db: Session = None,
chunk_size: int = 1000,
chunk_overlap: int = 200
) -> None:
"""Process document in background"""
logger = logging.getLogger(__name__)
logger.info(f"Starting background processing for task {task_id}, file: {file_name}")
# if we don't pass in db, create a new database session
if db is None:
db = SessionLocal()
should_close_db = True
else:
should_close_db = False
task = db.query(ProcessingTask).get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return
minio_client = None
local_temp_path = None
try:
logger.info(f"Task {task_id}: Setting status to processing")
task.status = "processing"
db.commit()
# 1. 从临时目录下载文件
minio_client = get_minio_client()
try:
local_temp_path = f"/tmp/temp_{task_id}_{file_name}" # 使用系统临时目录
logger.info(f"Task {task_id}: Downloading file from MinIO: {temp_path} to {local_temp_path}")
minio_client.fget_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path,
file_path=local_temp_path
)
logger.info(f"Task {task_id}: File downloaded successfully")
except MinioException as e:
# Idempotent fallback: temp object may already be consumed by another task.
# If the final document is already created, treat current task as completed.
if "NoSuchKey" in str(e) and task.document_upload:
existing_document = db.query(Document).filter(
Document.knowledge_base_id == kb_id,
Document.file_name == file_name,
Document.file_hash == task.document_upload.file_hash,
).first()
if existing_document:
logger.warning(
f"Task {task_id}: Temp object missing but document already exists, "
f"marking task as completed (document_id={existing_document.id})"
)
task.status = "completed"
task.document_id = existing_document.id
task.error_message = None
task.document_upload.status = "completed"
task.document_upload.error_message = None
db.commit()
return
error_msg = f"Failed to download temp file: {str(e)}"
logger.error(f"Task {task_id}: {error_msg}")
raise Exception(error_msg)
try:
# 2. 加载和分块文档
_, ext = os.path.splitext(file_name)
ext = ext.lower()
logger.info(f"Task {task_id}: Loading document with extension {ext}")
# 选择合适的加载器
if ext == ".pdf":
loader = PyPDFLoader(local_temp_path)
elif ext == ".docx":
loader = Docx2txtLoader(local_temp_path)
elif ext == ".md":
loader = UnstructuredMarkdownLoader(local_temp_path)
else: # 默认使用文本加载器
loader = TextLoader(local_temp_path)
logger.info(f"Task {task_id}: Loading document content")
documents = loader.load()
logger.info(f"Task {task_id}: Document loaded successfully")
logger.info(f"Task {task_id}: Splitting document into chunks")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(documents)
logger.info(f"Task {task_id}: Document split into {len(chunks)} chunks")
# 3. 创建向量存储
logger.info(f"Task {task_id}: Initializing vector store")
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# 4. 将临时文件移动到永久目录
permanent_path = f"kb_{kb_id}/{file_name}"
try:
logger.info(f"Task {task_id}: Moving file to permanent storage")
# 复制到永久目录
source = CopySource(settings.MINIO_BUCKET_NAME, temp_path)
minio_client.copy_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=permanent_path,
source=source
)
logger.info(f"Task {task_id}: File moved to permanent storage")
# 删除临时文件
logger.info(f"Task {task_id}: Removing temporary file from MinIO")
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path
)
logger.info(f"Task {task_id}: Temporary file removed")
except MinioException as e:
error_msg = f"Failed to move file to permanent storage: {str(e)}"
logger.error(f"Task {task_id}: {error_msg}")
raise Exception(error_msg)
# 5. 创建文档记录
logger.info(f"Task {task_id}: Creating document record")
document = Document(
file_name=file_name,
file_path=permanent_path,
file_hash=task.document_upload.file_hash,
file_size=task.document_upload.file_size,
content_type=task.document_upload.content_type,
knowledge_base_id=kb_id
)
db.add(document)
db.flush()
db.refresh(document)
logger.info(f"Task {task_id}: Document record created with ID {document.id}")
# 6. 存储文档块
logger.info(f"Task {task_id}: Storing document chunks")
for i, chunk in enumerate(chunks):
# 为每个 chunk 生成唯一的 ID
chunk_id = hashlib.sha256(
f"{kb_id}:{file_name}:{chunk.page_content}".encode()
).hexdigest()
metadata = _build_enriched_chunk_metadata(
source_metadata=chunk.metadata,
chunk_id=chunk_id,
file_name=file_name,
file_path=permanent_path,
kb_id=kb_id,
document_id=document.id,
chunk_index=i,
chunk_text=chunk.page_content,
)
chunk.metadata = metadata
doc_chunk = DocumentChunk(
id=chunk_id, # 添加 ID 字段
document_id=document.id,
kb_id=kb_id,
file_name=file_name,
chunk_metadata={
"page_content": chunk.page_content,
**metadata
},
hash=hashlib.sha256(
(chunk.page_content + str(metadata)).encode()
).hexdigest()
)
db.add(doc_chunk)
if i > 0 and i % 100 == 0:
logger.info(f"Task {task_id}: Stored {i} chunks")
db.flush()
# 7. 添加到向量存储
logger.info(f"Task {task_id}: Adding chunks to vector store")
vector_chunks = [
LangchainDocument(
page_content=chunk.page_content,
metadata=_sanitize_metadata_for_vector_store(chunk.metadata),
)
for chunk in chunks
]
vector_store.add_documents(vector_chunks)
# 移除 persist() 调用,因为新版本不需要
logger.info(f"Task {task_id}: Chunks added to vector store")
if settings.GRAPHRAG_ENABLED:
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
graph_adapter = GraphRAGAdapter()
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info(f"Task {task_id}: GraphRAG ingestion completed")
except Exception as graph_exc:
logger.error(f"Task {task_id}: GraphRAG ingestion failed: {graph_exc}")
# 8. 更新任务状态
logger.info(f"Task {task_id}: Updating task status to completed")
task.status = "completed"
task.document_id = document.id # 更新为新创建的文档ID
# 9. 更新上传记录状态
upload = task.document_upload # 直接通过关系获取
if upload:
logger.info(f"Task {task_id}: Updating upload record status to completed")
upload.status = "completed"
db.commit()
logger.info(f"Task {task_id}: Processing completed successfully")
finally:
# 清理本地临时文件
try:
if os.path.exists(local_temp_path):
logger.info(f"Task {task_id}: Cleaning up local temp file")
os.remove(local_temp_path)
logger.info(f"Task {task_id}: Local temp file cleaned up")
except Exception as e:
logger.warning(f"Task {task_id}: Failed to clean up local temp file: {str(e)}")
except Exception as e:
logger.error(f"Task {task_id}: Error processing document: {str(e)}")
logger.error(f"Task {task_id}: Stack trace: {traceback.format_exc()}")
db.rollback()
failed_task = db.query(ProcessingTask).get(task_id)
if failed_task:
failed_task.status = "failed"
failed_task.error_message = str(e)
if failed_task.document_upload:
failed_task.document_upload.status = "failed"
failed_task.document_upload.error_message = str(e)
db.commit()
# 清理临时文件
try:
logger.info(f"Task {task_id}: Cleaning up temporary file after error")
if minio_client is not None:
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path
)
logger.info(f"Task {task_id}: Temporary file cleaned up after error")
except:
logger.warning(f"Task {task_id}: Failed to clean up temporary file after error")
finally:
# if we create the db session, we need to close it
if should_close_db and db:
db.close()

View File

@@ -0,0 +1,46 @@
from app.core.config import settings
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
# If you plan on adding other embeddings, import them here
# from some_other_module import AnotherEmbeddingClass
class EmbeddingsFactory:
@staticmethod
def create():
"""
Factory method to create an embeddings instance based on .env config.
"""
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
if embeddings_provider == "openai":
return OpenAIEmbeddings(
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE,
model=settings.OPENAI_EMBEDDINGS_MODEL
)
elif embeddings_provider == "dashscope":
return OpenAIEmbeddings(
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
model=settings.DASH_SCOPE_EMBEDDINGS_MODEL,
# DashScope OpenAI-compatible embedding expects string input,
# while LangChain's len-safe path may send token ids.
check_embedding_ctx_length=False,
tiktoken_enabled=False,
skip_empty=True,
# DashScope embedding API supports at most 10 inputs per batch.
chunk_size=10,
)
elif embeddings_provider == "ollama":
return OllamaEmbeddings(
model=settings.OLLAMA_EMBEDDINGS_MODEL,
base_url=settings.OLLAMA_API_BASE
)
# Extend with other providers:
# elif embeddings_provider == "another_provider":
# return AnotherEmbeddingClass(...)
else:
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")

View File

@@ -0,0 +1,116 @@
"""Fusion RAG prompts for aerospace Chinese QA."""
ROUTER_SYSTEM_PROMPT = """
你是一个检索路由器。你的唯一任务是把用户请求分类到以下四类之一。
分类标签:
A: 通用对话路
- 适用:问候、寒暄、角色扮演、无须知识库支持的常识闲聊。
- 特征:没有明确的专业实体约束,也不依赖当前知识库文档。
B: 混合检索路 (Hybrid RAG)
- 适用:单实体事实查询、定义解释、时间/数值/指标问答。
- 特征:问题通常可由少量文本片段直接回答,核心是“找准证据”。
C: 局部图检索路 (Graph Local Search)
- 适用:实体关系、多跳因果、组件依赖、跨段落链式推理。
- 特征:问题包含“谁影响谁/为什么/如何传导/依赖链”。
D: 全局图检索路 (Graph Global Search)
- 适用:全局总结、趋势分析、跨系统比较、宏观评估。
- 特征:问题面向整个语料或多个主题社区,不是单点事实。
判定规则(按优先级):
1. 若请求明确是问候、寒暄、开放闲聊,判 A。
2. 若请求强调全局综述、趋势、横向比较,判 D。
3. 若请求强调实体关系、影响路径、多跳推理,判 C。
4. 其余知识查询默认判 B。
输出要求:
- 只能输出 JSON不要额外文本。
- 格式必须是:
{
"intent": "A/B/C/D",
"reason": "中文简要理由"
}
""".strip()
ROUTER_USER_PROMPT_TEMPLATE = """
请基于以下用户问题进行路由分类。
历史对话(可选):
{chat_history}
用户问题:
{query}
""".strip()
GENERAL_CHAT_PROMPT_TEMPLATE = """
你是中文航天问答助手。当前请求被路由为“通用对话路”。
请直接回答用户问题,要求:
- 简洁自然
- 不要伪造具体文献或数据来源
- 若涉及专业细节但无上下文支撑,请明确说明是一般性知识
用户问题:
{query}
""".strip()
HYBRID_RAG_PROMPT_TEMPLATE = """
你是航天领域事实问答助手。你会收到按相关性排序的文本证据片段,请严格基于证据作答。
要求:
1. 回答正文应自然连贯,不要使用“直接答案”“证据依据”等分节标题。
2. 关键信息需要有可追溯引用,引用编号使用 [1]、[2] 等格式。
3. 引用标号尽量集中放在回答末尾,不要在句中频繁插入。
4. 不得编造未在证据中出现的事实、时间、参数、型号。
5. 若证据不足,明确写:信息不足,缺少 xxx。
6. 输出中文,术语严谨,避免冗长。
问题:
{query}
证据片段:
{context}
""".strip()
GRAPH_LOCAL_PROMPT_TEMPLATE = """
你是航天知识图谱推理助手。你将获得一个局部子图上下文(实体、关系、证据)。
要求:
1. 输出结构固定为:
- 结论
- 推理链路
- 证据映射
- 不确定性
2. 推理链路需按步骤编号步骤1、步骤2...),明确“实体 -> 关系 -> 实体/结论”的链式过程。
3. 若局部子图不完整,必须指出断点,不能臆造链路。
4. 输出中文。
问题:
{query}
局部子图上下文:
{graph_context}
""".strip()
GRAPH_GLOBAL_PROMPT_TEMPLATE = """
你是航天领域全局分析助手。你将获得多个社区摘要,请进行跨社区综合研判。
要求:
1. 输出结构固定为:
- 总体结论
- 跨社区共性
- 关键差异
- 趋势判断
- 风险与建议
2. 每条关键判断尽量给出对应社区编号。
3. 仅依据输入摘要,证据不足时明确说明。
4. 输出中文,适合技术管理层阅读。
问题:
{query}
社区摘要:
{community_context}
""".strip()

View File

@@ -0,0 +1,3 @@
from app.services.graph.graphrag_adapter import GraphRAGAdapter
__all__ = ["GraphRAGAdapter"]

View File

@@ -0,0 +1,183 @@
import asyncio
import importlib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from app.core.config import settings
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
class GraphRAGAdapter:
_instance_lock = asyncio.Lock()
def __init__(self):
self._graphrag_instances: Dict[int, Any] = {}
self._kb_locks: Dict[int, asyncio.Lock] = {}
self._embedding_model = EmbeddingsFactory.create()
self._llm_model = LLMFactory.create(streaming=False)
self._symbols = self._load_symbols()
def _load_symbols(self) -> Dict[str, Any]:
module = importlib.import_module("nano_graphrag")
storage_module = importlib.import_module("nano_graphrag._storage")
utils_module = importlib.import_module("nano_graphrag._utils")
return {
"GraphRAG": module.GraphRAG,
"QueryParam": module.QueryParam,
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
}
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
if kb_id not in self._kb_locks:
self._kb_locks[kb_id] = asyncio.Lock()
return self._kb_locks[kb_id]
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
history_messages = history_messages or []
history_lines: List[str] = []
for item in history_messages:
if isinstance(item, dict):
role = str(item.get("role", "user"))
content = item.get("content", "")
if isinstance(content, list):
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
history_lines.append(f"{role}: {joined}")
else:
history_lines.append(f"{role}: {content}")
else:
history_lines.append(str(item))
full_prompt = "\n\n".join(
part
for part in [
f"系统提示: {system_prompt}" if system_prompt else "",
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
"用户输入:\n" + prompt,
]
if part
)
model = self._llm_model
max_tokens = kwargs.get("max_tokens")
if max_tokens is not None:
try:
model = model.bind(max_tokens=max_tokens)
except Exception:
pass
response = await model.ainvoke(full_prompt)
content = getattr(response, "content", response)
if isinstance(content, str):
return content
return str(content)
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
return np.array(vectors)
async def _get_or_create(self, kb_id: int) -> Any:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
async with GraphRAGAdapter._instance_lock:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
GraphRAG = self._symbols["GraphRAG"]
EmbeddingFunc = self._symbols["EmbeddingFunc"]
embedding_func = EmbeddingFunc(
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
func=self._embedding_call,
)
graph_storage_cls = self._symbols["NetworkXStorage"]
addon_params: Dict[str, Any] = {}
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
graph_storage_cls = self._symbols["Neo4jStorage"]
addon_params = {
"neo4j_url": settings.NEO4J_URL,
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
}
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
rag = GraphRAG(
working_dir=working_dir,
enable_local=True,
enable_naive_rag=True,
graph_storage_cls=graph_storage_cls,
addon_params=addon_params,
embedding_func=embedding_func,
best_model_func=self._llm_complete,
cheap_model_func=self._llm_complete,
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
)
self._graphrag_instances[kb_id] = rag
return rag
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
cleaned = [text.strip() for text in texts if text and text.strip()]
if not cleaned:
return
rag = await self._get_or_create(kb_id)
lock = self._get_kb_lock(kb_id)
async with lock:
await rag.ainsert(cleaned)
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="local",
top_k=top_k,
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="global",
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.global_context(kb_id, query, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids

View File

@@ -0,0 +1,85 @@
import re
from typing import Any, Dict, List
from app.services.vector_store.base import BaseVectorStore
def _tokenize_for_keyword_score(text: str) -> List[str]:
"""Simple multilingual tokenizer for lexical matching without extra dependencies."""
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
return [token for token in tokens if token.strip()]
def _keyword_score(query: str, doc_text: str) -> float:
query_terms = set(_tokenize_for_keyword_score(query))
doc_terms = set(_tokenize_for_keyword_score(doc_text))
if not query_terms or not doc_terms:
return 0.0
overlap = len(query_terms.intersection(doc_terms))
return overlap / max(1, len(query_terms))
def hybrid_search(
vector_store: BaseVectorStore,
query: str,
top_k: int = 6,
fetch_k: int = 20,
alpha: float = 0.65,
) -> List[Dict[str, Any]]:
"""
Hybrid retrieval via vector candidate generation + lexical reranking.
score = alpha * vector_rank_score + (1 - alpha) * keyword_score
"""
raw_results = vector_store.similarity_search_with_score(query, k=fetch_k)
if not raw_results:
return []
ranked: List[Dict[str, Any]] = []
total = len(raw_results)
for index, item in enumerate(raw_results):
if not isinstance(item, (tuple, list)) or len(item) < 1:
continue
doc = item[0]
if not hasattr(doc, "page_content"):
continue
rank_score = 1.0 - (index / max(1, total))
lexical_score = _keyword_score(query, doc.page_content)
final_score = alpha * rank_score + (1.0 - alpha) * lexical_score
ranked.append(
{
"document": doc,
"vector_rank_score": round(rank_score, 6),
"keyword_score": round(lexical_score, 6),
"final_score": round(final_score, 6),
}
)
ranked.sort(key=lambda row: row["final_score"], reverse=True)
return ranked[:top_k]
def format_hybrid_context(rows: List[Dict[str, Any]]) -> str:
parts: List[str] = []
for i, row in enumerate(rows, start=1):
doc = row["document"]
metadata = doc.metadata or {}
source = metadata.get("source") or metadata.get("file_name") or "unknown"
chunk_id = metadata.get("chunk_id") or "unknown"
parts.append(
(
f"[{i}] source={source}, chunk_id={chunk_id}, "
f"score={row['final_score']}\n"
f"{doc.page_content.strip()}"
)
)
return "\n\n".join(parts)

View File

@@ -0,0 +1,120 @@
import json
import re
from typing import Any, Dict, List
from app.services.fusion_prompts import (
ROUTER_SYSTEM_PROMPT,
ROUTER_USER_PROMPT_TEMPLATE,
)
VALID_INTENTS = {"A", "B", "C", "D"}
def _extract_json_object(raw_text: str) -> Dict[str, str]:
"""Extract and parse the first JSON object from model output."""
cleaned = raw_text.strip()
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
match = re.search(r"\{[\s\S]*\}", cleaned)
if not match:
raise ValueError("No JSON object found in router output")
data = json.loads(match.group(0))
if not isinstance(data, dict):
raise ValueError("Router output JSON is not an object")
intent = str(data.get("intent", "")).strip().upper()
reason = str(data.get("reason", "")).strip()
if intent not in VALID_INTENTS:
raise ValueError(f"Invalid intent: {intent}")
if not reason:
reason = "模型未提供理由,已按规则兜底。"
return {"intent": intent, "reason": reason}
def _build_history_text(messages: dict, max_turns: int = 6) -> str:
if not isinstance(messages, dict):
return ""
history = messages.get("messages", [])
if not isinstance(history, list):
return ""
tail = history[-max_turns:]
rows: List[str] = []
for msg in tail:
role = str(msg.get("role", "unknown")).strip()
content = str(msg.get("content", "")).strip().replace("\n", " ")
if content:
rows.append(f"{role}: {content}")
return "\n".join(rows)
def _heuristic_route(query: str) -> Dict[str, str]:
text = query.strip().lower()
general_chat_patterns = [
"你好",
"您好",
"在吗",
"谢谢",
"早上好",
"晚上好",
"你是谁",
"讲个笑话",
]
global_patterns = [
"总结",
"综述",
"整体",
"全局",
"趋势",
"对比",
"比较",
"宏观",
"共性",
"差异",
]
local_graph_patterns = [
"关系",
"依赖",
"影响",
"导致",
"原因",
"链路",
"多跳",
"传导",
"耦合",
"约束",
]
if any(token in text for token in general_chat_patterns):
return {"intent": "A", "reason": "命中通用对话关键词,且不依赖知识库检索。"}
if any(token in text for token in global_patterns):
return {"intent": "D", "reason": "问题指向全局总结或跨主题趋势分析。"}
if any(token in text for token in local_graph_patterns):
return {"intent": "C", "reason": "问题强调实体关系与链式推理。"}
return {"intent": "B", "reason": "默认归入事实查询,适合混合检索链路。"}
async def route_intent(llm: Any, query: str, messages: dict) -> Dict[str, str]:
"""Route user query to A/B/C/D with LLM-first and heuristic fallback."""
history_text = _build_history_text(messages)
user_prompt = ROUTER_USER_PROMPT_TEMPLATE.format(
chat_history=history_text or "",
query=query,
)
try:
full_prompt = f"{ROUTER_SYSTEM_PROMPT}\n\n{user_prompt}"
model_resp = await llm.ainvoke(full_prompt)
content = getattr(model_resp, "content", model_resp)
raw_text = content if isinstance(content, str) else str(content)
return _extract_json_object(raw_text)
except Exception:
return _heuristic_route(query)

View File

@@ -0,0 +1,57 @@
from typing import Optional
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from langchain_ollama import OllamaLLM
from app.core.config import settings
class LLMFactory:
@staticmethod
def create(
provider: Optional[str] = None,
temperature: float = 0,
streaming: bool = True,
) -> BaseChatModel:
"""
Create a LLM instance based on the provider
"""
# If no provider specified, use the one from settings
provider = provider or settings.CHAT_PROVIDER
if provider.lower() == "openai":
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.OPENAI_MODEL,
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE
)
elif provider.lower() == "deepseek":
return ChatDeepSeek(
temperature=temperature,
streaming=streaming,
model=settings.DEEPSEEK_MODEL,
api_key=settings.DEEPSEEK_API_KEY,
api_base=settings.DEEPSEEK_API_BASE
)
elif provider.lower() == "dashscope":
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.DASH_SCOPE_CHAT_MODEL,
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
)
elif provider.lower() == "ollama":
# Initialize Ollama model
return OllamaLLM(
model=settings.OLLAMA_MODEL,
base_url=settings.OLLAMA_API_BASE,
temperature=temperature,
streaming=streaming
)
# Add more providers here as needed
# elif provider.lower() == "anthropic":
# return ChatAnthropic(...)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -0,0 +1,3 @@
from app.services.reranker.external_api import ExternalRerankerClient
__all__ = ["ExternalRerankerClient"]

View File

@@ -0,0 +1,164 @@
import asyncio
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib import request
@dataclass
class ExternalRerankerClient:
api_url: str
api_key: str = ""
model: str = ""
timeout_seconds: float = 8.0
@property
def enabled(self) -> bool:
return bool(self.api_url)
@property
def is_dashscope_rerank(self) -> bool:
return "dashscope.aliyuncs.com" in self.api_url and "/services/rerank/" in self.api_url
async def rerank(
self,
*,
query: str,
documents: List[str],
top_n: Optional[int] = None,
metadata: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[float]]:
if not self.enabled:
return None
if not documents:
return []
payload = self._build_payload(
query=query,
documents=documents,
top_n=top_n or len(documents),
metadata=metadata,
)
try:
response = await asyncio.to_thread(self._post_json, payload)
scores = self._parse_scores(response, len(documents))
return scores
except Exception:
return None
def _post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
req = request.Request(
self.api_url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
def _build_payload(
self,
*,
query: str,
documents: List[str],
top_n: int,
metadata: Optional[List[Dict[str, Any]]],
) -> Dict[str, Any]:
if self.is_dashscope_rerank:
payload = {
"model": self.model,
"input": {
"query": query,
"documents": documents,
},
"parameters": {
"return_documents": True,
"top_n": top_n,
},
}
if metadata:
payload["metadata"] = metadata
return payload
payload = {
"model": self.model,
"query": query,
"documents": documents,
"top_n": top_n,
}
if metadata:
payload["metadata"] = metadata
return payload
def _parse_scores(self, response: Dict[str, Any], expected_len: int) -> List[float]:
# DashScope format:
# {"output": {"results": [{"index": 0, "relevance_score": 0.98}, ...]}}
output_block = response.get("output")
if isinstance(output_block, dict) and isinstance(output_block.get("results"), list):
raw_results = output_block["results"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
# Common response format #1:
# {"results": [{"index": 0, "relevance_score": 0.98}, ...]}
if isinstance(response.get("results"), list):
raw_results = response["results"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
# Common response format #2:
# {"scores": [0.9, 0.1, ...]}
if isinstance(response.get("scores"), list):
values = response["scores"]
scores: List[float] = []
for i in range(expected_len):
try:
scores.append(float(values[i]))
except Exception:
scores.append(0.0)
return scores
# Common response format #3:
# {"data": [{"index": 0, "score": 0.88}, ...]}
if isinstance(response.get("data"), list):
raw_results = response["data"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("score", item.get("relevance_score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
return [0.0] * expected_len

View File

@@ -0,0 +1,3 @@
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
__all__ = ["MultiKBRetriever", "format_retrieval_context"]

View File

@@ -0,0 +1,131 @@
import re
from typing import Any, Dict, List, Optional
from app.services.reranker.external_api import ExternalRerankerClient
def _tokenize(text: str) -> List[str]:
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
return [token for token in tokens if token.strip()]
def _keyword_score(query: str, text: str) -> float:
query_terms = set(_tokenize(query))
text_terms = set(_tokenize(text))
if not query_terms or not text_terms:
return 0.0
overlap = len(query_terms.intersection(text_terms))
return overlap / max(1, len(query_terms))
def format_retrieval_context(rows: List[Dict[str, Any]]) -> str:
blocks: List[str] = []
for i, row in enumerate(rows, start=1):
doc = row["document"]
metadata = doc.metadata or {}
blocks.append(
(
f"[{i}] kb_id={row.get('kb_id')}, source={metadata.get('source') or metadata.get('file_name') or 'unknown'}, "
f"chunk_id={metadata.get('chunk_id') or 'unknown'}, score={row.get('final_score', 0):.6f}\n"
f"{doc.page_content.strip()}"
)
)
return "\n\n".join(blocks)
class MultiKBRetriever:
def __init__(
self,
*,
reranker_client: Optional[ExternalRerankerClient] = None,
reranker_weight: float = 0.75,
vector_weight: float = 0.2,
keyword_weight: float = 0.05,
):
self.reranker_client = reranker_client
self.reranker_weight = reranker_weight
self.vector_weight = vector_weight
self.keyword_weight = keyword_weight
async def retrieve(
self,
*,
query: str,
kb_vector_stores: List[Dict[str, Any]],
fetch_k_per_kb: int = 12,
top_k: int = 12,
) -> List[Dict[str, Any]]:
candidates: List[Dict[str, Any]] = []
for kb_store in kb_vector_stores:
kb_id = kb_store["kb_id"]
vector_store = kb_store["store"]
raw = vector_store.similarity_search_with_score(query, k=fetch_k_per_kb)
total = len(raw)
for index, item in enumerate(raw):
if not isinstance(item, (tuple, list)) or not item:
continue
doc = item[0]
if not hasattr(doc, "page_content"):
continue
metadata = doc.metadata or {}
rank_score = 1.0 - (index / max(1, total))
lexical_score = _keyword_score(query, doc.page_content)
candidates.append(
{
"kb_id": kb_id,
"document": doc,
"chunk_key": f"{kb_id}:{metadata.get('chunk_id', index)}",
"vector_rank_score": round(rank_score, 6),
"keyword_score": round(lexical_score, 6),
}
)
if not candidates:
return []
# Dedupe by KB + chunk id to avoid repeated chunks from same collection.
unique_map: Dict[str, Dict[str, Any]] = {}
for row in candidates:
key = row["chunk_key"]
existing = unique_map.get(key)
if existing is None:
unique_map[key] = row
continue
if row["vector_rank_score"] > existing["vector_rank_score"]:
unique_map[key] = row
merged = list(unique_map.values())
merged.sort(key=lambda x: x["vector_rank_score"], reverse=True)
reranker_scores: Optional[List[float]] = None
if self.reranker_client is not None and self.reranker_client.enabled:
reranker_scores = await self.reranker_client.rerank(
query=query,
documents=[row["document"].page_content for row in merged],
top_n=min(top_k, len(merged)),
metadata=[{"kb_id": row["kb_id"]} for row in merged],
)
for idx, row in enumerate(merged):
base_score = (
self.vector_weight * row["vector_rank_score"]
+ self.keyword_weight * row["keyword_score"]
)
if reranker_scores is not None:
rerank_value = float(reranker_scores[idx])
final_score = self.reranker_weight * rerank_value + (1 - self.reranker_weight) * base_score
row["reranker_score"] = round(rerank_value, 6)
else:
final_score = base_score
row["reranker_score"] = None
row["final_score"] = round(final_score, 6)
merged.sort(key=lambda x: x["final_score"], reverse=True)
return merged[:top_k]

View File

@@ -0,0 +1,187 @@
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.tools.srs_reqs_qwen import get_srs_tool
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
db.commit()
payload = get_srs_tool().run(job.input_file_path)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for item in payload.get("requirements", []):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=item.get("title") or item["id"],
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"total_requirements": extraction.total_requirements,
"document_name": extraction.document_name,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id=job_id, error_message=str(exc))
finally:
db.close()
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def _parse_generated_at(value: Any) -> datetime:
if isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError:
return datetime.utcnow()
return datetime.utcnow()
def ensure_upload_path(job_id: int, file_name: str) -> Path:
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / file_name
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements: List[Dict[str, Any]] = []
for item in extraction.requirements:
requirements.append(
{
"id": item.requirement_uid,
"title": item.title,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"sortOrder": item.sort_order,
}
)
return {
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == extraction.id)
.all()
}
seen_ids = set()
for index, item in enumerate(updates):
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=item.get("title") or uid,
description=item.get("description") if item.get("description") is not None else "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
sort_order=int(item.get("sortOrder") or index),
)
db.add(req)
continue
req.title = item.get("title", req.title)
req.description = item.get("description", req.description)
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.sort_order = int(item.get("sortOrder", index))
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(updates)
extraction.statistics = {
"total": len(updates),
"by_type": _count_requirement_types(updates),
}
extraction.raw_output = {
"document_name": extraction.document_name,
"generated_at": extraction.generated_at.isoformat(),
"requirements": updates,
}
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
stats: Dict[str, int] = {}
for item in items:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats

View File

@@ -0,0 +1,3 @@
from app.services.testing_pipeline.pipeline import run_testing_pipeline
__all__ = ["run_testing_pipeline"]

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class ToolExecutionResult:
context: Dict[str, Any]
output_summary: str
fallback_used: bool = False
class TestingTool(ABC):
name: str
@abstractmethod
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
raise NotImplementedError

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
from time import perf_counter
from typing import Any, Dict, List, Optional
from uuid import uuid4
from app.services.llm.llm_factory import LLMFactory
from app.services.testing_pipeline.tools import build_default_tool_chain
def _build_input_summary(context: Dict[str, Any]) -> str:
req_text = str(context.get("user_requirement_text", "")).strip()
req_type = str(context.get("requirement_type_input", "")).strip() or "auto"
short_text = req_text if len(req_text) <= 60 else f"{req_text[:60]}..."
return f"requirement_type_input={req_type}; requirement_text={short_text}"
def _build_output_summary(context: Dict[str, Any]) -> str:
req_type_result = context.get("requirement_type_result", {})
req_type = req_type_result.get("requirement_type", "")
test_items = context.get("test_items", {})
test_cases = context.get("test_cases", {})
return (
f"requirement_type={req_type}; "
f"items={len(test_items.get('normal', [])) + len(test_items.get('abnormal', []))}; "
f"cases={len(test_cases.get('normal', [])) + len(test_cases.get('abnormal', []))}"
)
def run_testing_pipeline(
user_requirement_text: str,
requirement_type_input: Optional[str] = None,
debug: bool = False,
knowledge_context: Optional[str] = None,
use_model_generation: bool = False,
max_items_per_group: int = 12,
cases_per_item: int = 2,
max_focus_points: int = 6,
max_llm_calls: int = 10,
) -> Dict[str, Any]:
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False)
except Exception:
llm_model = None
context: Dict[str, Any] = {
"trace_id": str(uuid4()),
"user_requirement_text": user_requirement_text,
"requirement_type_input": requirement_type_input,
"debug": bool(debug),
"knowledge_context": (knowledge_context or "").strip(),
"knowledge_used": bool((knowledge_context or "").strip()),
"use_model_generation": bool(use_model_generation),
"llm_model": llm_model,
"max_items_per_group": max(4, min(int(max_items_per_group), 30)),
"cases_per_item": max(1, min(int(cases_per_item), 5)),
"max_focus_points": max(3, min(int(max_focus_points), 12)),
"llm_call_budget": max(0, min(int(max_llm_calls), 100)),
}
step_logs: List[Dict[str, Any]] = []
for tool in build_default_tool_chain():
start = perf_counter()
input_summary = _build_input_summary(context)
execution = tool.execute(context)
context = execution.context
duration_ms = (perf_counter() - start) * 1000
step_logs.append(
{
"step_name": tool.name,
"input_summary": input_summary,
"output_summary": execution.output_summary,
"success": True,
"fallback_used": execution.fallback_used,
"duration_ms": round(duration_ms, 3),
}
)
req_result = context.get("requirement_type_result", {})
return {
"trace_id": context.get("trace_id"),
"requirement_type": req_result.get("requirement_type", "未知类型"),
"reason": req_result.get("reason", ""),
"candidates": req_result.get("candidates", []),
"test_items": context.get("test_items", {"normal": [], "abnormal": []}),
"test_cases": context.get("test_cases", {"normal": [], "abnormal": []}),
"expected_results": context.get("expected_results", {"normal": [], "abnormal": []}),
"formatted_output": context.get("formatted_output", ""),
"pipeline_summary": _build_output_summary(context),
"knowledge_used": bool(context.get("knowledge_used", False)),
"step_logs": step_logs if debug else [],
}

View File

@@ -0,0 +1,203 @@
from __future__ import annotations
from typing import Dict, List
REQUIREMENT_TYPES: List[str] = [
"功能测试",
"性能测试",
"外部接口测试",
"人机交互界面测试",
"强度测试",
"余量测试",
"可靠性测试",
"安全性测试",
"恢复性测试",
"边界测试",
"安装性测试",
"互操作性测试",
"敏感性测试",
"测试充分性要求",
]
TYPE_SIGNAL_RULES: Dict[str, str] = {
"功能测试": "关注功能需求逐项验证、业务流程正确性、输入输出行为、状态转换与边界值处理。",
"性能测试": "关注处理精度、响应时间、处理数据量、系统协调性、负载潜力与运行占用空间。",
"外部接口测试": "关注外部输入输出接口的格式、内容、协议与正常/异常交互表现。",
"人机交互界面测试": "关注界面一致性、界面风格、操作流程、误操作健壮性与错误提示能力。",
"强度测试": "关注系统在极限、超负荷、饱和和降级条件下的稳定性与承受能力。",
"余量测试": "关注存储余量、输入输出通道余量、功能处理时间余量等资源裕度。",
"可靠性测试": "关注真实或仿真环境下的失效等级、运行剖面、输入覆盖和长期稳定运行能力。",
"安全性测试": "关注危险状态响应、安全关键部件、异常输入防护、非法访问阻断和数据完整性保护。",
"恢复性测试": "关注故障探测、备用切换、系统状态保护与从无错误状态继续执行能力。",
"边界测试": "关注输入输出域边界、状态转换端点、功能界限、性能界限与容量界限。",
"安装性测试": "关注不同配置下安装卸载流程和安装规程执行正确性。",
"互操作性测试": "关注多个软件并行运行时的互操作能力与协同正确性。",
"敏感性测试": "关注有效输入类中可能引发不稳定或不正常处理的数据组合。",
"测试充分性要求": "关注需求覆盖率、配置项覆盖、语句覆盖、分支覆盖及未覆盖分析确认。",
}
DECOMPOSE_FORCE_RULES: List[str] = [
"每个软件功能至少应被正常测试与被认可的异常场景覆盖;复杂功能需继续细分。",
"每个测试项必须语义完整、可直接执行。",
"覆盖必须包含:正常流程、边界条件(适用时)、异常条件。",
"粒度需适中,避免过粗或过细。",
"对未知类型必须执行通用分解,并保持正常/异常分组。",
"对需求说明未显式给出但在用户手册或操作手册体现的功能,也应补充测试项覆盖。",
]
REQUIREMENT_RULES: Dict[str, Dict[str, List[str]]] = {
"功能测试": {
"keywords": ["功能", "业务流程", "输入输出", "状态转换", "边界值"],
"normal": [
"正常覆盖功能主路径、基本数据类型、合法边界值与状态转换。",
],
"abnormal": [
"异常覆盖非法输入、不规则输入、非法边界值与最坏情况。",
],
},
"性能测试": {
"keywords": ["性能", "处理精度", "响应时间", "处理数据量", "负载", "占用空间"],
"normal": [
"正常覆盖处理精度、响应时间、处理数据量与模块协调性。",
],
"abnormal": [
"异常覆盖超负荷、软硬件限制、负载潜力上限与资源占用异常。",
],
},
"外部接口测试": {
"keywords": ["外部接口", "输入接口", "输出接口", "格式", "内容", "协议", "异常交互"],
"normal": [
"正常覆盖全部外部接口格式与内容正确性。",
],
"abnormal": [
"异常覆盖每个输入输出接口的错误格式、错误内容与异常交互。",
],
},
"人机交互界面测试": {
"keywords": ["界面", "风格", "交互", "误操作", "错误提示", "操作流程"],
"normal": [
"正常覆盖界面风格一致性与标准操作流程。",
],
"abnormal": [
"异常覆盖误操作、快速操作、非法输入、错误命令与错误流程提示。",
],
},
"强度测试": {
"keywords": ["强度", "极限", "超负荷", "饱和", "降级", "健壮性"],
"normal": [
"正常覆盖设计极限下系统功能和性能表现。",
],
"abnormal": [
"异常覆盖超出极限时的降级行为、健壮性与饱和表现。",
],
},
"余量测试": {
"keywords": ["余量", "存储余量", "通道余量", "处理时间余量", "资源裕度"],
"normal": [
"正常覆盖存储、通道、处理时间余量是否满足要求。",
],
"abnormal": [
"异常覆盖余量不足或耗尽时系统告警与受控行为。",
],
},
"可靠性测试": {
"keywords": ["可靠性", "运行剖面", "失效等级", "输入覆盖", "长期稳定"],
"normal": [
"正常覆盖典型环境、运行剖面与输入变量组合。",
],
"abnormal": [
"异常覆盖失效等级场景、边界环境变化、不合法输入域及失效记录。",
],
},
"安全性测试": {
"keywords": ["安全", "危险状态", "安全关键部件", "非法进入", "完整性", "防护"],
"normal": [
"正常覆盖安全关键部件、安全结构与合法操作路径。",
],
"abnormal": [
"异常覆盖危险状态、故障模式、边界接合部、非法进入与数据完整性保护。",
],
},
"恢复性测试": {
"keywords": ["恢复", "故障探测", "备用切换", "状态保护", "继续执行", "reset"],
"normal": [
"正常覆盖故障探测、备用切换、恢复后继续执行。",
],
"abnormal": [
"异常覆盖故障中作业保护、状态保护与恢复失败路径。",
],
},
"边界测试": {
"keywords": ["边界", "端点", "输入输出域", "状态转换", "性能界限", "容量界限"],
"normal": [
"正常覆盖输入输出域边界、状态转换端点与功能界限。",
],
"abnormal": [
"异常覆盖性能界限、容量界限和越界端点。",
],
},
"安装性测试": {
"keywords": ["安装", "卸载", "配置", "安装规程", "部署", "中断"],
"normal": [
"正常覆盖标准及不同配置下安装卸载流程。",
],
"abnormal": [
"异常覆盖安装规程错误、依赖异常与中断后的处理。",
],
},
"互操作性测试": {
"keywords": ["互操作", "并行运行", "协同", "兼容", "冲突", "互操作失败"],
"normal": [
"正常覆盖两个或多个软件同时运行与互操作过程。",
],
"abnormal": [
"异常覆盖互操作失败、并行冲突与协同异常。",
],
},
"敏感性测试": {
"keywords": ["敏感性", "输入类", "数据组合", "不稳定", "不正常处理"],
"normal": [
"正常覆盖有效输入类中典型数据组合。",
],
"abnormal": [
"异常覆盖引发不稳定或不正常处理的特殊数据组合。",
],
},
"测试充分性要求": {
"keywords": ["测试充分性", "需求覆盖率", "配置项覆盖", "语句覆盖", "分支覆盖", "未覆盖分析"],
"normal": [
"正常覆盖需求覆盖率、配置项覆盖与代码覆盖达标。",
],
"abnormal": [
"异常覆盖未覆盖部分逐项分析、确认与报告输出。",
],
},
}
GENERIC_DECOMPOSITION_RULES: Dict[str, List[str]] = {
"normal": [
"主流程正确性。",
"合法边界值。",
"标准输入输出。",
],
"abnormal": [
"非法输入。",
"越界输入。",
"资源异常或状态冲突。",
],
}
EXPECTED_RESULT_PLACEHOLDER_MAP: Dict[str, str] = {
"{{return_value}}": "接口或函数返回值验证。",
"{{state_change}}": "系统状态变化验证。",
"{{error_message}}": "异常场景错误信息验证。",
"{{data_persistence}}": "数据库或存储落库结果验证。",
"{{ui_display}}": "界面显示反馈验证。",
}

View File

@@ -0,0 +1,867 @@
from __future__ import annotations
import json
import re
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from app.services.testing_pipeline.base import TestingTool, ToolExecutionResult
from app.services.testing_pipeline.rules import (
DECOMPOSE_FORCE_RULES,
EXPECTED_RESULT_PLACEHOLDER_MAP,
GENERIC_DECOMPOSITION_RULES,
REQUIREMENT_RULES,
REQUIREMENT_TYPES,
TYPE_SIGNAL_RULES,
)
def _clean_text(value: str) -> str:
return " ".join((value or "").replace("\n", " ").split())
def _truncate_text(value: str, max_len: int = 2000) -> str:
text = _clean_text(value)
if len(text) <= max_len:
return text
return f"{text[:max_len]}..."
def _safe_int(value: Any, default: int, low: int, high: int) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(low, min(parsed, high))
def _strip_instruction_prefix(value: str) -> str:
text = _clean_text(value)
if not text:
return text
lowered = text.lower()
if lowered.startswith("/testing"):
text = _clean_text(text[len("/testing") :])
prefixes = [
"为以下需求生成测试用例",
"根据以下需求生成测试用例",
"请根据以下需求生成测试用例",
"请根据需求生成测试用例",
"请生成测试用例",
"生成测试用例",
]
for prefix in prefixes:
if text.startswith(prefix):
for sep in ("", ":"):
idx = text.find(sep)
if idx != -1:
text = _clean_text(text[idx + 1 :])
break
else:
text = _clean_text(text[len(prefix) :])
break
pattern = re.compile(r"^(请)?(根据|按|基于).{0,40}(需求|场景).{0,30}(生成|输出).{0,20}(测试项|测试用例)[:]")
matched = pattern.match(text)
if matched:
text = _clean_text(text[matched.end() :])
return text
def _extract_focus_points(value: str, max_points: int = 6) -> List[str]:
text = _strip_instruction_prefix(value)
if not text:
return []
parts = [_clean_text(part) for part in re.split(r"[,。;;]", text)]
parts = [part for part in parts if part]
ignored_tokens = ["生成测试用例", "测试项分解", "测试用例生成", "以下需求"]
filtered = [
part
for part in parts
if len(part) >= 4 and not any(token in part for token in ignored_tokens)
]
if not filtered:
filtered = parts
priority_keywords = [
"启停",
"开启",
"关闭",
"远程控制",
"保护",
"联动",
"状态",
"故障",
"恢复",
"切换",
"告警",
"模式",
"边界",
"时序",
]
priority = [part for part in filtered if any(keyword in part for keyword in priority_keywords)]
candidates = priority if priority else filtered
unique: List[str] = []
for part in candidates:
if part not in unique:
unique.append(part)
return unique[:max_points]
def _build_type_scores(text: str) -> Dict[str, int]:
scores: Dict[str, int] = {}
lowered = text.lower()
for req_type, rule in REQUIREMENT_RULES.items():
score = 0
if req_type in text:
score += 5
for keyword in rule.get("keywords", []):
if keyword.lower() in lowered:
score += 2
scores[req_type] = score
return scores
def _top_candidates(scores: Dict[str, int], top_n: int = 3) -> List[str]:
sorted_pairs = sorted(scores.items(), key=lambda pair: pair[1], reverse=True)
non_zero = [name for name, score in sorted_pairs if score > 0]
if non_zero:
return non_zero[:top_n]
return ["功能测试", "边界测试", "性能测试"][:top_n]
def _message_to_text(value: Any) -> str:
content = getattr(value, "content", value)
if isinstance(content, str):
return content
if isinstance(content, list):
chunks: List[str] = []
for item in content:
if isinstance(item, str):
chunks.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
chunks.append(text)
else:
chunks.append(str(item))
return "".join(chunks)
return str(content)
def _extract_json_object(value: str) -> Optional[Dict[str, Any]]:
text = (value or "").strip()
if not text:
return None
if text.startswith("```"):
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
if text.endswith("```"):
text = text[:-3].strip()
try:
data = json.loads(text)
if isinstance(data, dict):
return data
except Exception:
pass
start = text.find("{")
if start == -1:
return None
depth = 0
for idx in range(start, len(text)):
ch = text[idx]
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
fragment = text[start : idx + 1]
try:
data = json.loads(fragment)
if isinstance(data, dict):
return data
except Exception:
return None
return None
def _invoke_llm_json(context: Dict[str, Any], prompt: str) -> Optional[Dict[str, Any]]:
model = context.get("llm_model")
if model is None or not context.get("use_model_generation"):
return None
budget = context.get("llm_call_budget")
if isinstance(budget, int):
if budget <= 0:
return None
context["llm_call_budget"] = budget - 1
try:
response = model.invoke(prompt)
text = _message_to_text(response)
return _extract_json_object(text)
except Exception:
return None
def _invoke_llm_text(context: Dict[str, Any], prompt: str) -> str:
model = context.get("llm_model")
if model is None or not context.get("use_model_generation"):
return ""
budget = context.get("llm_call_budget")
if isinstance(budget, int):
if budget <= 0:
return ""
context["llm_call_budget"] = budget - 1
try:
response = model.invoke(prompt)
return _clean_text(_message_to_text(response))
except Exception:
return ""
def _normalize_item_entry(item: Any) -> Optional[Dict[str, Any]]:
if isinstance(item, str):
content = _clean_text(item)
if not content:
return None
return {"content": content, "coverage_tags": []}
if isinstance(item, dict):
content = _clean_text(str(item.get("content", "")))
if not content:
return None
tags = item.get("coverage_tags") or item.get("covered_points") or []
if not isinstance(tags, list):
tags = [str(tags)]
tags = [_clean_text(str(tag)) for tag in tags if _clean_text(str(tag))]
return {"content": content, "coverage_tags": tags}
return None
def _dedupe_items(items: List[Dict[str, Any]], max_items: int) -> List[Dict[str, Any]]:
merged: Dict[str, Dict[str, Any]] = {}
for item in items:
content = _clean_text(item.get("content", ""))
if not content:
continue
existing = merged.get(content)
if existing is None:
merged[content] = {
"content": content,
"coverage_tags": list(item.get("coverage_tags") or []),
}
else:
existing_tags = set(existing.get("coverage_tags") or [])
for tag in item.get("coverage_tags") or []:
if tag and tag not in existing_tags:
existing_tags.add(tag)
existing["coverage_tags"] = list(existing_tags)
deduped = list(merged.values())
return deduped[:max_items]
def _pick_expected_result_placeholder(content: str, abnormal: bool) -> str:
text = content or ""
if abnormal or any(token in text for token in ["非法", "异常", "错误", "拒绝", "越界", "失败"]):
return "{{error_message}}"
if any(token in text for token in ["状态", "切换", "转换", "恢复"]):
return "{{state_change}}"
if any(token in text for token in ["数据库", "存储", "落库", "持久化"]):
return "{{data_persistence}}"
if any(token in text for token in ["界面", "UI", "页面", "按钮", "提示"]):
return "{{ui_display}}"
return "{{return_value}}"
class IdentifyRequirementTypeTool(TestingTool):
name = "identify-requirement-type"
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
raw_text = _clean_text(context.get("user_requirement_text", ""))
text = _strip_instruction_prefix(raw_text)
if not text:
text = raw_text
max_focus_points = _safe_int(context.get("max_focus_points"), 6, 3, 12)
provided_type = _clean_text(context.get("requirement_type_input", ""))
focus_points = _extract_focus_points(text, max_points=max_focus_points)
fallback_used = False
if provided_type in REQUIREMENT_TYPES:
result = {
"requirement_type": provided_type,
"reason": "用户已显式指定需求类型,系统按指定类型执行。",
"candidates": [],
"scores": {},
"secondary_types": [],
}
else:
scores = _build_type_scores(text)
sorted_pairs = sorted(scores.items(), key=lambda pair: pair[1], reverse=True)
best_type, best_score = sorted_pairs[0]
secondary = [name for name, score in sorted_pairs[1:4] if score > 0]
if best_score <= 0:
fallback_used = True
candidates = _top_candidates(scores)
result = {
"requirement_type": "未知类型",
"reason": "未命中明确分类规则,已回退到未知类型并提供最接近候选。",
"candidates": candidates,
"scores": scores,
"secondary_types": [],
}
else:
signal = TYPE_SIGNAL_RULES.get(best_type, "")
result = {
"requirement_type": best_type,
"reason": f"命中{best_type}识别信号。{signal}",
"candidates": [],
"scores": scores,
"secondary_types": secondary,
}
context["requirement_type_result"] = result
context["normalized_requirement_text"] = text
context["requirement_focus_points"] = focus_points
context["knowledge_used"] = bool(context.get("knowledge_context"))
return ToolExecutionResult(
context=context,
output_summary=(
f"type={result['requirement_type']}; candidates={len(result['candidates'])}; "
f"secondary_types={len(result.get('secondary_types', []))}; focus_points={len(focus_points)}"
),
fallback_used=fallback_used,
)
class DecomposeTestItemsTool(TestingTool):
name = "decompose-test-items"
@staticmethod
def _seed_items(
req_type: str,
req_text: str,
focus_points: List[str],
max_items: int,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
if req_type in REQUIREMENT_RULES:
source_rules = REQUIREMENT_RULES[req_type]
normal_templates = list(source_rules.get("normal", []))
abnormal_templates = list(source_rules.get("abnormal", []))
else:
normal_templates = list(GENERIC_DECOMPOSITION_RULES["normal"])
abnormal_templates = list(GENERIC_DECOMPOSITION_RULES["abnormal"])
normal: List[Dict[str, Any]] = []
abnormal: List[Dict[str, Any]] = []
for template in normal_templates:
normal.append({"content": template, "coverage_tags": [req_type]})
for template in abnormal_templates:
abnormal.append({"content": template, "coverage_tags": [req_type]})
for point in focus_points:
normal.extend(
[
{
"content": f"验证{point}在标准作业流程下稳定执行且结果符合业务约束。",
"coverage_tags": [point, "正常流程"],
},
{
"content": f"验证{point}与相关联动控制、状态同步和回执反馈的一致性。",
"coverage_tags": [point, "联动一致性"],
},
]
)
abnormal.extend(
[
{
"content": f"验证{point}在非法输入、错误指令或权限异常时的保护与拒绝机制。",
"coverage_tags": [point, "异常输入"],
},
{
"content": f"验证{point}在边界条件、时序冲突或设备故障下的告警和恢复行为。",
"coverage_tags": [point, "边界异常"],
},
]
)
if any(token in req_text for token in ["手册", "操作手册", "用户手册", "作业指导"]):
normal.append(
{
"content": "验证需求说明未显式给出但在用户手册或操作手册体现的功能流程。",
"coverage_tags": ["手册功能"],
}
)
return _dedupe_items(normal, max_items), _dedupe_items(abnormal, max_items)
@staticmethod
def _generate_by_llm(context: Dict[str, Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
req_result = context.get("requirement_type_result", {})
req_type = req_result.get("requirement_type", "未知类型")
req_text = context.get("normalized_requirement_text", "")
focus_points = context.get("requirement_focus_points", [])
max_items = _safe_int(context.get("max_items_per_group"), 12, 4, 30)
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=2500)
prompt = f"""
你是资深测试分析师。请根据需求、分解规则和知识库片段,生成尽可能覆盖要点的测试项。
需求文本:{req_text}
需求类型:{req_type}
需求要点:{focus_points}
知识库片段:{knowledge_context or ''}
分解约束:
1. 正常测试与异常测试必须分组输出。
2. 每条测试项必须可执行、可验证,避免模板化空话。
3. 尽可能覆盖全部需求要点每组建议输出6-{max_items}条。
4. 优先生成与需求对象/控制逻辑/异常处理/边界条件强相关的测试项。
请仅输出 JSON 对象,结构如下:
{{
"normal_test_items": [
{{"content": "...", "coverage_tags": ["..."]}}
],
"abnormal_test_items": [
{{"content": "...", "coverage_tags": ["..."]}}
]
}}
""".strip()
data = _invoke_llm_json(context, prompt)
if not data:
return [], []
normal_raw = data.get("normal_test_items", [])
abnormal_raw = data.get("abnormal_test_items", [])
normal: List[Dict[str, Any]] = []
abnormal: List[Dict[str, Any]] = []
for item in normal_raw if isinstance(normal_raw, list) else []:
normalized = _normalize_item_entry(item)
if normalized:
normal.append(normalized)
for item in abnormal_raw if isinstance(abnormal_raw, list) else []:
normalized = _normalize_item_entry(item)
if normalized:
abnormal.append(normalized)
return _dedupe_items(normal, max_items), _dedupe_items(abnormal, max_items)
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
req_result = context.get("requirement_type_result", {})
req_type = req_result.get("requirement_type", "未知类型")
req_text = context.get("normalized_requirement_text") or _strip_instruction_prefix(
context.get("user_requirement_text", "")
)
focus_points = context.get("requirement_focus_points", [])
max_items = _safe_int(context.get("max_items_per_group"), 12, 4, 30)
seeded_normal, seeded_abnormal = self._seed_items(req_type, req_text, focus_points, max_items)
llm_normal, llm_abnormal = self._generate_by_llm(context)
merged_normal = _dedupe_items(llm_normal + seeded_normal, max_items)
merged_abnormal = _dedupe_items(llm_abnormal + seeded_abnormal, max_items)
fallback_used = not bool(llm_normal or llm_abnormal)
normal_items: List[Dict[str, Any]] = []
abnormal_items: List[Dict[str, Any]] = []
for idx, item in enumerate(merged_normal, start=1):
normal_items.append(
{
"id": f"N{idx}",
"content": item["content"],
"coverage_tags": item.get("coverage_tags", []),
}
)
for idx, item in enumerate(merged_abnormal, start=1):
abnormal_items.append(
{
"id": f"E{idx}",
"content": item["content"],
"coverage_tags": item.get("coverage_tags", []),
}
)
context["test_items"] = {
"normal": normal_items,
"abnormal": abnormal_items,
}
context["decompose_force_rules"] = DECOMPOSE_FORCE_RULES
return ToolExecutionResult(
context=context,
output_summary=(
f"normal_items={len(normal_items)}; abnormal_items={len(abnormal_items)}; "
f"llm_items={len(llm_normal) + len(llm_abnormal)}"
),
fallback_used=fallback_used,
)
class GenerateTestCasesTool(TestingTool):
name = "generate-test-cases"
@staticmethod
def _build_fallback_steps(item_content: str, abnormal: bool, variant: str) -> List[str]:
if abnormal:
return [
"确认测试前置环境、设备状态与日志采集开关已准备就绪。",
f"准备异常场景“{variant}”所需的输入数据、操作账号和触发条件。",
f"在目标对象执行异常触发操作,重点验证:{item_content}",
"持续观察系统返回码、错误文案、告警信息与日志链路完整性。",
"检查保护机制是否生效,包括拒绝策略、回滚行为和状态一致性。",
"记录证据并复位环境,确认异常处理后系统可恢复到稳定状态。",
]
return [
"确认测试环境、设备连接状态和前置业务数据均已初始化。",
f"准备“{variant}”所需输入参数、操作路径和判定阈值。",
f"在目标对象执行业务控制流程,重点验证:{item_content}",
"校验关键返回值、状态变化、控制回执及界面或接口反馈结果。",
"检查联动模块、日志记录和数据落库是否满足一致性要求。",
"沉淀测试证据并恢复环境,确保后续用例可重复执行。",
]
def _generate_cases_by_llm(
self,
context: Dict[str, Any],
item: Dict[str, Any],
abnormal: bool,
cases_per_item: int,
) -> List[Dict[str, Any]]:
req_text = context.get("normalized_requirement_text", "")
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=1800)
prompt = f"""
你是资深测试工程师。请围绕给定测试项生成详细测试用例。
需求:{req_text}
测试项:{item.get('content', '')}
测试类型:{'异常测试' if abnormal else '正常测试'}
知识库片段:{knowledge_context or ''}
要求:
1. 生成 {cases_per_item}-{max(cases_per_item + 1, cases_per_item)} 条测试用例。
2. 每条用例包含 test_content 与 operation_steps。
3. operation_steps 必须详细至少5步包含前置、执行、观察、校验与证据留存。
4. 内容必须围绕当前测试项,不要输出空洞模板。
仅输出 JSON
{{
"test_cases": [
{{
"title": "...",
"test_content": "...",
"operation_steps": ["...", "..."]
}}
]
}}
""".strip()
data = _invoke_llm_json(context, prompt)
if not data:
return []
raw_cases = data.get("test_cases", [])
if not isinstance(raw_cases, list):
return []
normalized_cases: List[Dict[str, Any]] = []
for case in raw_cases:
if not isinstance(case, dict):
continue
test_content = _clean_text(str(case.get("test_content", "")))
if not test_content:
continue
steps = case.get("operation_steps", [])
if not isinstance(steps, list):
continue
cleaned_steps = [_clean_text(str(step)) for step in steps if _clean_text(str(step))]
if len(cleaned_steps) < 5:
continue
normalized_cases.append(
{
"title": _clean_text(str(case.get("title", ""))),
"test_content": test_content,
"operation_steps": cleaned_steps,
}
)
return normalized_cases[: max(1, cases_per_item)]
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
test_items = context.get("test_items", {})
cases_per_item = _safe_int(context.get("cases_per_item"), 2, 1, 5)
normal_cases: List[Dict[str, Any]] = []
abnormal_cases: List[Dict[str, Any]] = []
llm_case_count = 0
for item in test_items.get("normal", []):
generated = self._generate_cases_by_llm(context, item, abnormal=False, cases_per_item=cases_per_item)
if not generated:
generated = [
{
"title": "标准流程验证",
"test_content": f"验证{item['content']}",
"operation_steps": self._build_fallback_steps(item["content"], False, "标准流程"),
},
{
"title": "边界与联动验证",
"test_content": f"验证{item['content']}在边界条件和联动场景下的稳定性",
"operation_steps": self._build_fallback_steps(item["content"], False, "边界与联动"),
},
][:cases_per_item]
else:
llm_case_count += len(generated)
for idx, case in enumerate(generated, start=1):
merged_content = _clean_text(case.get("test_content", item["content"]))
placeholder = _pick_expected_result_placeholder(merged_content, abnormal=False)
normal_cases.append(
{
"id": f"{item['id']}-C{idx}",
"item_id": item["id"],
"title": _clean_text(case.get("title", "")),
"operation_steps": case.get("operation_steps", []),
"test_content": merged_content,
"expected_result_placeholder": placeholder,
}
)
for item in test_items.get("abnormal", []):
generated = self._generate_cases_by_llm(context, item, abnormal=True, cases_per_item=cases_per_item)
if not generated:
generated = [
{
"title": "非法输入与权限异常验证",
"test_content": f"验证{item['content']}在非法输入与权限异常下的处理表现",
"operation_steps": self._build_fallback_steps(item["content"], True, "非法输入与权限异常"),
},
{
"title": "故障与时序冲突验证",
"test_content": f"验证{item['content']}在故障和时序冲突场景下的保护行为",
"operation_steps": self._build_fallback_steps(item["content"], True, "故障与时序冲突"),
},
][:cases_per_item]
else:
llm_case_count += len(generated)
for idx, case in enumerate(generated, start=1):
merged_content = _clean_text(case.get("test_content", item["content"]))
placeholder = _pick_expected_result_placeholder(merged_content, abnormal=True)
abnormal_cases.append(
{
"id": f"{item['id']}-C{idx}",
"item_id": item["id"],
"title": _clean_text(case.get("title", "")),
"operation_steps": case.get("operation_steps", []),
"test_content": merged_content,
"expected_result_placeholder": placeholder,
}
)
context["test_cases"] = {
"normal": normal_cases,
"abnormal": abnormal_cases,
}
return ToolExecutionResult(
context=context,
output_summary=(
f"normal_cases={len(normal_cases)}; abnormal_cases={len(abnormal_cases)}; llm_cases={llm_case_count}"
),
fallback_used=llm_case_count == 0,
)
class BuildExpectedResultsTool(TestingTool):
name = "build_expected_results"
def _expected_for_case(self, context: Dict[str, Any], case: Dict[str, Any], abnormal: bool) -> str:
placeholder = case.get("expected_result_placeholder", "{{return_value}}")
if placeholder not in EXPECTED_RESULT_PLACEHOLDER_MAP:
placeholder = "{{return_value}}"
req_text = context.get("normalized_requirement_text", "")
knowledge_context = _truncate_text(context.get("knowledge_context", ""), max_len=1200)
prompt = f"""
请基于以下信息生成一条可验证、可度量的测试预期结果,避免模板化空话。
需求:{req_text}
测试内容:{case.get('test_content', '')}
测试类型:{'异常测试' if abnormal else '正常测试'}
占位符语义:{placeholder} -> {EXPECTED_RESULT_PLACEHOLDER_MAP.get(placeholder, '')}
知识库片段:{knowledge_context or ''}
输出要求:
1. 仅输出一句中文预期结果。
2. 结果必须可判定成功/失败。
3. 包含关键观测项(返回值、状态、告警、日志、数据一致性中的相关项)。
""".strip()
llm_text = _invoke_llm_text(context, prompt)
if llm_text:
return _truncate_text(llm_text, max_len=220)
test_content = _clean_text(case.get("test_content", ""))
if placeholder == "{{error_message}}":
return f"触发{test_content}后,系统应返回明确错误码与错误文案,拒绝非法请求且核心状态保持一致。"
if placeholder == "{{state_change}}":
return f"执行{test_content}后,系统状态转换应符合需求定义,状态变化可被日志与回执共同验证。"
if placeholder == "{{data_persistence}}":
return f"执行{test_content}后,数据库或存储层应产生符合约束的持久化结果且无脏数据。"
if placeholder == "{{ui_display}}":
return f"执行{test_content}后,界面应展示与控制结果一致的反馈信息且提示可被用户执行。"
if abnormal:
return f"执行异常场景“{test_content}”后,系统应触发保护策略并输出可追溯日志,业务状态保持可恢复。"
return f"执行“{test_content}”后,返回值与状态变化应满足需求约束,关键结果可通过日志或回执验证。"
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
test_cases = context.get("test_cases", {})
normal_expected: List[Dict[str, str]] = []
abnormal_expected: List[Dict[str, str]] = []
for case in test_cases.get("normal", []):
normal_expected.append(
{
"id": case["id"],
"case_id": case["id"],
"result": self._expected_for_case(context, case, abnormal=False),
}
)
for case in test_cases.get("abnormal", []):
abnormal_expected.append(
{
"id": case["id"],
"case_id": case["id"],
"result": self._expected_for_case(context, case, abnormal=True),
}
)
context["expected_results"] = {
"normal": normal_expected,
"abnormal": abnormal_expected,
}
return ToolExecutionResult(
context=context,
output_summary=(
f"normal_expected={len(normal_expected)}; abnormal_expected={len(abnormal_expected)}"
),
)
class FormatOutputTool(TestingTool):
name = "format_output"
@staticmethod
def _format_case_block(case: Dict[str, Any], index: int) -> List[str]:
item_id = case.get("item_id", case.get("id", ""))
title = _clean_text(case.get("title", ""))
block: List[str] = []
block.append(f"{index}. [用例 {case['id']}](对应测试项 {item_id}{case.get('test_content', '')}")
if title:
block.append(f" 场景标题:{title}")
block.append(" 操作步骤:")
for step_idx, step in enumerate(case.get("operation_steps", []), start=1):
block.append(f" {step_idx}) {step}")
return block
def execute(self, context: Dict[str, Any]) -> ToolExecutionResult:
test_items = context.get("test_items", {"normal": [], "abnormal": []})
test_cases = context.get("test_cases", {"normal": [], "abnormal": []})
expected_results = context.get("expected_results", {"normal": [], "abnormal": []})
lines: List[str] = []
lines.append("**测试项**")
lines.append("")
lines.append("**正常测试**")
for index, item in enumerate(test_items.get("normal", []), start=1):
lines.append(f"{index}. [测试项 {item['id']}]{item['content']}")
lines.append("")
lines.append("**异常测试**")
for index, item in enumerate(test_items.get("abnormal", []), start=1):
lines.append(f"{index}. [测试项 {item['id']}]{item['content']}")
lines.append("")
lines.append("**测试用例**")
lines.append("")
lines.append("**正常测试**")
for index, case in enumerate(test_cases.get("normal", []), start=1):
lines.extend(self._format_case_block(case, index))
lines.append("")
lines.append("**异常测试**")
for index, case in enumerate(test_cases.get("abnormal", []), start=1):
lines.extend(self._format_case_block(case, index))
lines.append("")
lines.append("**预期成果**")
lines.append("")
lines.append("**正常测试**")
for index, expected in enumerate(expected_results.get("normal", []), start=1):
lines.append(
f"{index}. [预期 {expected['id']}](对应用例 {expected['case_id']}{expected['result']}"
)
lines.append("")
lines.append("**异常测试**")
for index, expected in enumerate(expected_results.get("abnormal", []), start=1):
lines.append(
f"{index}. [预期 {expected['id']}](对应用例 {expected['case_id']}{expected['result']}"
)
context["formatted_output"] = "\n".join(lines)
context["structured_output"] = {
"test_items": test_items,
"test_cases": test_cases,
"expected_results": expected_results,
}
return ToolExecutionResult(
context=context,
output_summary="formatted_sections=3",
)
def build_default_tool_chain() -> List[TestingTool]:
return [
IdentifyRequirementTypeTool(),
DecomposeTestItemsTool(),
GenerateTestCasesTool(),
BuildExpectedResultsTool(),
FormatOutputTool(),
]

View File

@@ -0,0 +1,122 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
@dataclass
class ChunkVectorMetadata:
"""Metadata payload for vector DB and graph linkage."""
chunk_id: str
kb_id: int
document_id: int
document_name: str
document_path: str
chunk_index: int
chunk_text: str
token_count: int
language: str = "zh"
source_type: str = "document"
mission_phase: Optional[str] = None
section_title: Optional[str] = None
publish_time: Optional[str] = None
extracted_entities: List[str] = field(default_factory=list)
extracted_entity_types: List[str] = field(default_factory=list)
extracted_relations: List[Dict[str, Any]] = field(default_factory=list)
graph_node_ids: List[str] = field(default_factory=list)
graph_edge_ids: List[str] = field(default_factory=list)
community_ids: List[str] = field(default_factory=list)
embedding_model: str = ""
embedding_dim: int = 0
ingest_time: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
def to_payload(self) -> Dict[str, Any]:
return {
"chunk_id": self.chunk_id,
"kb_id": self.kb_id,
"document_id": self.document_id,
"document_name": self.document_name,
"document_path": self.document_path,
"chunk_index": self.chunk_index,
"chunk_text": self.chunk_text,
"token_count": self.token_count,
"language": self.language,
"source_type": self.source_type,
"mission_phase": self.mission_phase,
"section_title": self.section_title,
"publish_time": self.publish_time,
"extracted_entities": self.extracted_entities,
"extracted_entity_types": self.extracted_entity_types,
"extracted_relations": self.extracted_relations,
"graph_node_ids": self.graph_node_ids,
"graph_edge_ids": self.graph_edge_ids,
"community_ids": self.community_ids,
"embedding_model": self.embedding_model,
"embedding_dim": self.embedding_dim,
"ingest_time": self.ingest_time,
}
def qdrant_collection_schema(collection_name: str, vector_size: int) -> Dict[str, Any]:
"""Qdrant collection and payload index recommendations."""
return {
"collection_name": collection_name,
"vectors": {
"size": vector_size,
"distance": "Cosine",
},
"payload_indexes": [
{"field_name": "kb_id", "field_schema": "integer"},
{"field_name": "document_id", "field_schema": "integer"},
{"field_name": "document_name", "field_schema": "keyword"},
{"field_name": "chunk_id", "field_schema": "keyword"},
{"field_name": "mission_phase", "field_schema": "keyword"},
{"field_name": "community_ids", "field_schema": "keyword"},
{"field_name": "extracted_entities", "field_schema": "keyword"},
{"field_name": "ingest_time", "field_schema": "datetime"},
],
}
def milvus_collection_schema(collection_name: str, vector_size: int) -> Dict[str, Any]:
"""Milvus field design for vector+graph linkage."""
return {
"collection_name": collection_name,
"fields": [
{"name": "id", "type": "VARCHAR", "max_length": 64, "is_primary": True},
{"name": "kb_id", "type": "INT64"},
{"name": "document_id", "type": "INT64"},
{"name": "chunk_index", "type": "INT32"},
{"name": "document_name", "type": "VARCHAR", "max_length": 255},
{"name": "mission_phase", "type": "VARCHAR", "max_length": 64},
{"name": "community_ids", "type": "VARCHAR", "max_length": 512},
{"name": "extracted_entities", "type": "VARCHAR", "max_length": 2048},
{"name": "ingest_time", "type": "VARCHAR", "max_length": 64},
{"name": "embedding", "type": "FLOAT_VECTOR", "dim": vector_size},
],
"index": {
"field_name": "embedding",
"index_type": "HNSW",
"metric_type": "COSINE",
"params": {"M": 16, "efConstruction": 200},
},
}
DOCUMENT_CHUNK_METADATA_DDL = """
ALTER TABLE document_chunks
ADD COLUMN IF NOT EXISTS chunk_index INT NULL,
ADD COLUMN IF NOT EXISTS token_count INT NULL,
ADD COLUMN IF NOT EXISTS language VARCHAR(16) DEFAULT 'zh',
ADD COLUMN IF NOT EXISTS mission_phase VARCHAR(64) NULL,
ADD COLUMN IF NOT EXISTS extracted_entities JSON NULL,
ADD COLUMN IF NOT EXISTS extracted_entity_types JSON NULL,
ADD COLUMN IF NOT EXISTS extracted_relations JSON NULL,
ADD COLUMN IF NOT EXISTS graph_node_ids JSON NULL,
ADD COLUMN IF NOT EXISTS graph_edge_ids JSON NULL,
ADD COLUMN IF NOT EXISTS community_ids JSON NULL,
ADD COLUMN IF NOT EXISTS embedding_model VARCHAR(128) NULL,
ADD COLUMN IF NOT EXISTS embedding_dim INT NULL;
""".strip()

View File

@@ -0,0 +1,11 @@
from .base import BaseVectorStore
from .chroma import ChromaVectorStore
from .qdrant import QdrantStore
from .factory import VectorStoreFactory
__all__ = [
'BaseVectorStore',
'ChromaVectorStore',
'QdrantStore',
'VectorStoreFactory'
]

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
class BaseVectorStore(ABC):
"""Abstract base class for vector store implementations"""
@abstractmethod
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
"""Initialize the vector store"""
pass
@abstractmethod
def add_documents(self, documents: List[Document]) -> None:
"""Add documents to the vector store"""
pass
@abstractmethod
def delete(self, ids: List[str]) -> None:
"""Delete documents from the vector store"""
pass
@abstractmethod
def as_retriever(self, **kwargs: Any):
"""Return a retriever interface for the vector store"""
pass
@abstractmethod
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents"""
pass
@abstractmethod
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents with score"""
pass
@abstractmethod
def delete_collection(self) -> None:
"""Delete the entire collection"""
pass

View File

@@ -0,0 +1,47 @@
from typing import List, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_chroma import Chroma
import chromadb
from app.core.config import settings
from .base import BaseVectorStore
class ChromaVectorStore(BaseVectorStore):
"""Chroma vector store implementation"""
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
"""Initialize Chroma vector store"""
chroma_client = chromadb.HttpClient(
host=settings.CHROMA_DB_HOST,
port=settings.CHROMA_DB_PORT,
)
self._store = Chroma(
client=chroma_client,
collection_name=collection_name,
embedding_function=embedding_function,
)
def add_documents(self, documents: List[Document]) -> None:
"""Add documents to Chroma"""
self._store.add_documents(documents)
def delete(self, ids: List[str]) -> None:
"""Delete documents from Chroma"""
self._store.delete(ids)
def as_retriever(self, **kwargs: Any):
"""Return a retriever interface"""
return self._store.as_retriever(**kwargs)
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents in Chroma"""
return self._store.similarity_search(query, k=k, **kwargs)
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents in Chroma with score"""
return self._store.similarity_search_with_score(query, k=k, **kwargs)
def delete_collection(self) -> None:
"""Delete the entire collection"""
self._store._client.delete_collection(self._store._collection.name)

View File

@@ -0,0 +1,59 @@
from typing import Dict, Type, Any
from langchain_core.embeddings import Embeddings
from .base import BaseVectorStore
from .chroma import ChromaVectorStore
from .qdrant import QdrantStore
class VectorStoreFactory:
"""Factory for creating vector store instances"""
_stores: Dict[str, Type[BaseVectorStore]] = {
'chroma': ChromaVectorStore,
'qdrant': QdrantStore
}
@classmethod
def create(
cls,
store_type: str,
collection_name: str,
embedding_function: Embeddings,
**kwargs: Any
) -> BaseVectorStore:
"""Create a vector store instance
Args:
store_type: Type of vector store ('chroma', 'qdrant', etc.)
collection_name: Name of the collection
embedding_function: Embedding function to use
**kwargs: Additional arguments for specific vector store implementations
Returns:
An instance of the requested vector store
Raises:
ValueError: If store_type is not supported
"""
store_class = cls._stores.get(store_type.lower())
if not store_class:
raise ValueError(
f"Unsupported vector store type: {store_type}. "
f"Supported types are: {', '.join(cls._stores.keys())}"
)
return store_class(
collection_name=collection_name,
embedding_function=embedding_function,
**kwargs
)
@classmethod
def register_store(cls, name: str, store_class: Type[BaseVectorStore]) -> None:
"""Register a new vector store implementation
Args:
name: Name of the vector store type
store_class: Vector store class implementation
"""
cls._stores[name.lower()] = store_class

View File

@@ -0,0 +1,43 @@
from typing import List, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.vectorstores import Qdrant
from app.core.config import settings
from .base import BaseVectorStore
class QdrantStore(BaseVectorStore):
"""Qdrant vector store implementation"""
def __init__(self, collection_name: str, embedding_function: Embeddings, **kwargs):
"""Initialize Qdrant vector store"""
self._store = Qdrant(
collection_name=collection_name,
embeddings=embedding_function,
url=settings.QDRANT_URL,
prefer_grpc=settings.QDRANT_PREFER_GRPC
)
def add_documents(self, documents: List[Document]) -> None:
"""Add documents to Qdrant"""
self._store.add_documents(documents)
def delete(self, ids: List[str]) -> None:
"""Delete documents from Qdrant"""
self._store.delete(ids)
def as_retriever(self, **kwargs: Any):
"""Return a retriever interface"""
return self._store.as_retriever(**kwargs)
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents in Qdrant"""
return self._store.similarity_search(query, k=k, **kwargs)
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
"""Search for similar documents in Qdrant with score"""
return self._store.similarity_search_with_score(query, k=k, **kwargs)
def delete_collection(self) -> None:
"""Delete the entire collection"""
self._store._client.delete_collection(self._store._collection_name)