增加代码知识库;修复文档处理内容;增加API设置

This commit is contained in:
2026-05-16 20:20:10 +08:00
parent 69b49d28b2
commit 7aa3ce3294
119 changed files with 182273 additions and 793 deletions

View File

@@ -17,6 +17,7 @@ from app.services.fusion_prompts import (
from app.services.graph.graphrag_adapter import GraphRAGAdapter
from app.services.intent_router import route_intent
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.reranker.external_api import ExternalRerankerClient
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline.pipeline import run_testing_pipeline
@@ -202,8 +203,12 @@ def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
return "\n\n".join(lines)
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
async def _build_kb_vector_stores(
db: Any,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -221,10 +226,13 @@ async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase])
return kb_vector_stores
def _build_reranker_client() -> ExternalRerankerClient:
def _build_reranker_client(model_profile: Any = None) -> ExternalRerankerClient:
api_key = settings.RERANKER_API_KEY
if model_profile is not None and getattr(model_profile, "provider", "") == "dashscope":
api_key = getattr(model_profile, "api_key", "") or api_key
return ExternalRerankerClient(
api_url=settings.RERANKER_API_URL,
api_key=settings.RERANKER_API_KEY,
api_key=api_key,
model=settings.RERANKER_MODEL,
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
)
@@ -287,6 +295,7 @@ async def generate_response(
knowledge_base_ids: List[int],
chat_id: int,
db: Any,
user_id: int,
) -> AsyncGenerator[str, None]:
try:
user_message = Message(content=query, role="user", chat_id=chat_id)
@@ -297,6 +306,9 @@ async def generate_response(
db.add(bot_message)
db.commit()
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
if _is_testing_generation_request(query):
explicit_type = _extract_requirement_type_from_query(query)
@@ -309,7 +321,7 @@ async def generate_response(
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs, model_profile)
if kb_vector_stores:
testing_retriever = MultiKBRetriever(
@@ -330,6 +342,7 @@ async def generate_response(
debug=True,
knowledge_context=knowledge_context,
use_model_generation=True,
llm_model=LLMFactory.create(streaming=False, model_profile=model_profile),
max_items_per_group=6,
cases_per_item=1,
max_focus_points=6,
@@ -391,11 +404,11 @@ async def generate_response(
)
kb_ids = [kb.id for kb in knowledge_bases]
llm = LLMFactory.create()
llm = LLMFactory.create(model_profile=model_profile)
decision = await route_intent(llm=llm, query=query, messages=messages)
intent = decision["intent"]
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
if intent in {"B", "C", "D"} and not kb_vector_stores:
intent = "A"
decision = {
@@ -403,7 +416,7 @@ async def generate_response(
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
}
reranker_client = _build_reranker_client()
reranker_client = _build_reranker_client(model_profile)
retriever = MultiKBRetriever(
reranker_client=reranker_client,
reranker_weight=settings.RERANKER_WEIGHT,
@@ -432,7 +445,7 @@ async def generate_response(
used_kb_ids: List[int] = []
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
graph_context, used_kb_ids = await adapter.local_context_multi(
kb_ids,
query,
@@ -465,7 +478,7 @@ async def generate_response(
community_context = ""
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
community_context, used_kb_ids = await adapter.global_context_multi(
kb_ids,
query,