增加代码知识库;修复文档处理内容;增加API设置

This commit is contained in:
2026-05-16 20:20:10 +08:00
parent 69b49d28b2
commit 7aa3ce3294
119 changed files with 182273 additions and 793 deletions

View File

@@ -17,6 +17,7 @@ from app.services.fusion_prompts import (
from app.services.graph.graphrag_adapter import GraphRAGAdapter
from app.services.intent_router import route_intent
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.reranker.external_api import ExternalRerankerClient
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline.pipeline import run_testing_pipeline
@@ -202,8 +203,12 @@ def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
return "\n\n".join(lines)
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
async def _build_kb_vector_stores(
db: Any,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -221,10 +226,13 @@ async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase])
return kb_vector_stores
def _build_reranker_client() -> ExternalRerankerClient:
def _build_reranker_client(model_profile: Any = None) -> ExternalRerankerClient:
api_key = settings.RERANKER_API_KEY
if model_profile is not None and getattr(model_profile, "provider", "") == "dashscope":
api_key = getattr(model_profile, "api_key", "") or api_key
return ExternalRerankerClient(
api_url=settings.RERANKER_API_URL,
api_key=settings.RERANKER_API_KEY,
api_key=api_key,
model=settings.RERANKER_MODEL,
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
)
@@ -287,6 +295,7 @@ async def generate_response(
knowledge_base_ids: List[int],
chat_id: int,
db: Any,
user_id: int,
) -> AsyncGenerator[str, None]:
try:
user_message = Message(content=query, role="user", chat_id=chat_id)
@@ -297,6 +306,9 @@ async def generate_response(
db.add(bot_message)
db.commit()
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
if _is_testing_generation_request(query):
explicit_type = _extract_requirement_type_from_query(query)
@@ -309,7 +321,7 @@ async def generate_response(
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs, model_profile)
if kb_vector_stores:
testing_retriever = MultiKBRetriever(
@@ -330,6 +342,7 @@ async def generate_response(
debug=True,
knowledge_context=knowledge_context,
use_model_generation=True,
llm_model=LLMFactory.create(streaming=False, model_profile=model_profile),
max_items_per_group=6,
cases_per_item=1,
max_focus_points=6,
@@ -391,11 +404,11 @@ async def generate_response(
)
kb_ids = [kb.id for kb in knowledge_bases]
llm = LLMFactory.create()
llm = LLMFactory.create(model_profile=model_profile)
decision = await route_intent(llm=llm, query=query, messages=messages)
intent = decision["intent"]
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
if intent in {"B", "C", "D"} and not kb_vector_stores:
intent = "A"
decision = {
@@ -403,7 +416,7 @@ async def generate_response(
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
}
reranker_client = _build_reranker_client()
reranker_client = _build_reranker_client(model_profile)
retriever = MultiKBRetriever(
reranker_client=reranker_client,
reranker_weight=settings.RERANKER_WEIGHT,
@@ -432,7 +445,7 @@ async def generate_response(
used_kb_ids: List[int] = []
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
graph_context, used_kb_ids = await adapter.local_context_multi(
kb_ids,
query,
@@ -465,7 +478,7 @@ async def generate_response(
community_context = ""
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
community_context, used_kb_ids = await adapter.global_context_multi(
kb_ids,
query,

View File

@@ -0,0 +1,10 @@
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
__all__ = [
"CodeFunctionEvidence",
"CodeGraphContext",
"CodeKnowledgeBaseAdapter",
"CodeSearchHit",
]

View File

@@ -0,0 +1,517 @@
from __future__ import annotations
import json
import logging
import math
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from app.services.code_kb.graph import CodeCallGraph
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
logger = logging.getLogger(__name__)
FIELD_ALIASES = {
"name": ("name", "function_name"),
"file": ("file", "file_path"),
"summary": ("summary",),
"logic_flow": ("logic_flow", "logic"),
"code_snippet": ("code_snippet", "source", "code"),
"calls": ("calls", "called_functions"),
"called_by": ("called_by", "caller_functions", "callers"),
}
class SimpleVectorIndex:
def __init__(self, vectors: Any, dimension: int) -> None:
try:
import numpy as np
except ImportError: # pragma: no cover - depends on deployment environment
np = None
self._np = np
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
self.ntotal = len(self.vectors)
@classmethod
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
try:
with open(vector_path, "r", encoding="utf-8") as file:
payload = json.load(file)
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
return None
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
return None
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
def search(self, vector: Any, top_k: int) -> Any:
if self.ntotal == 0:
if self._np is not None:
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
return [[]], [[]]
if self._np is not None:
query = self._np.array(vector, dtype="float32")
if query.ndim != 2 or query.shape[1] != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
order = self._np.argsort(distances)[:top_k]
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
query_values = [float(value) for value in query_row]
if len(query_values) != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
scored = [
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
for index, stored in enumerate(self.vectors)
]
scored.sort(key=lambda item: item[0])
selected = scored[:top_k]
return [[item[0] for item in selected]], [[item[1] for item in selected]]
def has_nonzero_vectors(self) -> bool:
return any(
any(abs(float(value)) > 1e-12 for value in vector)
for vector in self.vectors
)
class UnavailableVectorIndex:
d = 0
ntotal = 0
def __init__(self, reason: str) -> None:
self.reason = reason
def search(self, vector: Any, top_k: int) -> Any:
raise RuntimeError(self.reason)
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
for alias in aliases:
value = source.get(alias)
if value not in (None, ""):
return value
return default
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, tuple):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, str):
return [value] if value else []
return [str(value)]
def _as_int(value: Any) -> Optional[int]:
if value in (None, ""):
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _as_bool(value: Any, default: bool = True) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "y"}:
return True
if normalized in {"0", "false", "no", "n"}:
return False
return default
def _load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as file:
return json.load(file)
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
metadata = _load_json(metadata_path)
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
metadata_items = metadata_items or []
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
return {
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
"graph_nodes": graph_meta.get("total_nodes"),
"graph_edges": graph_meta.get("total_edges"),
"project_root": graph_meta.get("project_root"),
"generated_at": graph_meta.get("generated_at"),
}
class CodeKnowledgeBaseAdapter:
def __init__(self, embedding_function: Any = None) -> None:
self.embedding_function = embedding_function
self.faiss_index: Any = None
self.metadata: List[Dict[str, Any]] = []
self.graph_data: Dict[str, Any] = {}
self.functions: List[CodeFunctionEvidence] = []
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
self.call_graph: Optional[CodeCallGraph] = None
self.vector_search_enabled = False
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
self._validate_paths(vector_path, metadata_path, graph_path)
self.faiss_index = self._read_faiss_index(vector_path)
self.metadata = self._read_metadata(metadata_path)
self.graph_data = _load_json(graph_path)
graph_nodes = self._index_graph_nodes(self.graph_data)
self.functions = [
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
for row in self.metadata
]
self.functions_by_id = {item.node_id: item for item in self.functions}
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
self._configure_vector_search()
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
missing = [
path
for path in [vector_path, metadata_path, graph_path]
if not path or not Path(path).exists()
]
if missing:
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
def _read_faiss_index(self, vector_path: str) -> Any:
simple_index = SimpleVectorIndex.from_file(vector_path)
if simple_index is not None:
return simple_index
try:
import faiss # type: ignore
except ImportError as exc:
logger.warning("faiss is not installed; code search will use lexical fallback.")
return UnavailableVectorIndex(
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
)
return faiss.read_index(vector_path)
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
metadata = _load_json(metadata_path)
if isinstance(metadata, dict):
metadata = metadata.get("functions", metadata.get("items", []))
if not isinstance(metadata, list):
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
return [item for item in metadata if isinstance(item, dict)]
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
item.get("id"): item
for item in graph_data.get("nodes", []) or []
if isinstance(item, dict) and item.get("id")
}
def _normalize_metadata(
self,
metadata: Dict[str, Any],
graph_nodes: Dict[str, Dict[str, Any]],
index_dimension: int,
) -> CodeFunctionEvidence:
name = _first_value(metadata, FIELD_ALIASES["name"], "")
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
graph_node = graph_nodes.get(node_id, {})
raw_attributes = graph_node.get("raw_attributes") or {}
if not name:
name = graph_node.get("name") or node_id.removeprefix("Function:")
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
embedding_dim = metadata.get("embedding_dim") or index_dimension
embedding_available = _as_bool(
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
default=True,
)
return CodeFunctionEvidence(
node_id=node_id,
name=name,
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
file=file_path or "",
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
signature=metadata.get("signature") or graph_node.get("signature") or "",
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
logic_flow=_first_value(
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
)
or "",
code_snippet=_first_value(
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
)
or "",
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
called_by=_as_list(
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
),
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
embedding_model=metadata.get("embedding_model") or "",
embedding_dim=int(embedding_dim or 0),
embedding_available=embedding_available,
raw=metadata,
)
def _configure_vector_search(self) -> None:
index_total = int(getattr(self.faiss_index, "ntotal", 0))
self.vector_search_enabled = False
self.vector_search_disabled_reason = ""
if isinstance(self.faiss_index, UnavailableVectorIndex):
self.vector_search_disabled_reason = self.faiss_index.reason
return
if index_total == 0:
self.vector_search_disabled_reason = "Vector index is empty."
return
if index_total != len(self.functions):
self.vector_search_disabled_reason = (
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
)
logger.warning(
"FAISS index size (%s) differs from metadata size (%s)",
index_total,
len(self.functions),
)
return
if not self._index_has_nonzero_vectors():
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
if not any(item.embedding_available for item in self.functions):
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
self.vector_search_enabled = True
def _index_has_nonzero_vectors(self) -> bool:
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
if callable(has_nonzero_vectors):
return bool(has_nonzero_vectors())
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
if not callable(reconstruct_n):
return True
try:
import numpy as np
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
sample_size = min(index_total, 1000)
if sample_size <= 0:
return False
vectors = reconstruct_n(0, sample_size)
return bool(np.any(np.abs(vectors) > 1e-12))
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
return True
def _get_embedding_function(self) -> Any:
if self.embedding_function is None:
from app.services.embedding.embedding_factory import EmbeddingsFactory
self.embedding_function = EmbeddingsFactory.create()
return self.embedding_function
def _embed_query(self, query: str) -> np.ndarray:
try:
import numpy as np
except ImportError as exc:
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
embedding_function = self._get_embedding_function()
if hasattr(embedding_function, "embed_query"):
vector = embedding_function.embed_query(query)
elif callable(embedding_function):
vector = embedding_function(query)
else:
raise TypeError("Unsupported embedding function.")
query_vector = np.array([vector], dtype="float32")
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
if index_dimension and query_vector.shape[1] != index_dimension:
raise ValueError(
f"Query embedding dimension {query_vector.shape[1]} does not match "
f"FAISS index dimension {index_dimension}."
)
return query_vector
@staticmethod
def distance_to_similarity(distance: float) -> float:
if math.isnan(distance) or distance < 0:
return 0.0
if distance <= 2.0:
return max(0.0, min(1.0, 1.0 - distance / 2.0))
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
def search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
if self.faiss_index is None:
raise RuntimeError("Code knowledge base has not been loaded.")
if not query.strip():
return []
if not self.vector_search_enabled:
logger.info(
"Vector code search disabled, using lexical search: %s",
self.vector_search_disabled_reason,
)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
try:
vector = self._embed_query(query)
distances, indices = self.faiss_index.search(vector, top_k)
except Exception as exc:
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
hits: List[CodeSearchHit] = []
for rank, raw_index in enumerate(indices[0], start=1):
index = int(raw_index)
if index < 0 or index >= len(self.functions):
continue
evidence = self.functions[index]
if not evidence.embedding_available:
continue
distance = float(distances[0][rank - 1])
similarity = self.distance_to_similarity(distance)
if similarity < min_similarity:
continue
hits.append(
CodeSearchHit(
evidence=evidence,
similarity=similarity,
distance=distance,
rank=rank,
)
)
if not hits:
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
return hits
def _lexical_search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
query_tokens = self._tokens(query)
if not query_tokens:
return []
scored: List[CodeSearchHit] = []
for evidence in self.functions:
text = self._function_search_text(evidence)
score = self._lexical_similarity(query_tokens, text, evidence)
if score < min_similarity:
continue
scored.append(
CodeSearchHit(
evidence=evidence,
similarity=score,
distance=max(0.0, 1.0 - score),
rank=0,
)
)
scored.sort(key=lambda item: item.similarity, reverse=True)
for rank, item in enumerate(scored[:top_k], start=1):
item.rank = rank
return scored[:top_k]
@staticmethod
def _tokens(text: str) -> List[str]:
normalized = (text or "").lower()
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
expanded: List[str] = []
for token in tokens:
expanded.append(token)
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
return expanded
@staticmethod
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
return "\n".join(
[
evidence.node_id,
evidence.name,
evidence.qualified_name,
evidence.file,
evidence.signature,
evidence.summary,
evidence.logic_flow,
evidence.code_snippet[:4000],
" ".join(evidence.calls),
" ".join(evidence.called_by),
" ".join(evidence.includes),
]
).lower()
def _lexical_similarity(
self,
query_tokens: List[str],
text: str,
evidence: CodeFunctionEvidence,
) -> float:
text_tokens = set(self._tokens(text))
if not text_tokens:
return 0.0
unique_query_tokens = list(dict.fromkeys(query_tokens))
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
substring_count = sum(1 for token in unique_query_tokens if token in text)
overlap_score = overlap_count / max(1, len(unique_query_tokens))
substring_score = substring_count / max(1, len(unique_query_tokens))
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
evidence_score = 0.0
if evidence.summary:
evidence_score += 0.08
if evidence.logic_flow:
evidence_score += 0.05
if evidence.code_snippet:
evidence_score += 0.04
if evidence.start_line is not None and evidence.file:
evidence_score += 0.03
return max(
0.0,
min(
1.0,
overlap_score * 0.45
+ substring_score * 0.35
+ name_score * 0.12
+ evidence_score,
),
)
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
return self.functions_by_id.get(node_id)
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
if not self.call_graph:
return CodeGraphContext(node_id=node_id)
return self.call_graph.expand(node_id, max_hops=max_hops)
def summary(self) -> Dict[str, Any]:
return {
"function_count": len(self.functions),
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
"vector_search_enabled": self.vector_search_enabled,
"vector_search_disabled_reason": self.vector_search_disabled_reason,
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
"graph_edges": len(self.graph_data.get("edges", []) or []),
}

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Iterable, List
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
def _clip(value: str, limit: int) -> str:
value = value or ""
if len(value) <= limit:
return value
return value[:limit].rstrip() + "\n...[truncated]"
def format_evidence_context(
hits: Iterable[CodeSearchHit],
graph_contexts: Iterable[CodeGraphContext],
max_code_chars: int = 1000,
max_text_chars: int = 800,
) -> str:
context_by_node = {item.node_id: item for item in graph_contexts}
blocks: List[str] = []
for hit in hits:
item = hit.evidence
graph = context_by_node.get(item.node_id)
lines = [
f"[Function Evidence #{hit.rank}]",
f"node_id: {item.node_id}",
f"name: {item.name}",
f"qualified_name: {item.qualified_name}",
f"file: {item.file}",
f"lines: {item.start_line}-{item.end_line}",
f"similarity: {hit.similarity:.4f}",
f"signature: {_clip(item.signature, 300)}",
f"summary: {_clip(item.summary, max_text_chars)}",
f"logic_flow: {_clip(item.logic_flow, max_text_chars)}",
f"calls: {', '.join(item.calls[:20]) or '-'}",
f"called_by: {', '.join(item.called_by[:20]) or '-'}",
]
if graph:
lines.append(f"call_chain_evidence: {'; '.join(graph.call_chains[:12]) or '-'}")
if item.code_snippet:
lines.extend(["code_snippet:", _clip(item.code_snippet, max_code_chars)])
blocks.append("\n".join(lines))
return "\n\n---\n\n".join(blocks)

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from collections import defaultdict, deque
from typing import Callable, Dict, Iterable, List, Optional, Set
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
class CodeCallGraph:
def __init__(
self,
graph_data: Optional[dict],
functions: Iterable[CodeFunctionEvidence],
) -> None:
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
item.node_id: item for item in functions
}
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
for item in self.functions_by_id.values():
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
if name:
self.name_to_ids[name].append(item.node_id)
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
self._load_graph_edges(graph_data or {})
self._load_metadata_edges()
def _load_graph_edges(self, graph_data: dict) -> None:
for edge in graph_data.get("edges", []) or []:
if edge.get("type") != "CALLS":
continue
source_id = edge.get("source_id")
target_id = edge.get("target_id")
if source_id in self.functions_by_id and target_id in self.functions_by_id:
self.calls_by_id[source_id].add(target_id)
self.called_by_id[target_id].add(source_id)
def _resolve_name(self, value: str) -> List[str]:
if not value:
return []
if value in self.functions_by_id:
return [value]
if value.startswith("Function:") and value in self.functions_by_id:
return [value]
return self.name_to_ids.get(value, [])
def _load_metadata_edges(self) -> None:
for function in self.functions_by_id.values():
for callee in function.calls:
for target_id in self._resolve_name(callee):
if target_id != function.node_id:
self.calls_by_id[function.node_id].add(target_id)
self.called_by_id[target_id].add(function.node_id)
for caller in function.called_by:
for source_id in self._resolve_name(caller):
if source_id != function.node_id:
self.called_by_id[function.node_id].add(source_id)
self.calls_by_id[source_id].add(function.node_id)
def _bfs(
self,
start_id: str,
max_hops: int,
next_nodes: Callable[[str], Iterable[str]],
) -> List[CodeFunctionEvidence]:
seen = {start_id}
result: List[CodeFunctionEvidence] = []
queue = deque([(start_id, 0)])
while queue:
current_id, depth = queue.popleft()
if depth >= max_hops:
continue
for next_id in sorted(next_nodes(current_id)):
if next_id in seen:
continue
seen.add(next_id)
function = self.functions_by_id.get(next_id)
if function:
result.append(function)
queue.append((next_id, depth + 1))
return result
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
center = self.functions_by_id.get(node_id)
center_name = center.name if center else node_id
call_chains: List[str] = []
for caller in callers[:5]:
call_chains.append(f"{caller.name} -> {center_name}")
for callee in callees[:5]:
call_chains.append(f"{center_name} -> {callee.name}")
for caller in callers[:3]:
for callee in callees[:3]:
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
return CodeGraphContext(
node_id=node_id,
callers=callers,
callees=callees,
call_chains=list(dict.fromkeys(call_chains)),
)

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import List
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.schema import CodeSearchHit
class CodeFunctionRetriever:
def __init__(self, adapter: CodeKnowledgeBaseAdapter) -> None:
self.adapter = adapter
def retrieve(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
return self.adapter.search_functions(
query=query,
top_k=top_k,
min_similarity=min_similarity,
)

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class CodeFunctionEvidence:
node_id: str
name: str
qualified_name: str
file: str
start_line: Optional[int] = None
end_line: Optional[int] = None
signature: str = ""
summary: str = ""
logic_flow: str = ""
code_snippet: str = ""
calls: List[str] = field(default_factory=list)
called_by: List[str] = field(default_factory=list)
includes: List[str] = field(default_factory=list)
embedding_model: str = ""
embedding_dim: int = 0
embedding_available: bool = True
raw: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class CodeSearchHit:
evidence: CodeFunctionEvidence
similarity: float
distance: float
rank: int
def to_dict(self) -> Dict[str, Any]:
data = self.evidence.to_dict()
data.update(
{
"similarity": self.similarity,
"distance": self.distance,
"rank": self.rank,
}
)
return data
@dataclass
class CodeGraphContext:
node_id: str
callers: List[CodeFunctionEvidence] = field(default_factory=list)
callees: List[CodeFunctionEvidence] = field(default_factory=list)
call_chains: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"node_id": self.node_id,
"callers": [item.to_dict() for item in self.callers],
"callees": [item.to_dict() for item in self.callees],
"call_chains": self.call_chains,
}

View File

@@ -0,0 +1,4 @@
from app.services.consistency.comparator import ConsistencyComparator
__all__ = ["ConsistencyComparator"]

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import json
import logging
import re
from typing import Any, Dict, Iterable, List, Optional
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.formatter import format_evidence_context
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
from app.services.consistency.prompt import build_judgment_prompt, build_requirement_query
from app.services.consistency.schema import ConsistencyResultItem, RequirementSnapshot, VERDICTS
from app.services.consistency.scorer import coverage_score
logger = logging.getLogger(__name__)
def _clip(value: str, limit: int) -> str:
text = value or ""
if len(text) <= limit:
return text
return text[:limit].rstrip() + "\n...[truncated]"
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if str(item).strip()]
if isinstance(value, tuple):
return [str(item) for item in value if str(item).strip()]
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
return _as_list(parsed)
except json.JSONDecodeError:
return [line.strip() for line in text.splitlines() if line.strip()]
return [str(value)]
def requirement_to_snapshot(requirement: Any) -> RequirementSnapshot:
getter = requirement.get if isinstance(requirement, dict) else lambda key, default=None: getattr(requirement, key, default)
return RequirementSnapshot(
requirement_uid=getter("requirement_uid") or getter("id") or "",
title=getter("title") or "",
description=getter("description") or "",
acceptance_criteria=_as_list(getter("acceptance_criteria") or getter("acceptanceCriteria")),
requirement_type=getter("requirement_type") or getter("requirementType"),
section_title=getter("section_title") or getter("sectionTitle"),
interface_name=getter("interface_name") or getter("interfaceName"),
interface_type=getter("interface_type") or getter("interfaceType"),
data_source=getter("data_source") or getter("dataSource"),
data_destination=getter("data_destination") or getter("dataDestination"),
)
class ConsistencyComparator:
def __init__(
self,
code_kb_adapter: CodeKnowledgeBaseAdapter,
llm: Any = None,
use_llm: bool = True,
) -> None:
self.code_kb_adapter = code_kb_adapter
self.llm = llm
self.use_llm = use_llm
def compare_requirements(
self,
requirements: Iterable[Any],
top_k: int = 8,
max_call_hops: int = 2,
min_similarity: float = 0.55,
) -> List[ConsistencyResultItem]:
return [
self.compare_requirement(
requirement,
top_k=top_k,
max_call_hops=max_call_hops,
min_similarity=min_similarity,
)
for requirement in requirements
]
def compare_requirement(
self,
requirement: Any,
top_k: int = 8,
max_call_hops: int = 2,
min_similarity: float = 0.55,
) -> ConsistencyResultItem:
snapshot = requirement_to_snapshot(requirement)
query = build_requirement_query(snapshot)
hits = self.code_kb_adapter.search_functions(
query=query,
top_k=top_k,
min_similarity=min_similarity,
)
contexts = [
self.code_kb_adapter.expand_call_context(hit.evidence.node_id, max_hops=max_call_hops)
for hit in hits
]
if not hits:
judgment = self._missing_judgment("未找到满足相似度阈值的函数证据。")
elif not self.use_llm:
judgment = self._heuristic_judgment(hits, contexts)
else:
judgment = self._llm_judgment(snapshot, hits, contexts)
judgment = self._normalize_judgment(judgment)
judgment["requirement_snapshot"] = snapshot.to_dict()
score = coverage_score(snapshot, hits, contexts, judgment)
matched_functions = [self._matched_function_payload(hit) for hit in hits]
call_chains = self._collect_call_chains(contexts)
return ConsistencyResultItem(
requirement_uid=snapshot.requirement_uid,
requirement_title=snapshot.title,
requirement_type=snapshot.requirement_type,
requirement_text=snapshot.description,
verdict=judgment["verdict"],
coverage_score=score,
confidence=float(judgment.get("confidence") or 0.0),
matched_functions=matched_functions,
covered_points=_as_list(judgment.get("covered_points")),
missing_points=_as_list(judgment.get("missing_points")),
conflict_points=_as_list(judgment.get("conflict_points")),
call_chain_evidence=call_chains,
suggestion=str(judgment.get("suggestion") or ""),
raw_judgment=judgment,
)
def _llm_judgment(
self,
requirement: RequirementSnapshot,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
) -> Dict[str, Any]:
try:
evidence_context = format_evidence_context(hits, contexts)
prompt = build_judgment_prompt(requirement, evidence_context)
from app.services.llm.llm_factory import LLMFactory
llm = self.llm or LLMFactory.create(temperature=0, streaming=False)
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
text = getattr(response, "content", response)
return self.parse_json_judgment(str(text))
except Exception as exc:
logger.exception("LLM consistency judgment failed: %s", exc)
return {
"verdict": "uncertain",
"confidence": 0.2,
"covered_points": [],
"missing_points": ["模型判定失败,无法可靠确认覆盖情况。"],
"conflict_points": [],
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
"reasoning": f"LLM judgment failed: {exc}",
"suggestion": "请检查模型配置,或人工复核匹配函数证据。",
"fallback": True,
}
@staticmethod
def parse_json_judgment(raw_text: str) -> Dict[str, Any]:
text = raw_text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
text = re.sub(r"```$", "", text).strip()
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if match:
return json.loads(match.group(0))
raise
def _heuristic_judgment(
self,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
) -> Dict[str, Any]:
best = hits[0].similarity if hits else 0.0
if best >= 0.78:
verdict = "partial"
confidence = min(0.68, best)
else:
verdict = "uncertain"
confidence = min(0.5, best)
return {
"verdict": verdict,
"confidence": confidence,
"covered_points": [],
"missing_points": ["未启用 LLM 判定,无法细分验收准则覆盖点。"],
"conflict_points": [],
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
"reasoning": "仅基于向量召回和调用图生成保守判定。",
"suggestion": "启用模型判定或人工复核主要匹配函数。",
"call_context_count": len(contexts),
}
def _missing_judgment(self, reason: str) -> Dict[str, Any]:
return {
"verdict": "missing",
"confidence": 0.75,
"covered_points": [],
"missing_points": [reason],
"conflict_points": [],
"primary_evidence": [],
"reasoning": reason,
"suggestion": "补充代码实现或降低阈值后重新召回,并人工确认是否存在命名差异。",
}
def _normalize_judgment(self, judgment: Dict[str, Any]) -> Dict[str, Any]:
verdict = str(judgment.get("verdict") or "uncertain").strip().lower()
if verdict not in VERDICTS:
verdict = "uncertain"
confidence = judgment.get("confidence", 0.0)
try:
confidence = max(0.0, min(1.0, float(confidence)))
except (TypeError, ValueError):
confidence = 0.0
normalized = dict(judgment)
normalized["verdict"] = verdict
normalized["confidence"] = confidence
normalized.setdefault("covered_points", [])
normalized.setdefault("missing_points", [])
normalized.setdefault("conflict_points", [])
normalized.setdefault("primary_evidence", [])
normalized.setdefault("reasoning", "")
normalized.setdefault("suggestion", "")
return normalized
def _matched_function_payload(self, hit: CodeSearchHit) -> Dict[str, Any]:
item = hit.evidence
return {
"node_id": item.node_id,
"name": item.name,
"file": item.file,
"start_line": item.start_line,
"end_line": item.end_line,
"similarity": round(hit.similarity, 4),
"role": item.summary[:120] if item.summary else "",
"evidence_summary": item.summary,
"logic_flow": _clip(item.logic_flow, 1200),
"code_snippet": _clip(item.code_snippet, 2000),
"calls": item.calls[:20],
"called_by": item.called_by[:20],
"signature": item.signature,
}
def _collect_call_chains(self, contexts: List[CodeGraphContext]) -> List[str]:
chains: List[str] = []
for context in contexts:
chains.extend(context.call_chains)
return list(dict.fromkeys(chains))[:30]

View File

@@ -0,0 +1,134 @@
from __future__ import annotations
import io
import json
from typing import Any, Dict, Iterable, List
def normalize_result_dicts(results: Iterable[Any]) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
for item in results:
if hasattr(item, "to_dict"):
normalized.append(item.to_dict())
elif isinstance(item, dict):
normalized.append(item)
else:
normalized.append(
{
"requirement_uid": getattr(item, "requirement_uid", ""),
"verdict": getattr(item, "verdict", ""),
"coverage_score": getattr(item, "coverage_score", 0.0),
"confidence": getattr(item, "confidence", 0.0),
"matched_functions": getattr(item, "matched_functions", []),
"covered_points": getattr(item, "covered_points", []),
"missing_points": getattr(item, "missing_points", []),
"conflict_points": getattr(item, "conflict_points", []),
"call_chain_evidence": getattr(item, "call_chain_evidence", []),
"suggestion": getattr(item, "suggestion", ""),
"raw_judgment": getattr(item, "raw_judgment", {}),
}
)
return normalized
def export_json(results: Iterable[Any]) -> bytes:
return json.dumps(
{"results": normalize_result_dicts(results)},
ensure_ascii=False,
indent=2,
).encode("utf-8")
def export_markdown(results: Iterable[Any]) -> str:
rows = normalize_result_dicts(results)
lines = [
"# 需求代码一致性比对报告",
"",
"| 需求 ID | 判定 | 覆盖分 | 置信度 | 匹配函数 | 缺失点 | 建议 |",
"| --- | --- | ---: | ---: | ---: | ---: | --- |",
]
for item in rows:
lines.append(
"| {uid} | {verdict} | {score:.2f} | {confidence:.2f} | {functions} | {missing} | {suggestion} |".format(
uid=item.get("requirement_uid", ""),
verdict=item.get("verdict", ""),
score=float(item.get("coverage_score") or 0),
confidence=float(item.get("confidence") or 0),
functions=len(item.get("matched_functions") or []),
missing=len(item.get("missing_points") or []),
suggestion=str(item.get("suggestion") or "").replace("|", "/"),
)
)
for item in rows:
lines.extend(
[
"",
f"## {item.get('requirement_uid', '')} {item.get('requirement_title', '')}",
"",
f"- 判定: `{item.get('verdict', '')}`",
f"- 覆盖分: {float(item.get('coverage_score') or 0):.2f}",
f"- 置信度: {float(item.get('confidence') or 0):.2f}",
f"- 建议: {item.get('suggestion') or '-'}",
"",
"### 匹配函数",
]
)
for function in item.get("matched_functions") or []:
lines.append(
f"- `{function.get('name')}` {function.get('file')}:{function.get('start_line')} "
f"(similarity={float(function.get('similarity') or 0):.2f})"
)
lines.extend(["", "### 缺失点"])
for point in item.get("missing_points") or ["-"]:
lines.append(f"- {point}")
if item.get("conflict_points"):
lines.extend(["", "### 冲突点"])
for point in item.get("conflict_points") or []:
lines.append(f"- {point}")
return "\n".join(lines)
def export_excel(results: Iterable[Any]) -> bytes:
try:
from openpyxl import Workbook
except ImportError as exc:
raise RuntimeError("openpyxl is required to export Excel reports.") from exc
rows = normalize_result_dicts(results)
workbook = Workbook()
sheet = workbook.active
sheet.title = "Consistency"
headers = [
"需求ID",
"需求标题",
"需求类型",
"判定",
"覆盖分",
"置信度",
"匹配函数数量",
"主要文件",
"缺失点数量",
"建议",
]
sheet.append(headers)
for item in rows:
functions = item.get("matched_functions") or []
sheet.append(
[
item.get("requirement_uid", ""),
item.get("requirement_title", ""),
item.get("requirement_type", ""),
item.get("verdict", ""),
item.get("coverage_score", 0),
item.get("confidence", 0),
len(functions),
functions[0].get("file", "") if functions else "",
len(item.get("missing_points") or []),
item.get("suggestion", ""),
]
)
output = io.BytesIO()
workbook.save(output)
return output.getvalue()

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import json
from app.services.consistency.schema import RequirementSnapshot
SYSTEM_INSTRUCTION = """你是需求代码一致性审查助手。
只能基于输入的需求、验收准则、函数摘要、代码片段、调用链证据判断。
不得补充未给出的代码事实。
证据不足时输出 uncertain。
输出严格 JSON不要 Markdown。"""
def build_requirement_query(requirement: RequirementSnapshot) -> str:
parts = []
req_type = (requirement.requirement_type or "").lower()
if req_type == "interface":
parts.extend(
[
requirement.interface_name or "",
requirement.interface_type or "",
requirement.data_source or "",
requirement.data_destination or "",
requirement.description,
]
)
else:
parts.extend(
[
requirement.description,
"\n".join(requirement.acceptance_criteria),
requirement.section_title or "",
requirement.interface_name or "",
requirement.data_source or "",
requirement.data_destination or "",
]
)
return "\n".join(part for part in parts if part).strip()
def build_judgment_prompt(requirement: RequirementSnapshot, evidence_context: str) -> str:
payload = {
"requirement": requirement.to_dict(),
"evidence": evidence_context,
"output_schema": {
"verdict": "implemented | partial | missing | conflict | uncertain",
"confidence": 0.0,
"covered_points": [],
"missing_points": [],
"conflict_points": [],
"primary_evidence": [],
"reasoning": "brief reason based only on evidence",
"suggestion": "next action",
},
}
return SYSTEM_INSTRUCTION + "\n\n" + json.dumps(payload, ensure_ascii=False, indent=2)

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import argparse
from pathlib import Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run requirement-code consistency comparison.")
parser.add_argument("--srs-extraction-id", type=int, required=True)
parser.add_argument("--vector-path", required=True)
parser.add_argument("--metadata-path", required=True)
parser.add_argument("--graph-path", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--output-excel", default=None)
parser.add_argument("--output-markdown", default=None)
parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--max-call-hops", type=int, default=2)
parser.add_argument("--min-similarity", type=float, default=0.55)
parser.add_argument("--no-llm", action="store_true")
return parser.parse_args()
def main() -> int:
args = parse_args()
from app.db.session import SessionLocal
from app.models.tooling import SRSRequirement
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.consistency.comparator import ConsistencyComparator
from app.services.consistency.exporter import export_excel, export_json, export_markdown
adapter = CodeKnowledgeBaseAdapter()
adapter.load(args.vector_path, args.metadata_path, args.graph_path)
comparator = ConsistencyComparator(adapter, use_llm=not args.no_llm)
db = SessionLocal()
try:
requirements = (
db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == args.srs_extraction_id)
.order_by(SRSRequirement.sort_order)
.all()
)
results = comparator.compare_requirements(
requirements,
top_k=args.top_k,
max_call_hops=args.max_call_hops,
min_similarity=args.min_similarity,
)
finally:
db.close()
Path(args.output).write_bytes(export_json(results))
if args.output_markdown:
Path(args.output_markdown).write_text(export_markdown(results), encoding="utf-8")
if args.output_excel:
Path(args.output_excel).write_bytes(export_excel(results))
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
VERDICTS = {"implemented", "partial", "missing", "conflict", "uncertain"}
@dataclass
class RequirementSnapshot:
requirement_uid: str
title: str
description: str
acceptance_criteria: List[str] = field(default_factory=list)
requirement_type: Optional[str] = None
section_title: Optional[str] = None
interface_name: Optional[str] = None
interface_type: Optional[str] = None
data_source: Optional[str] = None
data_destination: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class ConsistencyResultItem:
requirement_uid: str
requirement_title: str
requirement_type: Optional[str]
requirement_text: str
verdict: str
coverage_score: float
confidence: float
matched_functions: List[Dict[str, Any]]
covered_points: List[str] = field(default_factory=list)
missing_points: List[str] = field(default_factory=list)
conflict_points: List[str] = field(default_factory=list)
call_chain_evidence: List[str] = field(default_factory=list)
suggestion: str = ""
raw_judgment: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
import re
from typing import Any, Dict, Iterable, List
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
from app.services.consistency.schema import RequirementSnapshot
def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
def _tokens(*values: str) -> List[str]:
text = " ".join(value or "" for value in values).lower()
return [item for item in re.split(r"[^a-z0-9_\u4e00-\u9fff]+", text) if len(item) >= 2]
def semantic_score(hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
top = max(hit.similarity for hit in hits)
avg = sum(hit.similarity for hit in hits[:3]) / min(3, len(hits))
return _clamp(top * 0.7 + avg * 0.3)
def acceptance_coverage_score(requirement: RequirementSnapshot, judgment: Dict[str, Any]) -> float:
criteria = requirement.acceptance_criteria or []
covered = judgment.get("covered_points") or []
missing = judgment.get("missing_points") or []
verdict = judgment.get("verdict")
if criteria:
if missing:
return _clamp((len(criteria) - min(len(missing), len(criteria))) / len(criteria))
if covered:
return _clamp(len(covered) / len(criteria))
return 1.0 if verdict == "implemented" else 0.4 if verdict == "partial" else 0.0
return {"implemented": 1.0, "partial": 0.55, "conflict": 0.25, "missing": 0.0}.get(verdict, 0.35)
def evidence_strength_score(hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
scores: List[float] = []
for hit in hits[:5]:
item = hit.evidence
checks = [
bool(item.file),
item.start_line is not None,
item.end_line is not None,
bool(item.summary),
bool(item.logic_flow),
bool(item.code_snippet),
]
scores.append(sum(1 for value in checks if value) / len(checks))
return _clamp(sum(scores) / len(scores))
def call_graph_score(contexts: Iterable[CodeGraphContext]) -> float:
contexts = list(contexts)
if not contexts:
return 0.0
scored = []
for context in contexts[:5]:
score = 0.0
if context.callers:
score += 0.35
if context.callees:
score += 0.35
if context.call_chains:
score += 0.30
scored.append(score)
return _clamp(sum(scored) / len(scored))
def exact_match_score(requirement: RequirementSnapshot, hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
important = _tokens(
requirement.interface_name or "",
requirement.interface_type or "",
requirement.data_source or "",
requirement.data_destination or "",
requirement.title or "",
)
if not important:
important = _tokens(requirement.description)[:12]
if not important:
return 0.0
evidence_text = " ".join(
f"{hit.evidence.name} {hit.evidence.qualified_name} {hit.evidence.summary} {hit.evidence.logic_flow}"
for hit in hits[:5]
).lower()
matched = sum(1 for token in important if token.lower() in evidence_text)
return _clamp(matched / len(important))
def coverage_score(
requirement: RequirementSnapshot,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
judgment: Dict[str, Any],
) -> float:
score = (
semantic_score(hits) * 0.25
+ acceptance_coverage_score(requirement, judgment) * 0.30
+ evidence_strength_score(hits) * 0.20
+ call_graph_score(contexts) * 0.15
+ exact_match_score(requirement, hits) * 0.10
)
verdict = judgment.get("verdict")
if verdict == "missing":
score = min(score, 0.25)
elif verdict == "uncertain":
score = min(score, 0.55)
elif verdict == "conflict":
score = min(score, 0.45)
return round(_clamp(score), 4)

View File

@@ -0,0 +1,711 @@
from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
import zipfile
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import (
CodeKnowledgeBase,
ConsistencyJob,
ConsistencyResult,
SRSExtraction,
SRSRequirement,
ToolJob,
)
from app.schemas.consistency import CodeKnowledgeBaseCreate, ConsistencyJobCreate
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter, read_code_kb_summary
from app.services.code_kb.formatter import format_evidence_context
from app.services.consistency.comparator import ConsistencyComparator
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.srs_job_service import _build_internal_title, _parse_generated_at
from app.tools.srs_reqs_qwen import get_srs_tool
CODE_UPLOAD_ROOT = Path("uploads") / "code_kbs"
AUTO_UPLOAD_ROOT = Path("uploads") / "consistency_auto"
def _workspace_root() -> Path:
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "rag-web-ui").exists() and (parent / "RAG-TEST-TOOLS").exists():
return parent
return current.parents[4]
def _rag_test_tools_root() -> Path:
candidates = [
_workspace_root() / "RAG-TEST-TOOLS",
Path(__file__).resolve().parents[3] / "RAG-TEST-TOOLS",
Path.cwd().parent / "RAG-TEST-TOOLS",
]
for candidate in candidates:
if candidate.exists():
return candidate.resolve()
return candidates[0]
def safe_upload_name(file_name: str | None, fallback: str = "upload.bin") -> str:
safe_name = Path(file_name or fallback).name
return safe_name or fallback
def ensure_upload_dir(*parts: str) -> Path:
path = Path("uploads").joinpath(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def save_uploaded_bytes(target_dir: Path, file_name: str, content: bytes) -> Path:
target_dir.mkdir(parents=True, exist_ok=True)
path = target_dir / safe_upload_name(file_name)
path.write_bytes(content)
if path.suffix.lower() == ".zip":
extract_dir = target_dir / path.stem
extract_zip_safe(path, extract_dir)
return extract_dir
return path
def extract_zip_safe(zip_path: Path, target_dir: Path) -> None:
target_dir.mkdir(parents=True, exist_ok=True)
target_root = target_dir.resolve()
with zipfile.ZipFile(zip_path) as archive:
for member in archive.infolist():
member_path = (target_dir / member.filename).resolve()
try:
member_path.relative_to(target_root)
except ValueError as exc:
raise ValueError(f"Unsafe zip entry: {member.filename}")
archive.extractall(target_dir)
def _build_code_kb_artifacts(
project_path: str,
output_dir: str,
base_name: str,
use_semantic: bool,
model_profile: Any = None,
) -> Dict[str, str]:
tools_root = _rag_test_tools_root()
if not tools_root.exists():
raise FileNotFoundError(f"RAG-TEST-TOOLS not found: {tools_root}")
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
"-m",
"rag_test_tools.build_code_kb",
"--project",
str(Path(project_path).resolve()),
"--output",
str(output_path),
"--base-name",
base_name,
]
if not use_semantic:
command.append("--skip-semantic")
env = os.environ.copy()
if model_profile is not None:
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", "") or ""
if api_key:
env["DASHSCOPE_API_KEY"] = api_key
env["DASH_SCOPE_API_KEY"] = api_key
env["QWEN_API_KEY"] = api_key
if api_base:
env["DASH_SCOPE_API_BASE"] = api_base
env["QWEN_API_URL"] = api_base
if getattr(model_profile, "chat_model", None):
env["QWEN_CHAT_MODEL"] = model_profile.chat_model
if getattr(model_profile, "embedding_model", None):
env["QWEN_EMBEDDING_MODEL"] = model_profile.embedding_model
completed = subprocess.run(
command,
cwd=str(tools_root),
env=env,
capture_output=True,
text=True,
timeout=3600,
check=False,
)
if completed.returncode != 0:
raise RuntimeError(
"Code knowledge base build failed: "
f"{completed.stderr or completed.stdout or completed.returncode}"
)
try:
return json.loads(completed.stdout)
except json.JSONDecodeError as exc:
raise RuntimeError(f"Code KB build returned invalid JSON: {completed.stdout}") from exc
def _ensure_paths_exist(paths: Iterable[str]) -> None:
missing = [path for path in paths if not path or not Path(path).exists()]
if missing:
raise FileNotFoundError(f"Code knowledge base file path does not exist: {missing}")
def create_code_kb(db: Session, user_id: int, payload: CodeKnowledgeBaseCreate) -> CodeKnowledgeBase:
_ensure_paths_exist([payload.vector_path, payload.metadata_path, payload.graph_path])
adapter = CodeKnowledgeBaseAdapter()
adapter.load(payload.vector_path, payload.metadata_path, payload.graph_path)
summary = {
**read_code_kb_summary(payload.metadata_path, payload.graph_path),
**adapter.summary(),
}
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=payload.name,
project_path=payload.project_path,
vector_path=payload.vector_path,
metadata_path=payload.metadata_path,
graph_path=payload.graph_path,
status="active",
metadata_summary=summary,
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def create_uploaded_code_kb(
db: Session,
user_id: int,
name: str,
project_path: str,
output_dir: str,
) -> CodeKnowledgeBase:
base_name = f"code_kb_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}"
output_path = Path(output_dir).resolve()
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=name,
project_path=project_path,
vector_path=str(output_path / f"{base_name}_rag.faiss"),
metadata_path=str(output_path / f"{base_name}_rag_metadata.json"),
graph_path=str(output_path / f"{base_name}_code_knowledge_graph.json"),
status="pending",
metadata_summary={"base_name": base_name, "output_dir": str(output_path)},
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def run_code_kb_build(code_kb_id: int, use_semantic: bool = True) -> None:
db = SessionLocal()
try:
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb_id).first()
if not code_kb:
return
model_profile = None
if use_semantic:
model_profile = ModelConfigService.require_active_config(db, code_kb.user_id)
ModelConfigService.touch_last_used(db, model_profile)
code_kb.status = "processing"
db.add(code_kb)
db.commit()
summary = code_kb.metadata_summary or {}
base_name = summary.get("base_name") or f"code_kb_{code_kb.id}"
output_dir = summary.get("output_dir") or str(Path(code_kb.vector_path).parent)
artifact_paths = _build_code_kb_artifacts(
project_path=code_kb.project_path or "",
output_dir=output_dir,
base_name=base_name,
use_semantic=use_semantic,
model_profile=model_profile,
)
code_kb.graph_path = artifact_paths["graph_path"]
code_kb.vector_path = artifact_paths["vector_path"]
code_kb.metadata_path = artifact_paths["metadata_path"]
embedding_function = (
EmbeddingsFactory.create(model_profile=model_profile)
if model_profile is not None
else None
)
adapter = CodeKnowledgeBaseAdapter(embedding_function=embedding_function)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
code_kb.status = "active"
code_kb.metadata_summary = {
**read_code_kb_summary(code_kb.metadata_path, code_kb.graph_path),
**adapter.summary(),
"source": "upload",
}
db.add(code_kb)
db.commit()
except Exception as exc:
if "code_kb" in locals() and code_kb:
code_kb.status = "failed"
code_kb.metadata_summary = {
**(code_kb.metadata_summary or {}),
"error_message": str(exc)[:2000],
}
db.add(code_kb)
db.commit()
finally:
db.close()
def list_code_kbs(db: Session, user_id: int) -> List[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.user_id == user_id)
.order_by(CodeKnowledgeBase.created_at.desc())
.all()
)
def get_owned_code_kb(db: Session, user_id: int, code_kb_id: int) -> Optional[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.id == code_kb_id, CodeKnowledgeBase.user_id == user_id)
.first()
)
def get_owned_srs_extraction(db: Session, user_id: int, extraction_id: int) -> Optional[SRSExtraction]:
return (
db.query(SRSExtraction)
.join(ToolJob, SRSExtraction.job_id == ToolJob.id)
.filter(SRSExtraction.id == extraction_id, ToolJob.user_id == user_id)
.first()
)
def create_consistency_job(
db: Session,
user_id: int,
payload: ConsistencyJobCreate,
) -> ConsistencyJob:
extraction = get_owned_srs_extraction(db, user_id, payload.srs_extraction_id)
if not extraction:
raise ValueError("SRS extraction does not exist.")
code_kb = get_owned_code_kb(db, user_id, payload.code_kb_id)
if not code_kb:
raise ValueError("Code knowledge base does not exist.")
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
requirement_query = db.query(SRSRequirement).filter(SRSRequirement.extraction_id == extraction.id)
if payload.requirement_uids:
requirement_query = requirement_query.filter(SRSRequirement.requirement_uid.in_(payload.requirement_uids))
total = requirement_query.count()
if total == 0:
raise ValueError("No SRS requirements matched the selected scope.")
job = ConsistencyJob(
user_id=user_id,
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
status="pending",
total_requirements=total,
completed_requirements=0,
output_summary={
"requirement_uids": payload.requirement_uids,
"top_k": payload.top_k,
"max_call_hops": payload.max_call_hops,
"min_similarity": payload.min_similarity,
"use_llm": payload.use_llm,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def list_consistency_jobs(db: Session, user_id: int) -> List[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.user_id == user_id)
.order_by(ConsistencyJob.created_at.desc())
.all()
)
def get_owned_consistency_job(db: Session, user_id: int, job_id: int) -> Optional[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.id == job_id, ConsistencyJob.user_id == user_id)
.first()
)
def ask_code_kb(
code_kb: CodeKnowledgeBase,
question: str,
top_k: int = 6,
min_similarity: float = 0.0,
use_llm: bool = True,
model_profile: Any = None,
) -> Dict[str, Any]:
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
if model_profile is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
hits = adapter.search_functions(question, top_k=top_k, min_similarity=min_similarity)
contexts = [adapter.expand_call_context(hit.evidence.node_id, max_hops=2) for hit in hits]
evidence = [hit.to_dict() for hit in hits]
if not hits:
return {
"answer": "未检索到相关函数证据,无法基于代码知识库回答。",
"evidence": [],
"raw_response": None,
}
evidence_context = format_evidence_context(hits, contexts)
if not use_llm:
return {
"answer": "已检索到相关函数证据,请查看 evidence 字段中的函数摘要、文件位置和调用链。",
"evidence": evidence,
"raw_response": None,
}
prompt = (
"你是代码知识库问答助手。只能基于给定代码证据回答问题;"
"如果证据不足,请明确说明不足。回答需要包含关键函数名和文件位置。\n\n"
f"问题:{question}\n\n代码证据:\n{evidence_context}"
)
try:
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
answer = str(getattr(response, "content", response))
return {"answer": answer, "evidence": evidence, "raw_response": answer}
except Exception as exc:
return {
"answer": f"模型问答失败,已返回检索证据供人工查看。错误:{exc}",
"evidence": evidence,
"raw_response": None,
}
def result_model_to_export_dict(result: ConsistencyResult) -> Dict[str, Any]:
raw = result.raw_judgment or {}
return {
"requirement_uid": result.requirement_uid,
"requirement_title": raw.get("requirement_title", ""),
"requirement_type": raw.get("requirement_type"),
"requirement_text": raw.get("requirement_text", ""),
"verdict": result.verdict,
"coverage_score": result.coverage_score,
"confidence": result.confidence,
"matched_functions": result.matched_functions or [],
"covered_points": result.covered_points or [],
"missing_points": result.missing_points or [],
"conflict_points": result.conflict_points or [],
"call_chain_evidence": result.call_chain_evidence or [],
"suggestion": result.suggestion or "",
"raw_judgment": raw,
}
def _store_result(db: Session, job: ConsistencyJob, result: Any) -> None:
result_dict = result.to_dict()
raw_judgment = dict(result_dict.get("raw_judgment") or {})
raw_judgment.update(
{
"requirement_title": result_dict.get("requirement_title"),
"requirement_type": result_dict.get("requirement_type"),
"requirement_text": result_dict.get("requirement_text"),
}
)
db.add(
ConsistencyResult(
job_id=job.id,
requirement_uid=result.requirement_uid,
verdict=result.verdict,
coverage_score=result.coverage_score,
confidence=result.confidence,
matched_functions=result.matched_functions,
covered_points=result.covered_points,
missing_points=result.missing_points,
conflict_points=result.conflict_points,
call_chain_evidence=result.call_chain_evidence,
suggestion=result.suggestion,
raw_judgment=raw_judgment,
)
)
def _create_srs_extraction_for_job(db: Session, job: ToolJob) -> SRSExtraction:
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for index, item in enumerate(payload.get("requirements", [])):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item.get("id") or f"REQ-{index + 1:03d}",
title=_build_internal_title(item.get("description"), item.get("id") or "", index),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or index),
)
db.add(requirement)
return extraction
def create_auto_consistency_tool_job(
db: Session,
user_id: int,
requirement_file_path: str,
requirement_file_name: str,
code_source_dir: str,
code_kb_name: str,
top_k: int,
max_call_hops: int,
min_similarity: float,
use_llm: bool,
use_semantic: bool,
) -> ToolJob:
job = ToolJob(
user_id=user_id,
tool_name="consistency.auto_compare",
status="pending",
input_file_name=requirement_file_name,
input_file_path=requirement_file_path,
output_summary={
"current_step": "pending",
"code_source_dir": code_source_dir,
"code_kb_name": code_kb_name,
"top_k": top_k,
"max_call_hops": max_call_hops,
"min_similarity": min_similarity,
"use_llm": use_llm,
"use_semantic": use_semantic,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def get_owned_auto_job(db: Session, user_id: int, job_id: int) -> Optional[ToolJob]:
return (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == user_id,
ToolJob.tool_name == "consistency.auto_compare",
)
.first()
)
def run_auto_consistency_job(tool_job_id: int) -> None:
db = SessionLocal()
try:
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job:
return
options = tool_job.output_summary or {}
tool_job.status = "processing"
tool_job.started_at = datetime.utcnow()
tool_job.output_summary = {**options, "current_step": "extracting_requirements"}
db.add(tool_job)
db.commit()
extraction = _create_srs_extraction_for_job(db, tool_job)
db.commit()
options = tool_job.output_summary or options
code_output_dir = str((AUTO_UPLOAD_ROOT / str(tool_job.id) / "code_kb").resolve())
code_kb = create_uploaded_code_kb(
db,
tool_job.user_id,
options.get("code_kb_name") or f"auto-code-kb-{tool_job.id}",
options["code_source_dir"],
code_output_dir,
)
tool_job.output_summary = {
**options,
"current_step": "building_code_kb",
"srs_extraction_id": extraction.id,
"code_kb_id": code_kb.id,
}
db.add(tool_job)
db.commit()
db.close()
run_code_kb_build(code_kb.id, use_semantic=bool(options.get("use_semantic", True)))
db = SessionLocal()
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not code_kb:
return
if code_kb.status != "active":
raise RuntimeError((code_kb.metadata_summary or {}).get("error_message") or "Code KB build failed.")
consistency_payload = ConsistencyJobCreate(
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
use_llm=bool(options.get("use_llm", True)),
)
consistency_job = create_consistency_job(db, tool_job.user_id, consistency_payload)
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "comparing",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
db.close()
run_consistency_job(consistency_job.id)
db = SessionLocal()
consistency_job = db.query(ConsistencyJob).filter(ConsistencyJob.id == consistency_job.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not consistency_job:
return
if consistency_job.status == "failed":
raise RuntimeError(consistency_job.error_message or "Consistency comparison failed.")
tool_job.status = "completed"
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "completed",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
except Exception as exc:
if "db" not in locals() or db is None:
db = SessionLocal()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if tool_job:
tool_job.status = "failed"
tool_job.error_message = str(exc)[:2000]
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "failed",
}
db.add(tool_job)
db.commit()
finally:
db.close()
def run_consistency_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ConsistencyJob).filter(ConsistencyJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
db.add(job)
db.commit()
options = job.output_summary or {}
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == job.code_kb_id).first()
if not code_kb:
raise RuntimeError("Code knowledge base does not exist.")
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
llm = None
if bool(options.get("use_llm", True)):
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
comparator = ConsistencyComparator(
adapter,
llm=llm,
use_llm=bool(options.get("use_llm", True)),
)
query = (
db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == job.srs_extraction_id)
.order_by(SRSRequirement.sort_order)
)
requirement_uids = options.get("requirement_uids")
if requirement_uids:
query = query.filter(SRSRequirement.requirement_uid.in_(requirement_uids))
requirements = query.all()
job.total_requirements = len(requirements)
db.add(job)
db.commit()
verdict_counter: Counter[str] = Counter()
for requirement in requirements:
result = comparator.compare_requirement(
requirement,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
)
_store_result(db, job, result)
verdict_counter[result.verdict] += 1
job.completed_requirements += 1
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
except Exception as exc:
if "job" in locals() and job:
job.status = "failed"
job.error_message = str(exc)
job.completed_at = datetime.utcnow()
db.add(job)
db.commit()
finally:
db.close()

View File

@@ -6,6 +6,7 @@ import traceback
import json
from app.db.session import SessionLocal
from io import BytesIO
from types import SimpleNamespace
from typing import Optional, List, Dict, Any
from fastapi import UploadFile
from langchain_community.document_loaders import (
@@ -26,6 +27,7 @@ from minio.error import MinioException
from minio.commonconfig import CopySource
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.model_config import ModelConfigService
class UploadResult(BaseModel):
file_path: str
@@ -120,7 +122,45 @@ def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> D
return sanitized
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
def _resolve_model_profile(db: Session, user_id: Optional[int]) -> Any:
if user_id is None:
return None
return ModelConfigService.require_active_config(db, user_id)
def _model_profile_snapshot(model_profile: Any) -> Any:
if model_profile is None:
return None
return SimpleNamespace(
provider=model_profile.provider,
api_key=model_profile.api_key,
api_base=model_profile.api_base,
chat_model=model_profile.chat_model,
embedding_model=model_profile.embedding_model,
)
def _load_model_profile_for_user(user_id: Optional[int]) -> Any:
if user_id is None:
return None
db = SessionLocal()
try:
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
return _model_profile_snapshot(model_profile)
finally:
db.close()
async def process_document(
file_path: str,
file_name: str,
kb_id: int,
document_id: int,
chunk_size: int = 1000,
chunk_overlap: int = 200,
user_id: Optional[int] = None,
) -> None:
"""Process document and store in vector database with incremental updates"""
logger = logging.getLogger(__name__)
@@ -129,7 +169,8 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
# Initialize embeddings
logger.info("Initializing OpenAI embeddings...")
embeddings = EmbeddingsFactory.create()
model_profile = _load_model_profile_for_user(user_id)
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
vector_store = VectorStoreFactory.create(
@@ -202,7 +243,7 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
graph_adapter = GraphRAGAdapter()
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info("GraphRAG ingestion completed in incremental processing")
@@ -323,7 +364,8 @@ async def process_document_background(
task_id: int,
db: Session = None,
chunk_size: int = 1000,
chunk_overlap: int = 200
chunk_overlap: int = 200,
user_id: Optional[int] = None,
) -> None:
"""Process document in background"""
logger = logging.getLogger(__name__)
@@ -348,6 +390,9 @@ async def process_document_background(
logger.info(f"Task {task_id}: Setting status to processing")
task.status = "processing"
db.commit()
model_profile = _resolve_model_profile(db, user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
# 1. 从临时目录下载文件
minio_client = get_minio_client()
@@ -416,7 +461,7 @@ async def process_document_background(
# 3. 创建向量存储
logger.info(f"Task {task_id}: Initializing vector store")
embeddings = EmbeddingsFactory.create()
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
@@ -520,7 +565,7 @@ async def process_document_background(
from app.services.graph.graphrag_adapter import GraphRAGAdapter
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
graph_adapter = GraphRAGAdapter()
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info(f"Task {task_id}: GraphRAG ingestion completed")

View File

@@ -1,30 +1,39 @@
from app.core.config import settings
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from typing import Optional
# If you plan on adding other embeddings, import them here
# from some_other_module import AnotherEmbeddingClass
class EmbeddingsFactory:
@staticmethod
def create():
def create(provider: Optional[str] = None, model_profile: Optional[object] = None):
"""
Factory method to create an embeddings instance based on .env config.
"""
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
if model_profile is not None:
embeddings_provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", None) or _default_api_base(embeddings_provider)
model = getattr(model_profile, "embedding_model", None) or _default_embedding_model(embeddings_provider)
else:
embeddings_provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
api_key = _default_api_key(embeddings_provider)
api_base = _default_api_base(embeddings_provider)
model = _default_embedding_model(embeddings_provider)
if embeddings_provider == "openai":
return OpenAIEmbeddings(
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE,
model=settings.OPENAI_EMBEDDINGS_MODEL
openai_api_key=api_key,
openai_api_base=api_base,
model=model
)
elif embeddings_provider == "dashscope":
elif embeddings_provider in {"dashscope", "openai_compatible"}:
return OpenAIEmbeddings(
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
model=settings.DASH_SCOPE_EMBEDDINGS_MODEL,
openai_api_key=api_key,
openai_api_base=api_base,
model=model,
# DashScope OpenAI-compatible embedding expects string input,
# while LangChain's len-safe path may send token ids.
check_embedding_ctx_length=False,
@@ -35,8 +44,8 @@ class EmbeddingsFactory:
)
elif embeddings_provider == "ollama":
return OllamaEmbeddings(
model=settings.OLLAMA_EMBEDDINGS_MODEL,
base_url=settings.OLLAMA_API_BASE
model=model,
base_url=api_base
)
# Extend with other providers:
@@ -44,3 +53,34 @@ class EmbeddingsFactory:
# return AnotherEmbeddingClass(...)
else:
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")
def _default_embedding_model(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_EMBEDDINGS_MODEL
if provider == "dashscope":
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4"
if provider == "ollama":
return settings.OLLAMA_EMBEDDINGS_MODEL
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or settings.OPENAI_EMBEDDINGS_MODEL
def _default_api_key(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_KEY
if provider == "dashscope":
return settings.DASH_SCOPE_API_KEY
return settings.API_KEY
def _default_api_base(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_BASE
if provider == "dashscope":
return settings.DASH_SCOPE_API_BASE
if provider == "ollama":
return settings.OLLAMA_API_BASE
return settings.DASH_SCOPE_API_BASE

View File

@@ -13,11 +13,11 @@ from app.services.llm.llm_factory import LLMFactory
class GraphRAGAdapter:
_instance_lock = asyncio.Lock()
def __init__(self):
def __init__(self, model_profile: Any = None):
self._graphrag_instances: Dict[int, Any] = {}
self._kb_locks: Dict[int, asyncio.Lock] = {}
self._embedding_model = EmbeddingsFactory.create()
self._llm_model = LLMFactory.create(streaming=False)
self._embedding_model = EmbeddingsFactory.create(model_profile=model_profile)
self._llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
self._symbols = self._load_symbols()
def _load_symbols(self) -> Dict[str, Any]:

View File

@@ -11,42 +11,51 @@ class LLMFactory:
provider: Optional[str] = None,
temperature: float = 0,
streaming: bool = True,
model_profile: Optional[object] = None,
) -> BaseChatModel:
"""
Create a LLM instance based on the provider
"""
# If no provider specified, use the one from settings
provider = provider or settings.CHAT_PROVIDER
if model_profile is not None:
provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
model = getattr(model_profile, "chat_model", None) or _default_chat_model(provider)
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", None) or _default_api_base(provider)
else:
provider = provider or settings.CHAT_PROVIDER
model = _default_chat_model(provider)
api_key = _default_api_key(provider)
api_base = _default_api_base(provider)
if provider.lower() == "openai":
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.OPENAI_MODEL,
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE
model=model,
openai_api_key=api_key,
openai_api_base=api_base
)
elif provider.lower() == "deepseek":
return ChatDeepSeek(
temperature=temperature,
streaming=streaming,
model=settings.DEEPSEEK_MODEL,
api_key=settings.DEEPSEEK_API_KEY,
api_base=settings.DEEPSEEK_API_BASE
model=model,
api_key=api_key,
api_base=api_base
)
elif provider.lower() == "dashscope":
elif provider.lower() in {"dashscope", "openai_compatible"}:
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.DASH_SCOPE_CHAT_MODEL,
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
model=model,
openai_api_key=api_key,
openai_api_base=api_base,
)
elif provider.lower() == "ollama":
# Initialize Ollama model
return OllamaLLM(
model=settings.OLLAMA_MODEL,
base_url=settings.OLLAMA_API_BASE,
model=model,
base_url=api_base,
temperature=temperature,
streaming=streaming
)
@@ -54,4 +63,41 @@ class LLMFactory:
# elif provider.lower() == "anthropic":
# return ChatAnthropic(...)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
raise ValueError(f"Unsupported LLM provider: {provider}")
def _default_chat_model(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_MODEL
if provider == "deepseek":
return settings.DEEPSEEK_MODEL
if provider == "dashscope":
return settings.DASH_SCOPE_CHAT_MODEL
if provider == "ollama":
return settings.OLLAMA_MODEL
return settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
def _default_api_key(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_KEY
if provider == "deepseek":
return settings.DEEPSEEK_API_KEY
if provider == "dashscope":
return settings.DASH_SCOPE_API_KEY
return settings.API_KEY
def _default_api_base(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_BASE
if provider == "deepseek":
return settings.DEEPSEEK_API_BASE
if provider == "dashscope":
return settings.DASH_SCOPE_API_BASE
if provider == "ollama":
return settings.OLLAMA_API_BASE
return settings.DASH_SCOPE_API_BASE

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.model_config import UserModelConfig
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
{
"provider": "dashscope",
"label": "DashScope",
"default_api_base": settings.DASH_SCOPE_API_BASE,
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai",
"label": "OpenAI",
"default_api_base": settings.OPENAI_API_BASE,
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai_compatible",
"label": "OpenAI Compatible",
"default_api_base": "",
"default_chat_model": "qwen3-max",
"default_embedding_model": "text-embedding-v4",
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "ollama",
"label": "Ollama",
"default_api_base": settings.OLLAMA_API_BASE,
"default_chat_model": settings.OLLAMA_MODEL,
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
"requires_api_key": False,
"supports_custom_api_base": True,
},
]
def provider_options_response() -> Dict[str, Any]:
first = PROVIDER_OPTIONS[0]
return {
"providers": PROVIDER_OPTIONS,
"defaults": {
"provider": first["provider"],
"api_base": first["default_api_base"],
"chat_model": first["default_chat_model"],
"embedding_model": first["default_embedding_model"],
},
}
def _provider_option(provider: str) -> Dict[str, Any]:
normalized = (provider or "dashscope").strip().lower()
for option in PROVIDER_OPTIONS:
if option["provider"] == normalized:
return option
raise ValueError(f"Unsupported model provider: {provider}")
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
option = _provider_option(provider)
api_base = payload.get("api_base")
if api_base is None and existing is not None:
api_base = existing.api_base
if not api_base:
api_base = option["default_api_base"]
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
raise ValueError("OpenAI Compatible provider requires an API base URL.")
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
embedding_model = str(
payload.get("embedding_model")
or getattr(existing, "embedding_model", "")
or option["default_embedding_model"]
).strip()
if not chat_model:
raise ValueError("Chat model is required.")
if not embedding_model:
raise ValueError("Embedding model is required.")
api_key = payload.get("api_key")
if api_key is None and existing is not None:
api_key = existing.api_key
api_key = str(api_key or "").strip()
if option["requires_api_key"] and not api_key:
raise ValueError("API key is required for this provider.")
name = payload.get("name")
if name is None and existing is not None:
name = existing.name
name = str(name or "").strip()
if not name:
name = option["label"]
return {
"name": name,
"provider": provider,
"api_key": api_key,
"api_base": str(api_base).strip() if api_base else None,
"chat_model": chat_model,
"embedding_model": embedding_model,
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
}
class ModelConfigService:
@staticmethod
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id)
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
.all()
)
@staticmethod
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
.first()
)
@staticmethod
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
.order_by(UserModelConfig.updated_at.desc())
.first()
)
@staticmethod
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
config = ModelConfigService.get_active_config(db, user_id)
if config is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
return config
@staticmethod
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
data = _normalized_payload(payload.model_dump())
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, user_id)
item = UserModelConfig(user_id=user_id, **data)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def update_config(
db: Session,
item: UserModelConfig,
payload: ModelConfigUpdate,
) -> UserModelConfig:
raw = payload.model_dump(exclude_unset=True)
if raw.get("api_key") == "":
raw.pop("api_key")
data = _normalized_payload(raw, existing=item)
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
for field, value in data.items():
setattr(item, field, value)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def delete_config(db: Session, item: UserModelConfig) -> None:
db.delete(item)
db.commit()
@staticmethod
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
item.last_used_at = datetime.utcnow()
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
if exclude_id is not None:
query = query.filter(UserModelConfig.id != exclude_id)
query.update({UserModelConfig.is_active: False}, synchronize_session=False)

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.services.model_config import ModelConfigService
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
@@ -63,7 +64,9 @@ def run_srs_job(job_id: int) -> None:
job.error_message = None
db.commit()
payload = get_srs_tool().run(job.input_file_path)
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
extraction = SRSExtraction(
job_id=job.id,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, List
@@ -11,10 +12,15 @@ from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
logger = logging.getLogger(__name__)
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
@@ -22,8 +28,15 @@ def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, An
return items
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
def _build_kb_vector_stores(
db: Session,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
if model_profile is None:
return []
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -47,8 +60,9 @@ def _resolve_knowledge_context(
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
model_profile: Any,
) -> str:
if knowledge_base_id is None:
if knowledge_base_id is None or model_profile is None:
return ""
try:
@@ -60,7 +74,7 @@ def _resolve_knowledge_context(
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases, model_profile)
if not kb_vector_stores:
return ""
@@ -143,6 +157,27 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
model_profile = ModelConfigService.get_active_config(db, job.user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
use_model_generation = model_profile is not None
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
except Exception as exc:
logger.exception(
"Testing generation LLM initialization failed for job=%s, falling back to rule-based output: %s",
job_id,
exc,
)
use_model_generation = False
else:
logger.info(
"Testing generation job=%s has no active model config; using rule-based output.",
job_id,
)
job.status = "processing"
job.started_at = datetime.utcnow()
@@ -183,6 +218,7 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
model_profile=model_profile,
)
pipeline_result = run_testing_pipeline(
@@ -190,7 +226,8 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=True,
use_model_generation=use_model_generation,
llm_model=llm_model,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,

View File

@@ -33,13 +33,13 @@ def run_testing_pipeline(
debug: bool = False,
knowledge_context: Optional[str] = None,
use_model_generation: bool = False,
llm_model: Any = None,
max_items_per_group: int = 12,
cases_per_item: int = 2,
max_focus_points: int = 6,
max_llm_calls: int = 10,
) -> Dict[str, Any]:
llm_model = None
if use_model_generation:
if use_model_generation and llm_model is None:
try:
from app.services.llm.llm_factory import LLMFactory