init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,532 @@
import base64
import json
import re
from collections import defaultdict
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.config import settings
from app.models.chat import Message
from app.models.knowledge import Document, KnowledgeBase
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.fusion_prompts import (
GENERAL_CHAT_PROMPT_TEMPLATE,
GRAPH_GLOBAL_PROMPT_TEMPLATE,
GRAPH_LOCAL_PROMPT_TEMPLATE,
HYBRID_RAG_PROMPT_TEMPLATE,
)
from app.services.graph.graphrag_adapter import GraphRAGAdapter
from app.services.intent_router import route_intent
from app.services.llm.llm_factory import LLMFactory
from app.services.reranker.external_api import ExternalRerankerClient
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline.pipeline import run_testing_pipeline
from app.services.testing_pipeline.rules import REQUIREMENT_TYPES
from app.services.vector_store import VectorStoreFactory
TESTING_TARGET_KEYWORDS = [
"测试项",
"测试用例",
"预期成果",
"需求类型",
"测试分解",
"分解",
"正常测试",
"异常测试",
"测试充分性",
]
TESTING_ACTION_KEYWORDS = [
"生成",
"输出",
"给出",
"",
"编写",
"设计",
"整理",
"列出",
"提供",
"制定",
]
TYPE_ALIAS_MAP = {
"接口测试": "外部接口测试",
"ui测试": "人机交互界面测试",
"界面测试": "人机交互界面测试",
"恢复测试": "恢复性测试",
"可靠性": "可靠性测试",
"安全性": "安全性测试",
"边界": "边界测试",
"安装": "安装性测试",
"互操作": "互操作性测试",
"敏感性": "敏感性测试",
"充分性": "测试充分性要求",
}
def _escape_stream_text(text: str) -> str:
return text.replace('"', '\\"').replace("\n", "\\n")
def _extract_stream_text(chunk: Any) -> str:
content = getattr(chunk, "content", chunk)
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
maybe_text = item.get("text")
if isinstance(maybe_text, str):
parts.append(maybe_text)
else:
parts.append(str(item))
return "".join(parts)
return str(content)
def _preview_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
preview = []
for row in rows[:10]:
doc = row["document"]
metadata = doc.metadata or {}
preview.append(
{
"kb_id": row.get("kb_id"),
"source": metadata.get("source") or metadata.get("file_name") or "unknown",
"chunk_id": metadata.get("chunk_id") or "unknown",
"score": row.get("final_score", 0),
"reranker_score": row.get("reranker_score"),
}
)
return preview
def _context_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
context_rows: List[Dict[str, Any]] = []
for row in rows:
doc = row["document"]
metadata = dict(doc.metadata or {})
if "kb_id" not in metadata and row.get("kb_id") is not None:
metadata["kb_id"] = row.get("kb_id")
metadata.setdefault("retrieval_score", row.get("final_score", 0))
if row.get("reranker_score") is not None:
metadata.setdefault("reranker_score", row.get("reranker_score"))
context_rows.append(
{
"page_content": doc.page_content.strip(),
"metadata": metadata,
}
)
return context_rows
def _build_local_graph_context_fallback(rows: List[Dict[str, Any]]) -> str:
entities = set()
relations: List[Dict[str, Any]] = []
evidences: List[str] = []
for row in rows:
doc = row["document"]
metadata = doc.metadata or {}
for ent in metadata.get("extracted_entities", []):
entities.add(str(ent))
for rel in metadata.get("extracted_relations", []):
if isinstance(rel, dict):
relations.append(rel)
evidences.append(doc.page_content.strip())
entity_block = "\n".join(f"- {name}" for name in sorted(entities)[:80]) or "- 暂无结构化实体,已使用向量检索回退。"
relation_lines: List[str] = []
for rel in relations[:120]:
src = rel.get("source") or rel.get("src") or rel.get("src_id") or "UNKNOWN"
tgt = rel.get("target") or rel.get("tgt") or rel.get("tgt_id") or "UNKNOWN"
rel_type = rel.get("type") or rel.get("relation_type") or "其他"
desc = rel.get("description") or ""
relation_lines.append(f"- {src} -> {tgt} | 类型={rel_type} | 说明={desc}")
relation_block = "\n".join(relation_lines) or "- 暂无结构化关系,已使用证据片段回答。"
evidence_block = "\n\n".join(
f"[证据{i}] {snippet}" for i, snippet in enumerate(evidences[:8], start=1)
)
if not evidence_block:
evidence_block = "无可用证据。"
return (
"实体列表:\n"
f"{entity_block}\n\n"
"关系列表:\n"
f"{relation_block}\n\n"
"原文证据:\n"
f"{evidence_block}"
)
def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
groups: Dict[str, List[str]] = defaultdict(list)
for row in rows:
doc = row["document"]
metadata = doc.metadata or {}
community_ids = metadata.get("community_ids") or []
if isinstance(community_ids, list) and community_ids:
keys = [str(item) for item in community_ids]
else:
source = metadata.get("source") or metadata.get("file_name") or "unknown"
keys = [f"source:{source}"]
for key in keys:
groups[key].append(doc.page_content.strip())
if not groups:
return "暂无社区摘要数据,已回退为基于证据片段的全局总结。"
lines: List[str] = []
for idx, (community_id, snippets) in enumerate(groups.items(), start=1):
merged = " ".join(snippets[:3])
lines.append(f"社区{idx} ({community_id}) 摘要: {merged}")
return "\n\n".join(lines)
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _build_reranker_client() -> ExternalRerankerClient:
return ExternalRerankerClient(
api_url=settings.RERANKER_API_URL,
api_key=settings.RERANKER_API_KEY,
model=settings.RERANKER_MODEL,
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
)
def _is_testing_generation_request(query: str) -> bool:
text = (query or "").strip()
if not text:
return False
normalized = text.lower()
if normalized.startswith("/testing"):
return True
if any(
token in normalized
for token in (
"testing_orchestrator",
"testing-orchestrator",
"identify_requirement_type",
"identify-requirement-type",
)
):
return True
has_target = any(keyword in text for keyword in TESTING_TARGET_KEYWORDS)
has_action = any(keyword in text for keyword in TESTING_ACTION_KEYWORDS)
if has_target and has_action:
return True
if any(keyword in text for keyword in ("测试项", "测试用例", "预期成果")):
if re.search(r"(请|帮|给|麻烦).{0,12}(写|生成|设计|整理|编写|列出|提供|制定)", text):
return True
if text.startswith(("生成", "编写", "设计", "整理", "输出", "列出", "提供", "制定")):
return True
return False
def _extract_requirement_type_from_query(query: str) -> Optional[str]:
text = (query or "").strip()
if not text:
return None
for req_type in REQUIREMENT_TYPES:
if req_type in text:
return req_type
lowered = text.lower()
for alias, req_type in TYPE_ALIAS_MAP.items():
if alias in text or alias in lowered:
return req_type
return None
async def generate_response(
query: str,
messages: dict,
knowledge_base_ids: List[int],
chat_id: int,
db: Any,
) -> AsyncGenerator[str, None]:
try:
user_message = Message(content=query, role="user", chat_id=chat_id)
db.add(user_message)
db.commit()
bot_message = Message(content="", role="assistant", chat_id=chat_id)
db.add(bot_message)
db.commit()
if _is_testing_generation_request(query):
explicit_type = _extract_requirement_type_from_query(query)
retrieval_rows: List[Dict[str, Any]] = []
knowledge_context = ""
kb_vector_stores = []
if knowledge_base_ids:
testing_kbs = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
if kb_vector_stores:
testing_retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await testing_retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
pipeline_result = run_testing_pipeline(
user_requirement_text=query,
requirement_type_input=explicit_type,
debug=True,
knowledge_context=knowledge_context,
use_model_generation=True,
max_items_per_group=6,
cases_per_item=1,
max_focus_points=6,
max_llm_calls=2,
)
context_payload = {
"route": {
"intent": "TESTING",
"reason": "命中测试生成意图,已自动调用测试工具链。",
},
"intent": "TESTING",
"skill_profile": "testing-orchestrator",
"tool_chain": [
"identify-requirement-type",
"decompose-test-items",
"generate-test-cases",
"build_expected_results",
"format_output",
],
"selected_chain": "TESTING_PIPELINE",
"graph_used": False,
"reranker_enabled": False,
"retrieval_preview": _preview_rows(retrieval_rows),
"context": _context_rows(retrieval_rows),
"testing_pipeline": {
"trace_id": pipeline_result.get("trace_id"),
"requirement_type": pipeline_result.get("requirement_type"),
"candidates": pipeline_result.get("candidates", []),
"pipeline_summary": pipeline_result.get("pipeline_summary", ""),
"knowledge_used": pipeline_result.get("knowledge_used", False),
"step_logs": pipeline_result.get("step_logs", []),
},
}
escaped_context = json.dumps(context_payload, ensure_ascii=False)
base64_context = base64.b64encode(escaped_context.encode()).decode()
separator = "__LLM_RESPONSE__"
full_response = f"{base64_context}{separator}"
yield f'0:"{base64_context}{separator}"\n'
rendered_text = pipeline_result.get("formatted_output", "").strip()
if not rendered_text:
rendered_text = "未生成测试内容,请补充更明确的需求后重试。"
full_response += rendered_text
yield f'0:"{_escape_stream_text(rendered_text)}"\n'
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
bot_message.content = full_response
db.commit()
return
knowledge_bases = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_ids = [kb.id for kb in knowledge_bases]
llm = LLMFactory.create()
decision = await route_intent(llm=llm, query=query, messages=messages)
intent = decision["intent"]
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if intent in {"B", "C", "D"} and not kb_vector_stores:
intent = "A"
decision = {
"intent": "A",
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
}
reranker_client = _build_reranker_client()
retriever = MultiKBRetriever(
reranker_client=reranker_client,
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows: List[Dict[str, Any]] = []
graph_used = False
selected_chain = intent
prompt_text = ""
if intent == "A":
prompt_text = GENERAL_CHAT_PROMPT_TEMPLATE.format(query=query)
elif intent == "B":
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=12,
)
context = format_retrieval_context(retrieval_rows) or "无可用证据。"
prompt_text = HYBRID_RAG_PROMPT_TEMPLATE.format(query=query, context=context)
elif intent == "C":
graph_context = ""
used_kb_ids: List[int] = []
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
graph_context, used_kb_ids = await adapter.local_context_multi(
kb_ids,
query,
top_k=settings.GRAPHRAG_LOCAL_TOP_K,
level=settings.GRAPHRAG_QUERY_LEVEL,
)
graph_used = bool(graph_context)
except Exception:
graph_context = ""
if not graph_context:
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=18,
top_k=14,
)
graph_context = _build_local_graph_context_fallback(retrieval_rows)
selected_chain = "C_fallback_B"
else:
selected_chain = "C_graph"
prompt_text = GRAPH_LOCAL_PROMPT_TEMPLATE.format(
query=query,
graph_context=graph_context,
)
else:
community_context = ""
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
community_context, used_kb_ids = await adapter.global_context_multi(
kb_ids,
query,
level=settings.GRAPHRAG_QUERY_LEVEL,
)
graph_used = bool(community_context)
except Exception:
community_context = ""
if not community_context:
retrieval_rows = await retriever.retrieve(
query=query,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=20,
top_k=14,
)
community_context = _build_global_community_context_fallback(retrieval_rows)
selected_chain = "D_fallback_B"
else:
selected_chain = "D_graph"
prompt_text = GRAPH_GLOBAL_PROMPT_TEMPLATE.format(
query=query,
community_context=community_context,
)
context_payload = {
"route": decision,
"intent": intent,
"selected_chain": selected_chain,
"graph_used": graph_used,
"reranker_enabled": reranker_client.enabled,
"retrieval_preview": _preview_rows(retrieval_rows),
"context": _context_rows(retrieval_rows),
}
escaped_context = json.dumps(context_payload, ensure_ascii=False)
base64_context = base64.b64encode(escaped_context.encode()).decode()
separator = "__LLM_RESPONSE__"
full_response = f"{base64_context}{separator}"
yield f'0:"{base64_context}{separator}"\n'
async for chunk in llm.astream(prompt_text):
text = _extract_stream_text(chunk)
if not text:
continue
full_response += text
yield f'0:"{_escape_stream_text(text)}"\n'
yield 'd:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n'
bot_message.content = full_response
db.commit()
except Exception as e:
error_message = f"Error generating response: {str(e)}"
print(error_message)
yield "3:{text}\n".format(text=error_message)
if "bot_message" in locals():
bot_message.content = error_message
db.commit()
finally:
db.close()