Files
rag_agent/rag-web-ui/backend/app/services/testing_generation_job_service.py

237 lines
7.6 KiB
Python

from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
items.extend(current)
return items
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _resolve_knowledge_context(
db: Session,
*,
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
) -> str:
if knowledge_base_id is None:
return ""
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == knowledge_base_id,
KnowledgeBase.user_id == user_id,
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
if not kb_vector_stores:
return ""
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
rows = asyncio.run(
retriever.retrieve(
query=requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
)
if rows:
return format_retrieval_context(rows)
except Exception:
return ""
return ""
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
test_items = [
{
"id": item.get("id"),
"content": item.get("content"),
}
for item in _flatten_record(pipeline_result.get("test_items", {}))
]
test_cases = [
{
"id": item.get("id"),
"itemId": item.get("item_id"),
"testContent": item.get("test_content"),
"operationSteps": item.get("operation_steps", []),
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
}
for item in _flatten_record(pipeline_result.get("test_cases", {}))
]
expected_results = [
{
"id": item.get("id"),
"caseId": item.get("case_id"),
"result": item.get("result"),
}
for item in _flatten_record(pipeline_result.get("expected_results", {}))
]
return {
**req,
"测试项": test_items,
"测试用例": test_cases,
"预期结果": expected_results,
}
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
requirements = payload.get("requirements") or []
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
job.output_summary = {
"source_document_name": source_document_name,
"current_step": 0,
"total_steps": len(requirements),
}
db.commit()
generated_requirements: List[Dict[str, Any]] = []
for index, req in enumerate(requirements):
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
job.output_summary = {
"source_document_name": source_document_name,
"current_step": index + 1,
"total_steps": len(requirements),
"current_requirement_id": req_id,
}
db.commit()
description = str(req.get("description") or "").strip()
if not description:
generated_requirements.append(
{
**req,
"测试项": [],
"测试用例": [],
"预期结果": [],
}
)
continue
knowledge_context = _resolve_knowledge_context(
db,
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
)
pipeline_result = run_testing_pipeline(
user_requirement_text=description,
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=True,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,
max_llm_calls=10,
)
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
generated_at = datetime.utcnow()
generated_file = {
"sourceDocument": source_document_name,
"sourceJobId": source_job_id,
"generatedAt": generated_at.isoformat(),
"totalRequirements": len(generated_requirements),
"knowledgeBaseId": knowledge_base_id,
"requirements": generated_requirements,
}
generation = TestingGeneration(
job_id=job.id,
source_job_id=source_job_id,
source_document_name=source_document_name,
generated_at=generated_at,
total_requirements=len(generated_requirements),
knowledge_base_id=knowledge_base_id,
generated_file=generated_file,
)
db.add(generation)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"source_document_name": source_document_name,
"current_step": len(generated_requirements),
"total_steps": len(generated_requirements),
"total_requirements": len(generated_requirements),
"knowledge_base_id": knowledge_base_id,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id, str(exc))
finally:
db.close()