from __future__ import annotations import asyncio import logging from datetime import datetime from typing import Any, Dict, List from sqlalchemy.orm import Session from app.core.config import settings from app.db.session import SessionLocal from app.models.knowledge import Document, KnowledgeBase from app.models.tooling import TestingGeneration, ToolJob from app.services.embedding.embedding_factory import EmbeddingsFactory from app.services.llm.llm_factory import LLMFactory from app.services.model_config import ModelConfigService from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context from app.services.testing_pipeline import run_testing_pipeline from app.services.vector_store import VectorStoreFactory logger = logging.getLogger(__name__) def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: items: List[Dict[str, Any]] = [] for current in value.values(): items.extend(current) return items def _build_kb_vector_stores( db: Session, knowledge_bases: List[KnowledgeBase], model_profile: Any, ) -> List[Dict[str, Any]]: if model_profile is None: return [] embeddings = EmbeddingsFactory.create(model_profile=model_profile) kb_vector_stores: List[Dict[str, Any]] = [] for kb in knowledge_bases: documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all() if not documents: continue store = VectorStoreFactory.create( store_type=settings.VECTOR_STORE_TYPE, collection_name=f"kb_{kb.id}", embedding_function=embeddings, ) kb_vector_stores.append({"kb_id": kb.id, "store": store}) return kb_vector_stores def _resolve_knowledge_context( db: Session, *, user_id: int, requirement_text: str, knowledge_base_id: int | None, model_profile: Any, ) -> str: if knowledge_base_id is None or model_profile is None: return "" try: knowledge_bases = ( db.query(KnowledgeBase) .filter( KnowledgeBase.id == knowledge_base_id, KnowledgeBase.user_id == user_id, ) .all() ) kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases, model_profile) if not kb_vector_stores: return "" retriever = MultiKBRetriever( reranker_weight=settings.RERANKER_WEIGHT, ) rows = asyncio.run( retriever.retrieve( query=requirement_text, kb_vector_stores=kb_vector_stores, fetch_k_per_kb=16, top_k=8, ) ) if rows: return format_retrieval_context(rows) except Exception: return "" return "" def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]: test_items = [ { "id": item.get("id"), "content": item.get("content"), } for item in _flatten_record(pipeline_result.get("test_items", {})) ] test_cases = [ { "id": item.get("id"), "itemId": item.get("item_id"), "testContent": item.get("test_content"), "operationSteps": item.get("operation_steps", []), "expectedResultPlaceholder": item.get("expected_result_placeholder"), } for item in _flatten_record(pipeline_result.get("test_cases", {})) ] expected_results = [ { "id": item.get("id"), "caseId": item.get("case_id"), "result": item.get("result"), } for item in _flatten_record(pipeline_result.get("expected_results", {})) ] return { **req, "测试项": test_items, "测试用例": test_cases, "预期结果": expected_results, } def _mark_job_failed(job_id: int, error_message: str) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "failed" job.completed_at = datetime.utcnow() job.error_message = error_message[:2000] db.commit() finally: db.close() def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return requirements = payload.get("requirements") or [] source_document_name = str(payload.get("source_document_name") or job.input_file_name or "") source_job_id = payload.get("source_job_id") knowledge_base_id = payload.get("knowledge_base_id") model_profile = ModelConfigService.get_active_config(db, job.user_id) if model_profile is not None: ModelConfigService.touch_last_used(db, model_profile) use_model_generation = model_profile is not None llm_model = None if use_model_generation: try: llm_model = LLMFactory.create(streaming=False, model_profile=model_profile) except Exception as exc: logger.exception( "Testing generation LLM initialization failed for job=%s, falling back to rule-based output: %s", job_id, exc, ) use_model_generation = False else: logger.info( "Testing generation job=%s has no active model config; using rule-based output.", job_id, ) job.status = "processing" job.started_at = datetime.utcnow() job.error_message = None job.output_summary = { "source_document_name": source_document_name, "current_step": 0, "total_steps": len(requirements), } db.commit() generated_requirements: List[Dict[str, Any]] = [] for index, req in enumerate(requirements): req_id = str(req.get("id") or f"REQ-{index + 1:03d}") job.output_summary = { "source_document_name": source_document_name, "current_step": index + 1, "total_steps": len(requirements), "current_requirement_id": req_id, } db.commit() description = str(req.get("description") or "").strip() if not description: generated_requirements.append( { **req, "测试项": [], "测试用例": [], "预期结果": [], } ) continue knowledge_context = _resolve_knowledge_context( db, user_id=job.user_id, requirement_text=description, knowledge_base_id=knowledge_base_id, model_profile=model_profile, ) pipeline_result = run_testing_pipeline( user_requirement_text=description, requirement_type_input=req.get("requirementType"), debug=False, knowledge_context=knowledge_context, use_model_generation=use_model_generation, llm_model=llm_model, max_items_per_group=12, cases_per_item=2, max_focus_points=6, max_llm_calls=10, ) generated_requirements.append(_build_generated_requirement(req, pipeline_result)) generated_at = datetime.utcnow() generated_file = { "sourceDocument": source_document_name, "sourceJobId": source_job_id, "generatedAt": generated_at.isoformat(), "totalRequirements": len(generated_requirements), "knowledgeBaseId": knowledge_base_id, "requirements": generated_requirements, } generation = TestingGeneration( job_id=job.id, source_job_id=source_job_id, source_document_name=source_document_name, generated_at=generated_at, total_requirements=len(generated_requirements), knowledge_base_id=knowledge_base_id, generated_file=generated_file, ) db.add(generation) job.status = "completed" job.completed_at = datetime.utcnow() job.output_summary = { "source_document_name": source_document_name, "current_step": len(generated_requirements), "total_steps": len(generated_requirements), "total_requirements": len(generated_requirements), "knowledge_base_id": knowledge_base_id, } db.commit() except Exception as exc: db.rollback() _mark_job_failed(job_id, str(exc)) finally: db.close()