Files
rag_agent/rag-web-ui/backend/app/api/api_v1/testing.py

157 lines
5.9 KiB
Python

import logging
import asyncio
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.knowledge import Document, KnowledgeBase
from app.models.user import User
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
router = APIRouter()
logger = logging.getLogger(__name__)
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
async def _build_kb_vector_stores(
db: Session,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
@router.post("/generate", response_model=TestingPipelineResponse)
async def generate_testing_content(
*,
payload: TestingPipelineRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Any:
model_profile = ModelConfigService.get_active_config(db, current_user.id)
knowledge_context = (payload.knowledge_context or "").strip()
if payload.knowledge_base_ids and model_profile is not None:
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(payload.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id,
)
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await retriever.retrieve(
query=payload.requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
top_k=payload.retrieval_top_k,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
except Exception as exc:
logger.exception(
"Testing generation retrieval fallback triggered for user=%s knowledge_base_ids=%s: %s",
current_user.id,
payload.knowledge_base_ids,
exc,
)
elif payload.knowledge_base_ids:
logger.warning(
"Testing generation skipped knowledge retrieval because no active model config exists: user=%s knowledge_base_ids=%s",
current_user.id,
payload.knowledge_base_ids,
)
use_model_generation = bool(payload.use_model_generation and model_profile is not None)
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
ModelConfigService.touch_last_used(db, model_profile)
except Exception as exc:
logger.exception(
"Testing generation LLM initialization failed for user=%s, falling back to rule-based output: %s",
current_user.id,
exc,
)
use_model_generation = False
elif payload.use_model_generation:
logger.warning(
"Testing generation falling back to rule-based output because no active model config exists: user=%s",
current_user.id,
)
pipeline_kwargs = {
"user_requirement_text": payload.requirement_text,
"requirement_type_input": payload.requirement_type,
"debug": payload.debug,
"knowledge_context": knowledge_context,
"use_model_generation": use_model_generation,
"llm_model": llm_model,
"max_items_per_group": payload.max_items_per_group,
"cases_per_item": payload.cases_per_item,
"max_focus_points": payload.max_focus_points,
"max_llm_calls": payload.max_llm_calls,
}
try:
result = await asyncio.wait_for(
asyncio.to_thread(run_testing_pipeline, **pipeline_kwargs),
timeout=MODEL_PIPELINE_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError as exc:
logger.exception(
"Testing pipeline timed out for user=%s use_model_generation=%s after %s seconds",
current_user.id,
payload.use_model_generation,
MODEL_PIPELINE_TIMEOUT_SECONDS,
)
raise HTTPException(
status_code=504,
detail=f"LLM generation timed out after {MODEL_PIPELINE_TIMEOUT_SECONDS} seconds",
) from exc
except Exception as exc:
logger.exception(
"Testing pipeline failed for user=%s use_model_generation=%s: %s",
current_user.id,
payload.use_model_generation,
exc,
)
raise HTTPException(
status_code=500,
detail=f"LLM generation failed: {exc}",
) from exc
return result