Files
rag_agent/rag-web-ui/backend/app/api/api_v1/testing.py

126 lines
4.5 KiB
Python

import logging
import asyncio
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.knowledge import Document, KnowledgeBase
from app.models.user import User
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
router = APIRouter()
logger = logging.getLogger(__name__)
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
@router.post("/generate", response_model=TestingPipelineResponse)
async def generate_testing_content(
*,
payload: TestingPipelineRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Any:
_ = current_user
knowledge_context = (payload.knowledge_context or "").strip()
if payload.knowledge_base_ids:
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(payload.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id,
)
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await retriever.retrieve(
query=payload.requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
top_k=payload.retrieval_top_k,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
except Exception as exc:
logger.exception(
"Testing generation retrieval fallback triggered for user=%s knowledge_base_ids=%s: %s",
current_user.id,
payload.knowledge_base_ids,
exc,
)
pipeline_kwargs = {
"user_requirement_text": payload.requirement_text,
"requirement_type_input": payload.requirement_type,
"debug": payload.debug,
"knowledge_context": knowledge_context,
"use_model_generation": payload.use_model_generation,
"max_items_per_group": payload.max_items_per_group,
"cases_per_item": payload.cases_per_item,
"max_focus_points": payload.max_focus_points,
"max_llm_calls": payload.max_llm_calls,
}
try:
result = await asyncio.wait_for(
asyncio.to_thread(run_testing_pipeline, **pipeline_kwargs),
timeout=MODEL_PIPELINE_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError as exc:
logger.exception(
"Testing pipeline timed out for user=%s use_model_generation=%s after %s seconds",
current_user.id,
payload.use_model_generation,
MODEL_PIPELINE_TIMEOUT_SECONDS,
)
raise HTTPException(
status_code=504,
detail=f"LLM generation timed out after {MODEL_PIPELINE_TIMEOUT_SECONDS} seconds",
) from exc
except Exception as exc:
logger.exception(
"Testing pipeline failed for user=%s use_model_generation=%s: %s",
current_user.id,
payload.use_model_generation,
exc,
)
raise HTTPException(
status_code=500,
detail=f"LLM generation failed: {exc}",
) from exc
return result