Files
rag_agent/rag-web-ui/backend/app/services/testing_generation_job_service.py

274 lines
8.9 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
logger = logging.getLogger(__name__)
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
items.extend(current)
return items
def _build_kb_vector_stores(
db: Session,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
if model_profile is None:
return []
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _resolve_knowledge_context(
db: Session,
*,
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
model_profile: Any,
) -> str:
if knowledge_base_id is None or model_profile is None:
return ""
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == knowledge_base_id,
KnowledgeBase.user_id == user_id,
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases, model_profile)
if not kb_vector_stores:
return ""
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
rows = asyncio.run(
retriever.retrieve(
query=requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
)
if rows:
return format_retrieval_context(rows)
except Exception:
return ""
return ""
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
test_items = [
{
"id": item.get("id"),
"content": item.get("content"),
}
for item in _flatten_record(pipeline_result.get("test_items", {}))
]
test_cases = [
{
"id": item.get("id"),
"itemId": item.get("item_id"),
"testContent": item.get("test_content"),
"operationSteps": item.get("operation_steps", []),
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
}
for item in _flatten_record(pipeline_result.get("test_cases", {}))
]
expected_results = [
{
"id": item.get("id"),
"caseId": item.get("case_id"),
"result": item.get("result"),
}
for item in _flatten_record(pipeline_result.get("expected_results", {}))
]
return {
**req,
"测试项": test_items,
"测试用例": test_cases,
"预期结果": expected_results,
}
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
requirements = payload.get("requirements") or []
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
model_profile = ModelConfigService.get_active_config(db, job.user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
use_model_generation = model_profile is not None
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
except Exception as exc:
logger.exception(
"Testing generation LLM initialization failed for job=%s, falling back to rule-based output: %s",
job_id,
exc,
)
use_model_generation = False
else:
logger.info(
"Testing generation job=%s has no active model config; using rule-based output.",
job_id,
)
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
job.output_summary = {
"source_document_name": source_document_name,
"current_step": 0,
"total_steps": len(requirements),
}
db.commit()
generated_requirements: List[Dict[str, Any]] = []
for index, req in enumerate(requirements):
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
job.output_summary = {
"source_document_name": source_document_name,
"current_step": index + 1,
"total_steps": len(requirements),
"current_requirement_id": req_id,
}
db.commit()
description = str(req.get("description") or "").strip()
if not description:
generated_requirements.append(
{
**req,
"测试项": [],
"测试用例": [],
"预期结果": [],
}
)
continue
knowledge_context = _resolve_knowledge_context(
db,
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
model_profile=model_profile,
)
pipeline_result = run_testing_pipeline(
user_requirement_text=description,
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=use_model_generation,
llm_model=llm_model,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,
max_llm_calls=10,
)
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
generated_at = datetime.utcnow()
generated_file = {
"sourceDocument": source_document_name,
"sourceJobId": source_job_id,
"generatedAt": generated_at.isoformat(),
"totalRequirements": len(generated_requirements),
"knowledgeBaseId": knowledge_base_id,
"requirements": generated_requirements,
}
generation = TestingGeneration(
job_id=job.id,
source_job_id=source_job_id,
source_document_name=source_document_name,
generated_at=generated_at,
total_requirements=len(generated_requirements),
knowledge_base_id=knowledge_base_id,
generated_file=generated_file,
)
db.add(generation)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"source_document_name": source_document_name,
"current_step": len(generated_requirements),
"total_steps": len(generated_requirements),
"total_requirements": len(generated_requirements),
"knowledge_base_id": knowledge_base_id,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id, str(exc))
finally:
db.close()