from __future__ import annotations import asyncio from datetime import datetime from typing import Any, Dict, List from sqlalchemy.orm import Session from app.core.config import settings from app.db.session import SessionLocal from app.models.knowledge import Document, KnowledgeBase from app.models.tooling import TestingGeneration, ToolJob from app.services.embedding.embedding_factory import EmbeddingsFactory from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context from app.services.testing_pipeline import run_testing_pipeline from app.services.vector_store import VectorStoreFactory def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: items: List[Dict[str, Any]] = [] for current in value.values(): items.extend(current) return items def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]: embeddings = EmbeddingsFactory.create() kb_vector_stores: List[Dict[str, Any]] = [] for kb in knowledge_bases: documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all() if not documents: continue store = VectorStoreFactory.create( store_type=settings.VECTOR_STORE_TYPE, collection_name=f"kb_{kb.id}", embedding_function=embeddings, ) kb_vector_stores.append({"kb_id": kb.id, "store": store}) return kb_vector_stores def _resolve_knowledge_context( db: Session, *, user_id: int, requirement_text: str, knowledge_base_id: int | None, ) -> str: if knowledge_base_id is None: return "" try: knowledge_bases = ( db.query(KnowledgeBase) .filter( KnowledgeBase.id == knowledge_base_id, KnowledgeBase.user_id == user_id, ) .all() ) kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases) if not kb_vector_stores: return "" retriever = MultiKBRetriever( reranker_weight=settings.RERANKER_WEIGHT, ) rows = asyncio.run( retriever.retrieve( query=requirement_text, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=16, top_k=8, ) ) if rows: return format_retrieval_context(rows) except Exception: return "" return "" def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]: test_items = [ { "id": item.get("id"), "content": item.get("content"), } for item in _flatten_record(pipeline_result.get("test_items", {})) ] test_cases = [ { "id": item.get("id"), "itemId": item.get("item_id"), "testContent": item.get("test_content"), "operationSteps": item.get("operation_steps", []), "expectedResultPlaceholder": item.get("expected_result_placeholder"), } for item in _flatten_record(pipeline_result.get("test_cases", {})) ] expected_results = [ { "id": item.get("id"), "caseId": item.get("case_id"), "result": item.get("result"), } for item in _flatten_record(pipeline_result.get("expected_results", {})) ] return { **req, "测试项": test_items, "测试用例": test_cases, "预期结果": expected_results, } def _mark_job_failed(job_id: int, error_message: str) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "failed" job.completed_at = datetime.utcnow() job.error_message = error_message[:2000] db.commit() finally: db.close() def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return requirements = payload.get("requirements") or [] source_document_name = str(payload.get("source_document_name") or job.input_file_name or "") source_job_id = payload.get("source_job_id") knowledge_base_id = payload.get("knowledge_base_id") job.status = "processing" job.started_at = datetime.utcnow() job.error_message = None job.output_summary = { "source_document_name": source_document_name, "current_step": 0, "total_steps": len(requirements), } db.commit() generated_requirements: List[Dict[str, Any]] = [] for index, req in enumerate(requirements): req_id = str(req.get("id") or f"REQ-{index + 1:03d}") job.output_summary = { "source_document_name": source_document_name, "current_step": index + 1, "total_steps": len(requirements), "current_requirement_id": req_id, } db.commit() description = str(req.get("description") or "").strip() if not description: generated_requirements.append( { **req, "测试项": [], "测试用例": [], "预期结果": [], } ) continue knowledge_context = _resolve_knowledge_context( db, user_id=job.user_id, requirement_text=description, knowledge_base_id=knowledge_base_id, ) pipeline_result = run_testing_pipeline( user_requirement_text=description, requirement_type_input=req.get("requirementType"), debug=False, knowledge_context=knowledge_context, use_model_generation=True, max_items_per_group=12, cases_per_item=2, max_focus_points=6, max_llm_calls=10, ) generated_requirements.append(_build_generated_requirement(req, pipeline_result)) generated_at = datetime.utcnow() generated_file = { "sourceDocument": source_document_name, "sourceJobId": source_job_id, "generatedAt": generated_at.isoformat(), "totalRequirements": len(generated_requirements), "knowledgeBaseId": knowledge_base_id, "requirements": generated_requirements, } generation = TestingGeneration( job_id=job.id, source_job_id=source_job_id, source_document_name=source_document_name, generated_at=generated_at, total_requirements=len(generated_requirements), knowledge_base_id=knowledge_base_id, generated_file=generated_file, ) db.add(generation) job.status = "completed" job.completed_at = datetime.utcnow() job.output_summary = { "source_document_name": source_document_name, "current_step": len(generated_requirements), "total_steps": len(generated_requirements), "total_requirements": len(generated_requirements), "knowledge_base_id": knowledge_base_id, } db.commit() except Exception as exc: db.rollback() _mark_job_failed(job_id, str(exc)) finally: db.close()