增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
@@ -17,6 +17,7 @@ from app.services.fusion_prompts import (
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
from app.services.intent_router import route_intent
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.services.reranker.external_api import ExternalRerankerClient
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline.pipeline import run_testing_pipeline
|
||||
@@ -202,8 +203,12 @@ def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
async def _build_kb_vector_stores(
|
||||
db: Any,
|
||||
knowledge_bases: List[KnowledgeBase],
|
||||
model_profile: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
@@ -221,10 +226,13 @@ async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase])
|
||||
return kb_vector_stores
|
||||
|
||||
|
||||
def _build_reranker_client() -> ExternalRerankerClient:
|
||||
def _build_reranker_client(model_profile: Any = None) -> ExternalRerankerClient:
|
||||
api_key = settings.RERANKER_API_KEY
|
||||
if model_profile is not None and getattr(model_profile, "provider", "") == "dashscope":
|
||||
api_key = getattr(model_profile, "api_key", "") or api_key
|
||||
return ExternalRerankerClient(
|
||||
api_url=settings.RERANKER_API_URL,
|
||||
api_key=settings.RERANKER_API_KEY,
|
||||
api_key=api_key,
|
||||
model=settings.RERANKER_MODEL,
|
||||
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
|
||||
)
|
||||
@@ -287,6 +295,7 @@ async def generate_response(
|
||||
knowledge_base_ids: List[int],
|
||||
chat_id: int,
|
||||
db: Any,
|
||||
user_id: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
user_message = Message(content=query, role="user", chat_id=chat_id)
|
||||
@@ -297,6 +306,9 @@ async def generate_response(
|
||||
db.add(bot_message)
|
||||
db.commit()
|
||||
|
||||
model_profile = ModelConfigService.require_active_config(db, user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
|
||||
if _is_testing_generation_request(query):
|
||||
explicit_type = _extract_requirement_type_from_query(query)
|
||||
|
||||
@@ -309,7 +321,7 @@ async def generate_response(
|
||||
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
|
||||
.all()
|
||||
)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs, model_profile)
|
||||
|
||||
if kb_vector_stores:
|
||||
testing_retriever = MultiKBRetriever(
|
||||
@@ -330,6 +342,7 @@ async def generate_response(
|
||||
debug=True,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=True,
|
||||
llm_model=LLMFactory.create(streaming=False, model_profile=model_profile),
|
||||
max_items_per_group=6,
|
||||
cases_per_item=1,
|
||||
max_focus_points=6,
|
||||
@@ -391,11 +404,11 @@ async def generate_response(
|
||||
)
|
||||
kb_ids = [kb.id for kb in knowledge_bases]
|
||||
|
||||
llm = LLMFactory.create()
|
||||
llm = LLMFactory.create(model_profile=model_profile)
|
||||
decision = await route_intent(llm=llm, query=query, messages=messages)
|
||||
intent = decision["intent"]
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
|
||||
if intent in {"B", "C", "D"} and not kb_vector_stores:
|
||||
intent = "A"
|
||||
decision = {
|
||||
@@ -403,7 +416,7 @@ async def generate_response(
|
||||
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
|
||||
}
|
||||
|
||||
reranker_client = _build_reranker_client()
|
||||
reranker_client = _build_reranker_client(model_profile)
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_client=reranker_client,
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
@@ -432,7 +445,7 @@ async def generate_response(
|
||||
used_kb_ids: List[int] = []
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
adapter = GraphRAGAdapter(model_profile=model_profile)
|
||||
graph_context, used_kb_ids = await adapter.local_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
@@ -465,7 +478,7 @@ async def generate_response(
|
||||
community_context = ""
|
||||
if settings.GRAPHRAG_ENABLED and kb_ids:
|
||||
try:
|
||||
adapter = GraphRAGAdapter()
|
||||
adapter = GraphRAGAdapter(model_profile=model_profile)
|
||||
community_context, used_kb_ids = await adapter.global_context_multi(
|
||||
kb_ids,
|
||||
query,
|
||||
|
||||
10
rag-web-ui/backend/app/services/code_kb/__init__.py
Normal file
10
rag-web-ui/backend/app/services/code_kb/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
|
||||
|
||||
__all__ = [
|
||||
"CodeFunctionEvidence",
|
||||
"CodeGraphContext",
|
||||
"CodeKnowledgeBaseAdapter",
|
||||
"CodeSearchHit",
|
||||
]
|
||||
|
||||
517
rag-web-ui/backend/app/services/code_kb/adapter.py
Normal file
517
rag-web-ui/backend/app/services/code_kb/adapter.py
Normal file
@@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from app.services.code_kb.graph import CodeCallGraph
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FIELD_ALIASES = {
|
||||
"name": ("name", "function_name"),
|
||||
"file": ("file", "file_path"),
|
||||
"summary": ("summary",),
|
||||
"logic_flow": ("logic_flow", "logic"),
|
||||
"code_snippet": ("code_snippet", "source", "code"),
|
||||
"calls": ("calls", "called_functions"),
|
||||
"called_by": ("called_by", "caller_functions", "callers"),
|
||||
}
|
||||
|
||||
|
||||
class SimpleVectorIndex:
|
||||
def __init__(self, vectors: Any, dimension: int) -> None:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError: # pragma: no cover - depends on deployment environment
|
||||
np = None
|
||||
self._np = np
|
||||
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
|
||||
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
|
||||
self.ntotal = len(self.vectors)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
|
||||
try:
|
||||
with open(vector_path, "r", encoding="utf-8") as file:
|
||||
payload = json.load(file)
|
||||
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
|
||||
return None
|
||||
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
|
||||
return None
|
||||
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
|
||||
|
||||
def search(self, vector: Any, top_k: int) -> Any:
|
||||
if self.ntotal == 0:
|
||||
if self._np is not None:
|
||||
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
|
||||
return [[]], [[]]
|
||||
if self._np is not None:
|
||||
query = self._np.array(vector, dtype="float32")
|
||||
if query.ndim != 2 or query.shape[1] != self.d:
|
||||
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
||||
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
|
||||
order = self._np.argsort(distances)[:top_k]
|
||||
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
|
||||
|
||||
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
|
||||
query_values = [float(value) for value in query_row]
|
||||
if len(query_values) != self.d:
|
||||
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
||||
scored = [
|
||||
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
|
||||
for index, stored in enumerate(self.vectors)
|
||||
]
|
||||
scored.sort(key=lambda item: item[0])
|
||||
selected = scored[:top_k]
|
||||
return [[item[0] for item in selected]], [[item[1] for item in selected]]
|
||||
|
||||
def has_nonzero_vectors(self) -> bool:
|
||||
return any(
|
||||
any(abs(float(value)) > 1e-12 for value in vector)
|
||||
for vector in self.vectors
|
||||
)
|
||||
|
||||
|
||||
class UnavailableVectorIndex:
|
||||
d = 0
|
||||
ntotal = 0
|
||||
|
||||
def __init__(self, reason: str) -> None:
|
||||
self.reason = reason
|
||||
|
||||
def search(self, vector: Any, top_k: int) -> Any:
|
||||
raise RuntimeError(self.reason)
|
||||
|
||||
|
||||
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
|
||||
for alias in aliases:
|
||||
value = source.get(alias)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def _as_list(value: Any) -> List[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value if item not in (None, "")]
|
||||
if isinstance(value, tuple):
|
||||
return [str(item) for item in value if item not in (None, "")]
|
||||
if isinstance(value, str):
|
||||
return [value] if value else []
|
||||
return [str(value)]
|
||||
|
||||
|
||||
def _as_int(value: Any) -> Optional[int]:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _as_bool(value: Any, default: bool = True) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "y"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "n"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _load_json(path: str) -> Any:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
|
||||
metadata = _load_json(metadata_path)
|
||||
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
|
||||
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
|
||||
metadata_items = metadata_items or []
|
||||
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
|
||||
return {
|
||||
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
|
||||
"graph_nodes": graph_meta.get("total_nodes"),
|
||||
"graph_edges": graph_meta.get("total_edges"),
|
||||
"project_root": graph_meta.get("project_root"),
|
||||
"generated_at": graph_meta.get("generated_at"),
|
||||
}
|
||||
|
||||
|
||||
class CodeKnowledgeBaseAdapter:
|
||||
def __init__(self, embedding_function: Any = None) -> None:
|
||||
self.embedding_function = embedding_function
|
||||
self.faiss_index: Any = None
|
||||
self.metadata: List[Dict[str, Any]] = []
|
||||
self.graph_data: Dict[str, Any] = {}
|
||||
self.functions: List[CodeFunctionEvidence] = []
|
||||
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
|
||||
self.call_graph: Optional[CodeCallGraph] = None
|
||||
self.vector_search_enabled = False
|
||||
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
|
||||
|
||||
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
||||
self._validate_paths(vector_path, metadata_path, graph_path)
|
||||
self.faiss_index = self._read_faiss_index(vector_path)
|
||||
self.metadata = self._read_metadata(metadata_path)
|
||||
self.graph_data = _load_json(graph_path)
|
||||
graph_nodes = self._index_graph_nodes(self.graph_data)
|
||||
|
||||
self.functions = [
|
||||
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
|
||||
for row in self.metadata
|
||||
]
|
||||
self.functions_by_id = {item.node_id: item for item in self.functions}
|
||||
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
|
||||
self._configure_vector_search()
|
||||
|
||||
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
||||
missing = [
|
||||
path
|
||||
for path in [vector_path, metadata_path, graph_path]
|
||||
if not path or not Path(path).exists()
|
||||
]
|
||||
if missing:
|
||||
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
|
||||
|
||||
def _read_faiss_index(self, vector_path: str) -> Any:
|
||||
simple_index = SimpleVectorIndex.from_file(vector_path)
|
||||
if simple_index is not None:
|
||||
return simple_index
|
||||
try:
|
||||
import faiss # type: ignore
|
||||
except ImportError as exc:
|
||||
logger.warning("faiss is not installed; code search will use lexical fallback.")
|
||||
return UnavailableVectorIndex(
|
||||
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
|
||||
)
|
||||
return faiss.read_index(vector_path)
|
||||
|
||||
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
|
||||
metadata = _load_json(metadata_path)
|
||||
if isinstance(metadata, dict):
|
||||
metadata = metadata.get("functions", metadata.get("items", []))
|
||||
if not isinstance(metadata, list):
|
||||
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
|
||||
return [item for item in metadata if isinstance(item, dict)]
|
||||
|
||||
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
item.get("id"): item
|
||||
for item in graph_data.get("nodes", []) or []
|
||||
if isinstance(item, dict) and item.get("id")
|
||||
}
|
||||
|
||||
def _normalize_metadata(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
graph_nodes: Dict[str, Dict[str, Any]],
|
||||
index_dimension: int,
|
||||
) -> CodeFunctionEvidence:
|
||||
name = _first_value(metadata, FIELD_ALIASES["name"], "")
|
||||
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
|
||||
graph_node = graph_nodes.get(node_id, {})
|
||||
raw_attributes = graph_node.get("raw_attributes") or {}
|
||||
if not name:
|
||||
name = graph_node.get("name") or node_id.removeprefix("Function:")
|
||||
|
||||
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
|
||||
embedding_dim = metadata.get("embedding_dim") or index_dimension
|
||||
embedding_available = _as_bool(
|
||||
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
|
||||
default=True,
|
||||
)
|
||||
return CodeFunctionEvidence(
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
|
||||
file=file_path or "",
|
||||
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
|
||||
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
|
||||
signature=metadata.get("signature") or graph_node.get("signature") or "",
|
||||
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
|
||||
logic_flow=_first_value(
|
||||
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
|
||||
)
|
||||
or "",
|
||||
code_snippet=_first_value(
|
||||
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
|
||||
)
|
||||
or "",
|
||||
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
|
||||
called_by=_as_list(
|
||||
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
|
||||
),
|
||||
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
|
||||
embedding_model=metadata.get("embedding_model") or "",
|
||||
embedding_dim=int(embedding_dim or 0),
|
||||
embedding_available=embedding_available,
|
||||
raw=metadata,
|
||||
)
|
||||
|
||||
def _configure_vector_search(self) -> None:
|
||||
index_total = int(getattr(self.faiss_index, "ntotal", 0))
|
||||
self.vector_search_enabled = False
|
||||
self.vector_search_disabled_reason = ""
|
||||
if isinstance(self.faiss_index, UnavailableVectorIndex):
|
||||
self.vector_search_disabled_reason = self.faiss_index.reason
|
||||
return
|
||||
if index_total == 0:
|
||||
self.vector_search_disabled_reason = "Vector index is empty."
|
||||
return
|
||||
if index_total != len(self.functions):
|
||||
self.vector_search_disabled_reason = (
|
||||
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
|
||||
)
|
||||
logger.warning(
|
||||
"FAISS index size (%s) differs from metadata size (%s)",
|
||||
index_total,
|
||||
len(self.functions),
|
||||
)
|
||||
return
|
||||
if not self._index_has_nonzero_vectors():
|
||||
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
|
||||
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
||||
return
|
||||
if not any(item.embedding_available for item in self.functions):
|
||||
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
|
||||
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
||||
return
|
||||
self.vector_search_enabled = True
|
||||
|
||||
def _index_has_nonzero_vectors(self) -> bool:
|
||||
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
|
||||
if callable(has_nonzero_vectors):
|
||||
return bool(has_nonzero_vectors())
|
||||
|
||||
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
|
||||
if not callable(reconstruct_n):
|
||||
return True
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
|
||||
sample_size = min(index_total, 1000)
|
||||
if sample_size <= 0:
|
||||
return False
|
||||
vectors = reconstruct_n(0, sample_size)
|
||||
return bool(np.any(np.abs(vectors) > 1e-12))
|
||||
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
|
||||
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
|
||||
return True
|
||||
|
||||
def _get_embedding_function(self) -> Any:
|
||||
if self.embedding_function is None:
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
|
||||
self.embedding_function = EmbeddingsFactory.create()
|
||||
return self.embedding_function
|
||||
|
||||
def _embed_query(self, query: str) -> np.ndarray:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
|
||||
embedding_function = self._get_embedding_function()
|
||||
if hasattr(embedding_function, "embed_query"):
|
||||
vector = embedding_function.embed_query(query)
|
||||
elif callable(embedding_function):
|
||||
vector = embedding_function(query)
|
||||
else:
|
||||
raise TypeError("Unsupported embedding function.")
|
||||
query_vector = np.array([vector], dtype="float32")
|
||||
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
|
||||
if index_dimension and query_vector.shape[1] != index_dimension:
|
||||
raise ValueError(
|
||||
f"Query embedding dimension {query_vector.shape[1]} does not match "
|
||||
f"FAISS index dimension {index_dimension}."
|
||||
)
|
||||
return query_vector
|
||||
|
||||
@staticmethod
|
||||
def distance_to_similarity(distance: float) -> float:
|
||||
if math.isnan(distance) or distance < 0:
|
||||
return 0.0
|
||||
if distance <= 2.0:
|
||||
return max(0.0, min(1.0, 1.0 - distance / 2.0))
|
||||
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
|
||||
|
||||
def search_functions(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
if self.faiss_index is None:
|
||||
raise RuntimeError("Code knowledge base has not been loaded.")
|
||||
if not query.strip():
|
||||
return []
|
||||
if not self.vector_search_enabled:
|
||||
logger.info(
|
||||
"Vector code search disabled, using lexical search: %s",
|
||||
self.vector_search_disabled_reason,
|
||||
)
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
|
||||
try:
|
||||
vector = self._embed_query(query)
|
||||
distances, indices = self.faiss_index.search(vector, top_k)
|
||||
except Exception as exc:
|
||||
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
|
||||
hits: List[CodeSearchHit] = []
|
||||
for rank, raw_index in enumerate(indices[0], start=1):
|
||||
index = int(raw_index)
|
||||
if index < 0 or index >= len(self.functions):
|
||||
continue
|
||||
evidence = self.functions[index]
|
||||
if not evidence.embedding_available:
|
||||
continue
|
||||
distance = float(distances[0][rank - 1])
|
||||
similarity = self.distance_to_similarity(distance)
|
||||
if similarity < min_similarity:
|
||||
continue
|
||||
hits.append(
|
||||
CodeSearchHit(
|
||||
evidence=evidence,
|
||||
similarity=similarity,
|
||||
distance=distance,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
if not hits:
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
return hits
|
||||
|
||||
def _lexical_search_functions(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
query_tokens = self._tokens(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
scored: List[CodeSearchHit] = []
|
||||
for evidence in self.functions:
|
||||
text = self._function_search_text(evidence)
|
||||
score = self._lexical_similarity(query_tokens, text, evidence)
|
||||
if score < min_similarity:
|
||||
continue
|
||||
scored.append(
|
||||
CodeSearchHit(
|
||||
evidence=evidence,
|
||||
similarity=score,
|
||||
distance=max(0.0, 1.0 - score),
|
||||
rank=0,
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.similarity, reverse=True)
|
||||
for rank, item in enumerate(scored[:top_k], start=1):
|
||||
item.rank = rank
|
||||
return scored[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _tokens(text: str) -> List[str]:
|
||||
normalized = (text or "").lower()
|
||||
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
|
||||
expanded: List[str] = []
|
||||
for token in tokens:
|
||||
expanded.append(token)
|
||||
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
|
||||
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
|
||||
return expanded
|
||||
|
||||
@staticmethod
|
||||
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
evidence.node_id,
|
||||
evidence.name,
|
||||
evidence.qualified_name,
|
||||
evidence.file,
|
||||
evidence.signature,
|
||||
evidence.summary,
|
||||
evidence.logic_flow,
|
||||
evidence.code_snippet[:4000],
|
||||
" ".join(evidence.calls),
|
||||
" ".join(evidence.called_by),
|
||||
" ".join(evidence.includes),
|
||||
]
|
||||
).lower()
|
||||
|
||||
def _lexical_similarity(
|
||||
self,
|
||||
query_tokens: List[str],
|
||||
text: str,
|
||||
evidence: CodeFunctionEvidence,
|
||||
) -> float:
|
||||
text_tokens = set(self._tokens(text))
|
||||
if not text_tokens:
|
||||
return 0.0
|
||||
unique_query_tokens = list(dict.fromkeys(query_tokens))
|
||||
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
|
||||
substring_count = sum(1 for token in unique_query_tokens if token in text)
|
||||
overlap_score = overlap_count / max(1, len(unique_query_tokens))
|
||||
substring_score = substring_count / max(1, len(unique_query_tokens))
|
||||
|
||||
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
|
||||
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
|
||||
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
|
||||
evidence_score = 0.0
|
||||
if evidence.summary:
|
||||
evidence_score += 0.08
|
||||
if evidence.logic_flow:
|
||||
evidence_score += 0.05
|
||||
if evidence.code_snippet:
|
||||
evidence_score += 0.04
|
||||
if evidence.start_line is not None and evidence.file:
|
||||
evidence_score += 0.03
|
||||
|
||||
return max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
overlap_score * 0.45
|
||||
+ substring_score * 0.35
|
||||
+ name_score * 0.12
|
||||
+ evidence_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
|
||||
return self.functions_by_id.get(node_id)
|
||||
|
||||
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||||
if not self.call_graph:
|
||||
return CodeGraphContext(node_id=node_id)
|
||||
return self.call_graph.expand(node_id, max_hops=max_hops)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"function_count": len(self.functions),
|
||||
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
|
||||
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
|
||||
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
|
||||
"vector_search_enabled": self.vector_search_enabled,
|
||||
"vector_search_disabled_reason": self.vector_search_disabled_reason,
|
||||
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
|
||||
"graph_edges": len(self.graph_data.get("edges", []) or []),
|
||||
}
|
||||
46
rag-web-ui/backend/app/services/code_kb/formatter.py
Normal file
46
rag-web-ui/backend/app/services/code_kb/formatter.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, List
|
||||
|
||||
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
|
||||
|
||||
|
||||
def _clip(value: str, limit: int) -> str:
|
||||
value = value or ""
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
return value[:limit].rstrip() + "\n...[truncated]"
|
||||
|
||||
|
||||
def format_evidence_context(
|
||||
hits: Iterable[CodeSearchHit],
|
||||
graph_contexts: Iterable[CodeGraphContext],
|
||||
max_code_chars: int = 1000,
|
||||
max_text_chars: int = 800,
|
||||
) -> str:
|
||||
context_by_node = {item.node_id: item for item in graph_contexts}
|
||||
blocks: List[str] = []
|
||||
for hit in hits:
|
||||
item = hit.evidence
|
||||
graph = context_by_node.get(item.node_id)
|
||||
lines = [
|
||||
f"[Function Evidence #{hit.rank}]",
|
||||
f"node_id: {item.node_id}",
|
||||
f"name: {item.name}",
|
||||
f"qualified_name: {item.qualified_name}",
|
||||
f"file: {item.file}",
|
||||
f"lines: {item.start_line}-{item.end_line}",
|
||||
f"similarity: {hit.similarity:.4f}",
|
||||
f"signature: {_clip(item.signature, 300)}",
|
||||
f"summary: {_clip(item.summary, max_text_chars)}",
|
||||
f"logic_flow: {_clip(item.logic_flow, max_text_chars)}",
|
||||
f"calls: {', '.join(item.calls[:20]) or '-'}",
|
||||
f"called_by: {', '.join(item.called_by[:20]) or '-'}",
|
||||
]
|
||||
if graph:
|
||||
lines.append(f"call_chain_evidence: {'; '.join(graph.call_chains[:12]) or '-'}")
|
||||
if item.code_snippet:
|
||||
lines.extend(["code_snippet:", _clip(item.code_snippet, max_code_chars)])
|
||||
blocks.append("\n".join(lines))
|
||||
return "\n\n---\n\n".join(blocks)
|
||||
|
||||
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
|
||||
|
||||
|
||||
class CodeCallGraph:
|
||||
def __init__(
|
||||
self,
|
||||
graph_data: Optional[dict],
|
||||
functions: Iterable[CodeFunctionEvidence],
|
||||
) -> None:
|
||||
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
|
||||
item.node_id: item for item in functions
|
||||
}
|
||||
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
|
||||
for item in self.functions_by_id.values():
|
||||
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
|
||||
if name:
|
||||
self.name_to_ids[name].append(item.node_id)
|
||||
|
||||
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self._load_graph_edges(graph_data or {})
|
||||
self._load_metadata_edges()
|
||||
|
||||
def _load_graph_edges(self, graph_data: dict) -> None:
|
||||
for edge in graph_data.get("edges", []) or []:
|
||||
if edge.get("type") != "CALLS":
|
||||
continue
|
||||
source_id = edge.get("source_id")
|
||||
target_id = edge.get("target_id")
|
||||
if source_id in self.functions_by_id and target_id in self.functions_by_id:
|
||||
self.calls_by_id[source_id].add(target_id)
|
||||
self.called_by_id[target_id].add(source_id)
|
||||
|
||||
def _resolve_name(self, value: str) -> List[str]:
|
||||
if not value:
|
||||
return []
|
||||
if value in self.functions_by_id:
|
||||
return [value]
|
||||
if value.startswith("Function:") and value in self.functions_by_id:
|
||||
return [value]
|
||||
return self.name_to_ids.get(value, [])
|
||||
|
||||
def _load_metadata_edges(self) -> None:
|
||||
for function in self.functions_by_id.values():
|
||||
for callee in function.calls:
|
||||
for target_id in self._resolve_name(callee):
|
||||
if target_id != function.node_id:
|
||||
self.calls_by_id[function.node_id].add(target_id)
|
||||
self.called_by_id[target_id].add(function.node_id)
|
||||
for caller in function.called_by:
|
||||
for source_id in self._resolve_name(caller):
|
||||
if source_id != function.node_id:
|
||||
self.called_by_id[function.node_id].add(source_id)
|
||||
self.calls_by_id[source_id].add(function.node_id)
|
||||
|
||||
def _bfs(
|
||||
self,
|
||||
start_id: str,
|
||||
max_hops: int,
|
||||
next_nodes: Callable[[str], Iterable[str]],
|
||||
) -> List[CodeFunctionEvidence]:
|
||||
seen = {start_id}
|
||||
result: List[CodeFunctionEvidence] = []
|
||||
queue = deque([(start_id, 0)])
|
||||
while queue:
|
||||
current_id, depth = queue.popleft()
|
||||
if depth >= max_hops:
|
||||
continue
|
||||
for next_id in sorted(next_nodes(current_id)):
|
||||
if next_id in seen:
|
||||
continue
|
||||
seen.add(next_id)
|
||||
function = self.functions_by_id.get(next_id)
|
||||
if function:
|
||||
result.append(function)
|
||||
queue.append((next_id, depth + 1))
|
||||
return result
|
||||
|
||||
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||||
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
|
||||
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
|
||||
center = self.functions_by_id.get(node_id)
|
||||
center_name = center.name if center else node_id
|
||||
|
||||
call_chains: List[str] = []
|
||||
for caller in callers[:5]:
|
||||
call_chains.append(f"{caller.name} -> {center_name}")
|
||||
for callee in callees[:5]:
|
||||
call_chains.append(f"{center_name} -> {callee.name}")
|
||||
for caller in callers[:3]:
|
||||
for callee in callees[:3]:
|
||||
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
|
||||
|
||||
return CodeGraphContext(
|
||||
node_id=node_id,
|
||||
callers=callers,
|
||||
callees=callees,
|
||||
call_chains=list(dict.fromkeys(call_chains)),
|
||||
)
|
||||
|
||||
24
rag-web-ui/backend/app/services/code_kb/retriever.py
Normal file
24
rag-web-ui/backend/app/services/code_kb/retriever.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.code_kb.schema import CodeSearchHit
|
||||
|
||||
|
||||
class CodeFunctionRetriever:
|
||||
def __init__(self, adapter: CodeKnowledgeBaseAdapter) -> None:
|
||||
self.adapter = adapter
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
return self.adapter.search_functions(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
|
||||
63
rag-web-ui/backend/app/services/code_kb/schema.py
Normal file
63
rag-web-ui/backend/app/services/code_kb/schema.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeFunctionEvidence:
|
||||
node_id: str
|
||||
name: str
|
||||
qualified_name: str
|
||||
file: str
|
||||
start_line: Optional[int] = None
|
||||
end_line: Optional[int] = None
|
||||
signature: str = ""
|
||||
summary: str = ""
|
||||
logic_flow: str = ""
|
||||
code_snippet: str = ""
|
||||
calls: List[str] = field(default_factory=list)
|
||||
called_by: List[str] = field(default_factory=list)
|
||||
includes: List[str] = field(default_factory=list)
|
||||
embedding_model: str = ""
|
||||
embedding_dim: int = 0
|
||||
embedding_available: bool = True
|
||||
raw: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeSearchHit:
|
||||
evidence: CodeFunctionEvidence
|
||||
similarity: float
|
||||
distance: float
|
||||
rank: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = self.evidence.to_dict()
|
||||
data.update(
|
||||
{
|
||||
"similarity": self.similarity,
|
||||
"distance": self.distance,
|
||||
"rank": self.rank,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeGraphContext:
|
||||
node_id: str
|
||||
callers: List[CodeFunctionEvidence] = field(default_factory=list)
|
||||
callees: List[CodeFunctionEvidence] = field(default_factory=list)
|
||||
call_chains: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"callers": [item.to_dict() for item in self.callers],
|
||||
"callees": [item.to_dict() for item in self.callees],
|
||||
"call_chains": self.call_chains,
|
||||
}
|
||||
4
rag-web-ui/backend/app/services/consistency/__init__.py
Normal file
4
rag-web-ui/backend/app/services/consistency/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.services.consistency.comparator import ConsistencyComparator
|
||||
|
||||
__all__ = ["ConsistencyComparator"]
|
||||
|
||||
258
rag-web-ui/backend/app/services/consistency/comparator.py
Normal file
258
rag-web-ui/backend/app/services/consistency/comparator.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.code_kb.formatter import format_evidence_context
|
||||
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
|
||||
from app.services.consistency.prompt import build_judgment_prompt, build_requirement_query
|
||||
from app.services.consistency.schema import ConsistencyResultItem, RequirementSnapshot, VERDICTS
|
||||
from app.services.consistency.scorer import coverage_score
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _clip(value: str, limit: int) -> str:
|
||||
text = value or ""
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[:limit].rstrip() + "\n...[truncated]"
|
||||
|
||||
|
||||
def _as_list(value: Any) -> List[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value if str(item).strip()]
|
||||
if isinstance(value, tuple):
|
||||
return [str(item) for item in value if str(item).strip()]
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
return _as_list(parsed)
|
||||
except json.JSONDecodeError:
|
||||
return [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return [str(value)]
|
||||
|
||||
|
||||
def requirement_to_snapshot(requirement: Any) -> RequirementSnapshot:
|
||||
getter = requirement.get if isinstance(requirement, dict) else lambda key, default=None: getattr(requirement, key, default)
|
||||
return RequirementSnapshot(
|
||||
requirement_uid=getter("requirement_uid") or getter("id") or "",
|
||||
title=getter("title") or "",
|
||||
description=getter("description") or "",
|
||||
acceptance_criteria=_as_list(getter("acceptance_criteria") or getter("acceptanceCriteria")),
|
||||
requirement_type=getter("requirement_type") or getter("requirementType"),
|
||||
section_title=getter("section_title") or getter("sectionTitle"),
|
||||
interface_name=getter("interface_name") or getter("interfaceName"),
|
||||
interface_type=getter("interface_type") or getter("interfaceType"),
|
||||
data_source=getter("data_source") or getter("dataSource"),
|
||||
data_destination=getter("data_destination") or getter("dataDestination"),
|
||||
)
|
||||
|
||||
|
||||
class ConsistencyComparator:
|
||||
def __init__(
|
||||
self,
|
||||
code_kb_adapter: CodeKnowledgeBaseAdapter,
|
||||
llm: Any = None,
|
||||
use_llm: bool = True,
|
||||
) -> None:
|
||||
self.code_kb_adapter = code_kb_adapter
|
||||
self.llm = llm
|
||||
self.use_llm = use_llm
|
||||
|
||||
def compare_requirements(
|
||||
self,
|
||||
requirements: Iterable[Any],
|
||||
top_k: int = 8,
|
||||
max_call_hops: int = 2,
|
||||
min_similarity: float = 0.55,
|
||||
) -> List[ConsistencyResultItem]:
|
||||
return [
|
||||
self.compare_requirement(
|
||||
requirement,
|
||||
top_k=top_k,
|
||||
max_call_hops=max_call_hops,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
for requirement in requirements
|
||||
]
|
||||
|
||||
def compare_requirement(
|
||||
self,
|
||||
requirement: Any,
|
||||
top_k: int = 8,
|
||||
max_call_hops: int = 2,
|
||||
min_similarity: float = 0.55,
|
||||
) -> ConsistencyResultItem:
|
||||
snapshot = requirement_to_snapshot(requirement)
|
||||
query = build_requirement_query(snapshot)
|
||||
hits = self.code_kb_adapter.search_functions(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
contexts = [
|
||||
self.code_kb_adapter.expand_call_context(hit.evidence.node_id, max_hops=max_call_hops)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
if not hits:
|
||||
judgment = self._missing_judgment("未找到满足相似度阈值的函数证据。")
|
||||
elif not self.use_llm:
|
||||
judgment = self._heuristic_judgment(hits, contexts)
|
||||
else:
|
||||
judgment = self._llm_judgment(snapshot, hits, contexts)
|
||||
|
||||
judgment = self._normalize_judgment(judgment)
|
||||
judgment["requirement_snapshot"] = snapshot.to_dict()
|
||||
score = coverage_score(snapshot, hits, contexts, judgment)
|
||||
matched_functions = [self._matched_function_payload(hit) for hit in hits]
|
||||
call_chains = self._collect_call_chains(contexts)
|
||||
|
||||
return ConsistencyResultItem(
|
||||
requirement_uid=snapshot.requirement_uid,
|
||||
requirement_title=snapshot.title,
|
||||
requirement_type=snapshot.requirement_type,
|
||||
requirement_text=snapshot.description,
|
||||
verdict=judgment["verdict"],
|
||||
coverage_score=score,
|
||||
confidence=float(judgment.get("confidence") or 0.0),
|
||||
matched_functions=matched_functions,
|
||||
covered_points=_as_list(judgment.get("covered_points")),
|
||||
missing_points=_as_list(judgment.get("missing_points")),
|
||||
conflict_points=_as_list(judgment.get("conflict_points")),
|
||||
call_chain_evidence=call_chains,
|
||||
suggestion=str(judgment.get("suggestion") or ""),
|
||||
raw_judgment=judgment,
|
||||
)
|
||||
|
||||
def _llm_judgment(
|
||||
self,
|
||||
requirement: RequirementSnapshot,
|
||||
hits: List[CodeSearchHit],
|
||||
contexts: List[CodeGraphContext],
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
evidence_context = format_evidence_context(hits, contexts)
|
||||
prompt = build_judgment_prompt(requirement, evidence_context)
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
|
||||
llm = self.llm or LLMFactory.create(temperature=0, streaming=False)
|
||||
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
|
||||
text = getattr(response, "content", response)
|
||||
return self.parse_json_judgment(str(text))
|
||||
except Exception as exc:
|
||||
logger.exception("LLM consistency judgment failed: %s", exc)
|
||||
return {
|
||||
"verdict": "uncertain",
|
||||
"confidence": 0.2,
|
||||
"covered_points": [],
|
||||
"missing_points": ["模型判定失败,无法可靠确认覆盖情况。"],
|
||||
"conflict_points": [],
|
||||
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
|
||||
"reasoning": f"LLM judgment failed: {exc}",
|
||||
"suggestion": "请检查模型配置,或人工复核匹配函数证据。",
|
||||
"fallback": True,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_json_judgment(raw_text: str) -> Dict[str, Any]:
|
||||
text = raw_text.strip()
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
|
||||
text = re.sub(r"```$", "", text).strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
||||
if match:
|
||||
return json.loads(match.group(0))
|
||||
raise
|
||||
|
||||
def _heuristic_judgment(
|
||||
self,
|
||||
hits: List[CodeSearchHit],
|
||||
contexts: List[CodeGraphContext],
|
||||
) -> Dict[str, Any]:
|
||||
best = hits[0].similarity if hits else 0.0
|
||||
if best >= 0.78:
|
||||
verdict = "partial"
|
||||
confidence = min(0.68, best)
|
||||
else:
|
||||
verdict = "uncertain"
|
||||
confidence = min(0.5, best)
|
||||
return {
|
||||
"verdict": verdict,
|
||||
"confidence": confidence,
|
||||
"covered_points": [],
|
||||
"missing_points": ["未启用 LLM 判定,无法细分验收准则覆盖点。"],
|
||||
"conflict_points": [],
|
||||
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
|
||||
"reasoning": "仅基于向量召回和调用图生成保守判定。",
|
||||
"suggestion": "启用模型判定或人工复核主要匹配函数。",
|
||||
"call_context_count": len(contexts),
|
||||
}
|
||||
|
||||
def _missing_judgment(self, reason: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"verdict": "missing",
|
||||
"confidence": 0.75,
|
||||
"covered_points": [],
|
||||
"missing_points": [reason],
|
||||
"conflict_points": [],
|
||||
"primary_evidence": [],
|
||||
"reasoning": reason,
|
||||
"suggestion": "补充代码实现或降低阈值后重新召回,并人工确认是否存在命名差异。",
|
||||
}
|
||||
|
||||
def _normalize_judgment(self, judgment: Dict[str, Any]) -> Dict[str, Any]:
|
||||
verdict = str(judgment.get("verdict") or "uncertain").strip().lower()
|
||||
if verdict not in VERDICTS:
|
||||
verdict = "uncertain"
|
||||
confidence = judgment.get("confidence", 0.0)
|
||||
try:
|
||||
confidence = max(0.0, min(1.0, float(confidence)))
|
||||
except (TypeError, ValueError):
|
||||
confidence = 0.0
|
||||
normalized = dict(judgment)
|
||||
normalized["verdict"] = verdict
|
||||
normalized["confidence"] = confidence
|
||||
normalized.setdefault("covered_points", [])
|
||||
normalized.setdefault("missing_points", [])
|
||||
normalized.setdefault("conflict_points", [])
|
||||
normalized.setdefault("primary_evidence", [])
|
||||
normalized.setdefault("reasoning", "")
|
||||
normalized.setdefault("suggestion", "")
|
||||
return normalized
|
||||
|
||||
def _matched_function_payload(self, hit: CodeSearchHit) -> Dict[str, Any]:
|
||||
item = hit.evidence
|
||||
return {
|
||||
"node_id": item.node_id,
|
||||
"name": item.name,
|
||||
"file": item.file,
|
||||
"start_line": item.start_line,
|
||||
"end_line": item.end_line,
|
||||
"similarity": round(hit.similarity, 4),
|
||||
"role": item.summary[:120] if item.summary else "",
|
||||
"evidence_summary": item.summary,
|
||||
"logic_flow": _clip(item.logic_flow, 1200),
|
||||
"code_snippet": _clip(item.code_snippet, 2000),
|
||||
"calls": item.calls[:20],
|
||||
"called_by": item.called_by[:20],
|
||||
"signature": item.signature,
|
||||
}
|
||||
|
||||
def _collect_call_chains(self, contexts: List[CodeGraphContext]) -> List[str]:
|
||||
chains: List[str] = []
|
||||
for context in contexts:
|
||||
chains.extend(context.call_chains)
|
||||
return list(dict.fromkeys(chains))[:30]
|
||||
134
rag-web-ui/backend/app/services/consistency/exporter.py
Normal file
134
rag-web-ui/backend/app/services/consistency/exporter.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
|
||||
def normalize_result_dicts(results: Iterable[Any]) -> List[Dict[str, Any]]:
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for item in results:
|
||||
if hasattr(item, "to_dict"):
|
||||
normalized.append(item.to_dict())
|
||||
elif isinstance(item, dict):
|
||||
normalized.append(item)
|
||||
else:
|
||||
normalized.append(
|
||||
{
|
||||
"requirement_uid": getattr(item, "requirement_uid", ""),
|
||||
"verdict": getattr(item, "verdict", ""),
|
||||
"coverage_score": getattr(item, "coverage_score", 0.0),
|
||||
"confidence": getattr(item, "confidence", 0.0),
|
||||
"matched_functions": getattr(item, "matched_functions", []),
|
||||
"covered_points": getattr(item, "covered_points", []),
|
||||
"missing_points": getattr(item, "missing_points", []),
|
||||
"conflict_points": getattr(item, "conflict_points", []),
|
||||
"call_chain_evidence": getattr(item, "call_chain_evidence", []),
|
||||
"suggestion": getattr(item, "suggestion", ""),
|
||||
"raw_judgment": getattr(item, "raw_judgment", {}),
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def export_json(results: Iterable[Any]) -> bytes:
|
||||
return json.dumps(
|
||||
{"results": normalize_result_dicts(results)},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def export_markdown(results: Iterable[Any]) -> str:
|
||||
rows = normalize_result_dicts(results)
|
||||
lines = [
|
||||
"# 需求代码一致性比对报告",
|
||||
"",
|
||||
"| 需求 ID | 判定 | 覆盖分 | 置信度 | 匹配函数 | 缺失点 | 建议 |",
|
||||
"| --- | --- | ---: | ---: | ---: | ---: | --- |",
|
||||
]
|
||||
for item in rows:
|
||||
lines.append(
|
||||
"| {uid} | {verdict} | {score:.2f} | {confidence:.2f} | {functions} | {missing} | {suggestion} |".format(
|
||||
uid=item.get("requirement_uid", ""),
|
||||
verdict=item.get("verdict", ""),
|
||||
score=float(item.get("coverage_score") or 0),
|
||||
confidence=float(item.get("confidence") or 0),
|
||||
functions=len(item.get("matched_functions") or []),
|
||||
missing=len(item.get("missing_points") or []),
|
||||
suggestion=str(item.get("suggestion") or "").replace("|", "/"),
|
||||
)
|
||||
)
|
||||
|
||||
for item in rows:
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"## {item.get('requirement_uid', '')} {item.get('requirement_title', '')}",
|
||||
"",
|
||||
f"- 判定: `{item.get('verdict', '')}`",
|
||||
f"- 覆盖分: {float(item.get('coverage_score') or 0):.2f}",
|
||||
f"- 置信度: {float(item.get('confidence') or 0):.2f}",
|
||||
f"- 建议: {item.get('suggestion') or '-'}",
|
||||
"",
|
||||
"### 匹配函数",
|
||||
]
|
||||
)
|
||||
for function in item.get("matched_functions") or []:
|
||||
lines.append(
|
||||
f"- `{function.get('name')}` {function.get('file')}:{function.get('start_line')} "
|
||||
f"(similarity={float(function.get('similarity') or 0):.2f})"
|
||||
)
|
||||
lines.extend(["", "### 缺失点"])
|
||||
for point in item.get("missing_points") or ["-"]:
|
||||
lines.append(f"- {point}")
|
||||
if item.get("conflict_points"):
|
||||
lines.extend(["", "### 冲突点"])
|
||||
for point in item.get("conflict_points") or []:
|
||||
lines.append(f"- {point}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def export_excel(results: Iterable[Any]) -> bytes:
|
||||
try:
|
||||
from openpyxl import Workbook
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("openpyxl is required to export Excel reports.") from exc
|
||||
|
||||
rows = normalize_result_dicts(results)
|
||||
workbook = Workbook()
|
||||
sheet = workbook.active
|
||||
sheet.title = "Consistency"
|
||||
headers = [
|
||||
"需求ID",
|
||||
"需求标题",
|
||||
"需求类型",
|
||||
"判定",
|
||||
"覆盖分",
|
||||
"置信度",
|
||||
"匹配函数数量",
|
||||
"主要文件",
|
||||
"缺失点数量",
|
||||
"建议",
|
||||
]
|
||||
sheet.append(headers)
|
||||
for item in rows:
|
||||
functions = item.get("matched_functions") or []
|
||||
sheet.append(
|
||||
[
|
||||
item.get("requirement_uid", ""),
|
||||
item.get("requirement_title", ""),
|
||||
item.get("requirement_type", ""),
|
||||
item.get("verdict", ""),
|
||||
item.get("coverage_score", 0),
|
||||
item.get("confidence", 0),
|
||||
len(functions),
|
||||
functions[0].get("file", "") if functions else "",
|
||||
len(item.get("missing_points") or []),
|
||||
item.get("suggestion", ""),
|
||||
]
|
||||
)
|
||||
output = io.BytesIO()
|
||||
workbook.save(output)
|
||||
return output.getvalue()
|
||||
|
||||
58
rag-web-ui/backend/app/services/consistency/prompt.py
Normal file
58
rag-web-ui/backend/app/services/consistency/prompt.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.services.consistency.schema import RequirementSnapshot
|
||||
|
||||
|
||||
SYSTEM_INSTRUCTION = """你是需求代码一致性审查助手。
|
||||
只能基于输入的需求、验收准则、函数摘要、代码片段、调用链证据判断。
|
||||
不得补充未给出的代码事实。
|
||||
证据不足时输出 uncertain。
|
||||
输出严格 JSON,不要 Markdown。"""
|
||||
|
||||
|
||||
def build_requirement_query(requirement: RequirementSnapshot) -> str:
|
||||
parts = []
|
||||
req_type = (requirement.requirement_type or "").lower()
|
||||
if req_type == "interface":
|
||||
parts.extend(
|
||||
[
|
||||
requirement.interface_name or "",
|
||||
requirement.interface_type or "",
|
||||
requirement.data_source or "",
|
||||
requirement.data_destination or "",
|
||||
requirement.description,
|
||||
]
|
||||
)
|
||||
else:
|
||||
parts.extend(
|
||||
[
|
||||
requirement.description,
|
||||
"\n".join(requirement.acceptance_criteria),
|
||||
requirement.section_title or "",
|
||||
requirement.interface_name or "",
|
||||
requirement.data_source or "",
|
||||
requirement.data_destination or "",
|
||||
]
|
||||
)
|
||||
return "\n".join(part for part in parts if part).strip()
|
||||
|
||||
|
||||
def build_judgment_prompt(requirement: RequirementSnapshot, evidence_context: str) -> str:
|
||||
payload = {
|
||||
"requirement": requirement.to_dict(),
|
||||
"evidence": evidence_context,
|
||||
"output_schema": {
|
||||
"verdict": "implemented | partial | missing | conflict | uncertain",
|
||||
"confidence": 0.0,
|
||||
"covered_points": [],
|
||||
"missing_points": [],
|
||||
"conflict_points": [],
|
||||
"primary_evidence": [],
|
||||
"reasoning": "brief reason based only on evidence",
|
||||
"suggestion": "next action",
|
||||
},
|
||||
}
|
||||
return SYSTEM_INSTRUCTION + "\n\n" + json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
61
rag-web-ui/backend/app/services/consistency/run_compare.py
Normal file
61
rag-web-ui/backend/app/services/consistency/run_compare.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run requirement-code consistency comparison.")
|
||||
parser.add_argument("--srs-extraction-id", type=int, required=True)
|
||||
parser.add_argument("--vector-path", required=True)
|
||||
parser.add_argument("--metadata-path", required=True)
|
||||
parser.add_argument("--graph-path", required=True)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--output-excel", default=None)
|
||||
parser.add_argument("--output-markdown", default=None)
|
||||
parser.add_argument("--top-k", type=int, default=8)
|
||||
parser.add_argument("--max-call-hops", type=int, default=2)
|
||||
parser.add_argument("--min-similarity", type=float, default=0.55)
|
||||
parser.add_argument("--no-llm", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.tooling import SRSRequirement
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.consistency.comparator import ConsistencyComparator
|
||||
from app.services.consistency.exporter import export_excel, export_json, export_markdown
|
||||
|
||||
adapter = CodeKnowledgeBaseAdapter()
|
||||
adapter.load(args.vector_path, args.metadata_path, args.graph_path)
|
||||
comparator = ConsistencyComparator(adapter, use_llm=not args.no_llm)
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
requirements = (
|
||||
db.query(SRSRequirement)
|
||||
.filter(SRSRequirement.extraction_id == args.srs_extraction_id)
|
||||
.order_by(SRSRequirement.sort_order)
|
||||
.all()
|
||||
)
|
||||
results = comparator.compare_requirements(
|
||||
requirements,
|
||||
top_k=args.top_k,
|
||||
max_call_hops=args.max_call_hops,
|
||||
min_similarity=args.min_similarity,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Path(args.output).write_bytes(export_json(results))
|
||||
if args.output_markdown:
|
||||
Path(args.output_markdown).write_text(export_markdown(results), encoding="utf-8")
|
||||
if args.output_excel:
|
||||
Path(args.output_excel).write_bytes(export_excel(results))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
46
rag-web-ui/backend/app/services/consistency/schema.py
Normal file
46
rag-web-ui/backend/app/services/consistency/schema.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
VERDICTS = {"implemented", "partial", "missing", "conflict", "uncertain"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequirementSnapshot:
|
||||
requirement_uid: str
|
||||
title: str
|
||||
description: str
|
||||
acceptance_criteria: List[str] = field(default_factory=list)
|
||||
requirement_type: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
interface_name: Optional[str] = None
|
||||
interface_type: Optional[str] = None
|
||||
data_source: Optional[str] = None
|
||||
data_destination: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsistencyResultItem:
|
||||
requirement_uid: str
|
||||
requirement_title: str
|
||||
requirement_type: Optional[str]
|
||||
requirement_text: str
|
||||
verdict: str
|
||||
coverage_score: float
|
||||
confidence: float
|
||||
matched_functions: List[Dict[str, Any]]
|
||||
covered_points: List[str] = field(default_factory=list)
|
||||
missing_points: List[str] = field(default_factory=list)
|
||||
conflict_points: List[str] = field(default_factory=list)
|
||||
call_chain_evidence: List[str] = field(default_factory=list)
|
||||
suggestion: str = ""
|
||||
raw_judgment: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
120
rag-web-ui/backend/app/services/consistency/scorer.py
Normal file
120
rag-web-ui/backend/app/services/consistency/scorer.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
|
||||
from app.services.consistency.schema import RequirementSnapshot
|
||||
|
||||
|
||||
def _clamp(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
|
||||
def _tokens(*values: str) -> List[str]:
|
||||
text = " ".join(value or "" for value in values).lower()
|
||||
return [item for item in re.split(r"[^a-z0-9_\u4e00-\u9fff]+", text) if len(item) >= 2]
|
||||
|
||||
|
||||
def semantic_score(hits: List[CodeSearchHit]) -> float:
|
||||
if not hits:
|
||||
return 0.0
|
||||
top = max(hit.similarity for hit in hits)
|
||||
avg = sum(hit.similarity for hit in hits[:3]) / min(3, len(hits))
|
||||
return _clamp(top * 0.7 + avg * 0.3)
|
||||
|
||||
|
||||
def acceptance_coverage_score(requirement: RequirementSnapshot, judgment: Dict[str, Any]) -> float:
|
||||
criteria = requirement.acceptance_criteria or []
|
||||
covered = judgment.get("covered_points") or []
|
||||
missing = judgment.get("missing_points") or []
|
||||
verdict = judgment.get("verdict")
|
||||
if criteria:
|
||||
if missing:
|
||||
return _clamp((len(criteria) - min(len(missing), len(criteria))) / len(criteria))
|
||||
if covered:
|
||||
return _clamp(len(covered) / len(criteria))
|
||||
return 1.0 if verdict == "implemented" else 0.4 if verdict == "partial" else 0.0
|
||||
return {"implemented": 1.0, "partial": 0.55, "conflict": 0.25, "missing": 0.0}.get(verdict, 0.35)
|
||||
|
||||
|
||||
def evidence_strength_score(hits: List[CodeSearchHit]) -> float:
|
||||
if not hits:
|
||||
return 0.0
|
||||
scores: List[float] = []
|
||||
for hit in hits[:5]:
|
||||
item = hit.evidence
|
||||
checks = [
|
||||
bool(item.file),
|
||||
item.start_line is not None,
|
||||
item.end_line is not None,
|
||||
bool(item.summary),
|
||||
bool(item.logic_flow),
|
||||
bool(item.code_snippet),
|
||||
]
|
||||
scores.append(sum(1 for value in checks if value) / len(checks))
|
||||
return _clamp(sum(scores) / len(scores))
|
||||
|
||||
|
||||
def call_graph_score(contexts: Iterable[CodeGraphContext]) -> float:
|
||||
contexts = list(contexts)
|
||||
if not contexts:
|
||||
return 0.0
|
||||
scored = []
|
||||
for context in contexts[:5]:
|
||||
score = 0.0
|
||||
if context.callers:
|
||||
score += 0.35
|
||||
if context.callees:
|
||||
score += 0.35
|
||||
if context.call_chains:
|
||||
score += 0.30
|
||||
scored.append(score)
|
||||
return _clamp(sum(scored) / len(scored))
|
||||
|
||||
|
||||
def exact_match_score(requirement: RequirementSnapshot, hits: List[CodeSearchHit]) -> float:
|
||||
if not hits:
|
||||
return 0.0
|
||||
important = _tokens(
|
||||
requirement.interface_name or "",
|
||||
requirement.interface_type or "",
|
||||
requirement.data_source or "",
|
||||
requirement.data_destination or "",
|
||||
requirement.title or "",
|
||||
)
|
||||
if not important:
|
||||
important = _tokens(requirement.description)[:12]
|
||||
if not important:
|
||||
return 0.0
|
||||
|
||||
evidence_text = " ".join(
|
||||
f"{hit.evidence.name} {hit.evidence.qualified_name} {hit.evidence.summary} {hit.evidence.logic_flow}"
|
||||
for hit in hits[:5]
|
||||
).lower()
|
||||
matched = sum(1 for token in important if token.lower() in evidence_text)
|
||||
return _clamp(matched / len(important))
|
||||
|
||||
|
||||
def coverage_score(
|
||||
requirement: RequirementSnapshot,
|
||||
hits: List[CodeSearchHit],
|
||||
contexts: List[CodeGraphContext],
|
||||
judgment: Dict[str, Any],
|
||||
) -> float:
|
||||
score = (
|
||||
semantic_score(hits) * 0.25
|
||||
+ acceptance_coverage_score(requirement, judgment) * 0.30
|
||||
+ evidence_strength_score(hits) * 0.20
|
||||
+ call_graph_score(contexts) * 0.15
|
||||
+ exact_match_score(requirement, hits) * 0.10
|
||||
)
|
||||
verdict = judgment.get("verdict")
|
||||
if verdict == "missing":
|
||||
score = min(score, 0.25)
|
||||
elif verdict == "uncertain":
|
||||
score = min(score, 0.55)
|
||||
elif verdict == "conflict":
|
||||
score = min(score, 0.45)
|
||||
return round(_clamp(score), 4)
|
||||
|
||||
711
rag-web-ui/backend/app/services/consistency_job_service.py
Normal file
711
rag-web-ui/backend/app/services/consistency_job_service.py
Normal file
@@ -0,0 +1,711 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import zipfile
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.tooling import (
|
||||
CodeKnowledgeBase,
|
||||
ConsistencyJob,
|
||||
ConsistencyResult,
|
||||
SRSExtraction,
|
||||
SRSRequirement,
|
||||
ToolJob,
|
||||
)
|
||||
from app.schemas.consistency import CodeKnowledgeBaseCreate, ConsistencyJobCreate
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter, read_code_kb_summary
|
||||
from app.services.code_kb.formatter import format_evidence_context
|
||||
from app.services.consistency.comparator import ConsistencyComparator
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.services.srs_job_service import _build_internal_title, _parse_generated_at
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
CODE_UPLOAD_ROOT = Path("uploads") / "code_kbs"
|
||||
AUTO_UPLOAD_ROOT = Path("uploads") / "consistency_auto"
|
||||
|
||||
|
||||
def _workspace_root() -> Path:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "rag-web-ui").exists() and (parent / "RAG-TEST-TOOLS").exists():
|
||||
return parent
|
||||
return current.parents[4]
|
||||
|
||||
|
||||
def _rag_test_tools_root() -> Path:
|
||||
candidates = [
|
||||
_workspace_root() / "RAG-TEST-TOOLS",
|
||||
Path(__file__).resolve().parents[3] / "RAG-TEST-TOOLS",
|
||||
Path.cwd().parent / "RAG-TEST-TOOLS",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate.resolve()
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def safe_upload_name(file_name: str | None, fallback: str = "upload.bin") -> str:
|
||||
safe_name = Path(file_name or fallback).name
|
||||
return safe_name or fallback
|
||||
|
||||
|
||||
def ensure_upload_dir(*parts: str) -> Path:
|
||||
path = Path("uploads").joinpath(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def save_uploaded_bytes(target_dir: Path, file_name: str, content: bytes) -> Path:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = target_dir / safe_upload_name(file_name)
|
||||
path.write_bytes(content)
|
||||
if path.suffix.lower() == ".zip":
|
||||
extract_dir = target_dir / path.stem
|
||||
extract_zip_safe(path, extract_dir)
|
||||
return extract_dir
|
||||
return path
|
||||
|
||||
|
||||
def extract_zip_safe(zip_path: Path, target_dir: Path) -> None:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_root = target_dir.resolve()
|
||||
with zipfile.ZipFile(zip_path) as archive:
|
||||
for member in archive.infolist():
|
||||
member_path = (target_dir / member.filename).resolve()
|
||||
try:
|
||||
member_path.relative_to(target_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Unsafe zip entry: {member.filename}")
|
||||
archive.extractall(target_dir)
|
||||
|
||||
|
||||
def _build_code_kb_artifacts(
|
||||
project_path: str,
|
||||
output_dir: str,
|
||||
base_name: str,
|
||||
use_semantic: bool,
|
||||
model_profile: Any = None,
|
||||
) -> Dict[str, str]:
|
||||
tools_root = _rag_test_tools_root()
|
||||
if not tools_root.exists():
|
||||
raise FileNotFoundError(f"RAG-TEST-TOOLS not found: {tools_root}")
|
||||
|
||||
output_path = Path(output_dir).resolve()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"rag_test_tools.build_code_kb",
|
||||
"--project",
|
||||
str(Path(project_path).resolve()),
|
||||
"--output",
|
||||
str(output_path),
|
||||
"--base-name",
|
||||
base_name,
|
||||
]
|
||||
if not use_semantic:
|
||||
command.append("--skip-semantic")
|
||||
|
||||
env = os.environ.copy()
|
||||
if model_profile is not None:
|
||||
api_key = getattr(model_profile, "api_key", "") or ""
|
||||
api_base = getattr(model_profile, "api_base", "") or ""
|
||||
if api_key:
|
||||
env["DASHSCOPE_API_KEY"] = api_key
|
||||
env["DASH_SCOPE_API_KEY"] = api_key
|
||||
env["QWEN_API_KEY"] = api_key
|
||||
if api_base:
|
||||
env["DASH_SCOPE_API_BASE"] = api_base
|
||||
env["QWEN_API_URL"] = api_base
|
||||
if getattr(model_profile, "chat_model", None):
|
||||
env["QWEN_CHAT_MODEL"] = model_profile.chat_model
|
||||
if getattr(model_profile, "embedding_model", None):
|
||||
env["QWEN_EMBEDDING_MODEL"] = model_profile.embedding_model
|
||||
|
||||
completed = subprocess.run(
|
||||
command,
|
||||
cwd=str(tools_root),
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600,
|
||||
check=False,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Code knowledge base build failed: "
|
||||
f"{completed.stderr or completed.stdout or completed.returncode}"
|
||||
)
|
||||
try:
|
||||
return json.loads(completed.stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Code KB build returned invalid JSON: {completed.stdout}") from exc
|
||||
|
||||
|
||||
def _ensure_paths_exist(paths: Iterable[str]) -> None:
|
||||
missing = [path for path in paths if not path or not Path(path).exists()]
|
||||
if missing:
|
||||
raise FileNotFoundError(f"Code knowledge base file path does not exist: {missing}")
|
||||
|
||||
|
||||
def create_code_kb(db: Session, user_id: int, payload: CodeKnowledgeBaseCreate) -> CodeKnowledgeBase:
|
||||
_ensure_paths_exist([payload.vector_path, payload.metadata_path, payload.graph_path])
|
||||
adapter = CodeKnowledgeBaseAdapter()
|
||||
adapter.load(payload.vector_path, payload.metadata_path, payload.graph_path)
|
||||
summary = {
|
||||
**read_code_kb_summary(payload.metadata_path, payload.graph_path),
|
||||
**adapter.summary(),
|
||||
}
|
||||
code_kb = CodeKnowledgeBase(
|
||||
user_id=user_id,
|
||||
name=payload.name,
|
||||
project_path=payload.project_path,
|
||||
vector_path=payload.vector_path,
|
||||
metadata_path=payload.metadata_path,
|
||||
graph_path=payload.graph_path,
|
||||
status="active",
|
||||
metadata_summary=summary,
|
||||
)
|
||||
db.add(code_kb)
|
||||
db.commit()
|
||||
db.refresh(code_kb)
|
||||
return code_kb
|
||||
|
||||
|
||||
def create_uploaded_code_kb(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
name: str,
|
||||
project_path: str,
|
||||
output_dir: str,
|
||||
) -> CodeKnowledgeBase:
|
||||
base_name = f"code_kb_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}"
|
||||
output_path = Path(output_dir).resolve()
|
||||
code_kb = CodeKnowledgeBase(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
project_path=project_path,
|
||||
vector_path=str(output_path / f"{base_name}_rag.faiss"),
|
||||
metadata_path=str(output_path / f"{base_name}_rag_metadata.json"),
|
||||
graph_path=str(output_path / f"{base_name}_code_knowledge_graph.json"),
|
||||
status="pending",
|
||||
metadata_summary={"base_name": base_name, "output_dir": str(output_path)},
|
||||
)
|
||||
db.add(code_kb)
|
||||
db.commit()
|
||||
db.refresh(code_kb)
|
||||
return code_kb
|
||||
|
||||
|
||||
def run_code_kb_build(code_kb_id: int, use_semantic: bool = True) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb_id).first()
|
||||
if not code_kb:
|
||||
return
|
||||
model_profile = None
|
||||
if use_semantic:
|
||||
model_profile = ModelConfigService.require_active_config(db, code_kb.user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
code_kb.status = "processing"
|
||||
db.add(code_kb)
|
||||
db.commit()
|
||||
|
||||
summary = code_kb.metadata_summary or {}
|
||||
base_name = summary.get("base_name") or f"code_kb_{code_kb.id}"
|
||||
output_dir = summary.get("output_dir") or str(Path(code_kb.vector_path).parent)
|
||||
artifact_paths = _build_code_kb_artifacts(
|
||||
project_path=code_kb.project_path or "",
|
||||
output_dir=output_dir,
|
||||
base_name=base_name,
|
||||
use_semantic=use_semantic,
|
||||
model_profile=model_profile,
|
||||
)
|
||||
code_kb.graph_path = artifact_paths["graph_path"]
|
||||
code_kb.vector_path = artifact_paths["vector_path"]
|
||||
code_kb.metadata_path = artifact_paths["metadata_path"]
|
||||
|
||||
embedding_function = (
|
||||
EmbeddingsFactory.create(model_profile=model_profile)
|
||||
if model_profile is not None
|
||||
else None
|
||||
)
|
||||
adapter = CodeKnowledgeBaseAdapter(embedding_function=embedding_function)
|
||||
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
|
||||
code_kb.status = "active"
|
||||
code_kb.metadata_summary = {
|
||||
**read_code_kb_summary(code_kb.metadata_path, code_kb.graph_path),
|
||||
**adapter.summary(),
|
||||
"source": "upload",
|
||||
}
|
||||
db.add(code_kb)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
if "code_kb" in locals() and code_kb:
|
||||
code_kb.status = "failed"
|
||||
code_kb.metadata_summary = {
|
||||
**(code_kb.metadata_summary or {}),
|
||||
"error_message": str(exc)[:2000],
|
||||
}
|
||||
db.add(code_kb)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def list_code_kbs(db: Session, user_id: int) -> List[CodeKnowledgeBase]:
|
||||
return (
|
||||
db.query(CodeKnowledgeBase)
|
||||
.filter(CodeKnowledgeBase.user_id == user_id)
|
||||
.order_by(CodeKnowledgeBase.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_owned_code_kb(db: Session, user_id: int, code_kb_id: int) -> Optional[CodeKnowledgeBase]:
|
||||
return (
|
||||
db.query(CodeKnowledgeBase)
|
||||
.filter(CodeKnowledgeBase.id == code_kb_id, CodeKnowledgeBase.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_owned_srs_extraction(db: Session, user_id: int, extraction_id: int) -> Optional[SRSExtraction]:
|
||||
return (
|
||||
db.query(SRSExtraction)
|
||||
.join(ToolJob, SRSExtraction.job_id == ToolJob.id)
|
||||
.filter(SRSExtraction.id == extraction_id, ToolJob.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def create_consistency_job(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
payload: ConsistencyJobCreate,
|
||||
) -> ConsistencyJob:
|
||||
extraction = get_owned_srs_extraction(db, user_id, payload.srs_extraction_id)
|
||||
if not extraction:
|
||||
raise ValueError("SRS extraction does not exist.")
|
||||
code_kb = get_owned_code_kb(db, user_id, payload.code_kb_id)
|
||||
if not code_kb:
|
||||
raise ValueError("Code knowledge base does not exist.")
|
||||
if code_kb.status != "active":
|
||||
raise ValueError("Code knowledge base is not active.")
|
||||
|
||||
requirement_query = db.query(SRSRequirement).filter(SRSRequirement.extraction_id == extraction.id)
|
||||
if payload.requirement_uids:
|
||||
requirement_query = requirement_query.filter(SRSRequirement.requirement_uid.in_(payload.requirement_uids))
|
||||
total = requirement_query.count()
|
||||
if total == 0:
|
||||
raise ValueError("No SRS requirements matched the selected scope.")
|
||||
|
||||
job = ConsistencyJob(
|
||||
user_id=user_id,
|
||||
srs_extraction_id=extraction.id,
|
||||
code_kb_id=code_kb.id,
|
||||
status="pending",
|
||||
total_requirements=total,
|
||||
completed_requirements=0,
|
||||
output_summary={
|
||||
"requirement_uids": payload.requirement_uids,
|
||||
"top_k": payload.top_k,
|
||||
"max_call_hops": payload.max_call_hops,
|
||||
"min_similarity": payload.min_similarity,
|
||||
"use_llm": payload.use_llm,
|
||||
},
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def list_consistency_jobs(db: Session, user_id: int) -> List[ConsistencyJob]:
|
||||
return (
|
||||
db.query(ConsistencyJob)
|
||||
.filter(ConsistencyJob.user_id == user_id)
|
||||
.order_by(ConsistencyJob.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_owned_consistency_job(db: Session, user_id: int, job_id: int) -> Optional[ConsistencyJob]:
|
||||
return (
|
||||
db.query(ConsistencyJob)
|
||||
.filter(ConsistencyJob.id == job_id, ConsistencyJob.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def ask_code_kb(
|
||||
code_kb: CodeKnowledgeBase,
|
||||
question: str,
|
||||
top_k: int = 6,
|
||||
min_similarity: float = 0.0,
|
||||
use_llm: bool = True,
|
||||
model_profile: Any = None,
|
||||
) -> Dict[str, Any]:
|
||||
if code_kb.status != "active":
|
||||
raise ValueError("Code knowledge base is not active.")
|
||||
if model_profile is None:
|
||||
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
|
||||
adapter = CodeKnowledgeBaseAdapter(
|
||||
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
|
||||
)
|
||||
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
|
||||
hits = adapter.search_functions(question, top_k=top_k, min_similarity=min_similarity)
|
||||
contexts = [adapter.expand_call_context(hit.evidence.node_id, max_hops=2) for hit in hits]
|
||||
evidence = [hit.to_dict() for hit in hits]
|
||||
if not hits:
|
||||
return {
|
||||
"answer": "未检索到相关函数证据,无法基于代码知识库回答。",
|
||||
"evidence": [],
|
||||
"raw_response": None,
|
||||
}
|
||||
|
||||
evidence_context = format_evidence_context(hits, contexts)
|
||||
if not use_llm:
|
||||
return {
|
||||
"answer": "已检索到相关函数证据,请查看 evidence 字段中的函数摘要、文件位置和调用链。",
|
||||
"evidence": evidence,
|
||||
"raw_response": None,
|
||||
}
|
||||
|
||||
prompt = (
|
||||
"你是代码知识库问答助手。只能基于给定代码证据回答问题;"
|
||||
"如果证据不足,请明确说明不足。回答需要包含关键函数名和文件位置。\n\n"
|
||||
f"问题:{question}\n\n代码证据:\n{evidence_context}"
|
||||
)
|
||||
try:
|
||||
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
|
||||
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
|
||||
answer = str(getattr(response, "content", response))
|
||||
return {"answer": answer, "evidence": evidence, "raw_response": answer}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"answer": f"模型问答失败,已返回检索证据供人工查看。错误:{exc}",
|
||||
"evidence": evidence,
|
||||
"raw_response": None,
|
||||
}
|
||||
|
||||
|
||||
def result_model_to_export_dict(result: ConsistencyResult) -> Dict[str, Any]:
|
||||
raw = result.raw_judgment or {}
|
||||
return {
|
||||
"requirement_uid": result.requirement_uid,
|
||||
"requirement_title": raw.get("requirement_title", ""),
|
||||
"requirement_type": raw.get("requirement_type"),
|
||||
"requirement_text": raw.get("requirement_text", ""),
|
||||
"verdict": result.verdict,
|
||||
"coverage_score": result.coverage_score,
|
||||
"confidence": result.confidence,
|
||||
"matched_functions": result.matched_functions or [],
|
||||
"covered_points": result.covered_points or [],
|
||||
"missing_points": result.missing_points or [],
|
||||
"conflict_points": result.conflict_points or [],
|
||||
"call_chain_evidence": result.call_chain_evidence or [],
|
||||
"suggestion": result.suggestion or "",
|
||||
"raw_judgment": raw,
|
||||
}
|
||||
|
||||
|
||||
def _store_result(db: Session, job: ConsistencyJob, result: Any) -> None:
|
||||
result_dict = result.to_dict()
|
||||
raw_judgment = dict(result_dict.get("raw_judgment") or {})
|
||||
raw_judgment.update(
|
||||
{
|
||||
"requirement_title": result_dict.get("requirement_title"),
|
||||
"requirement_type": result_dict.get("requirement_type"),
|
||||
"requirement_text": result_dict.get("requirement_text"),
|
||||
}
|
||||
)
|
||||
db.add(
|
||||
ConsistencyResult(
|
||||
job_id=job.id,
|
||||
requirement_uid=result.requirement_uid,
|
||||
verdict=result.verdict,
|
||||
coverage_score=result.coverage_score,
|
||||
confidence=result.confidence,
|
||||
matched_functions=result.matched_functions,
|
||||
covered_points=result.covered_points,
|
||||
missing_points=result.missing_points,
|
||||
conflict_points=result.conflict_points,
|
||||
call_chain_evidence=result.call_chain_evidence,
|
||||
suggestion=result.suggestion,
|
||||
raw_judgment=raw_judgment,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _create_srs_extraction_for_job(db: Session, job: ToolJob) -> SRSExtraction:
|
||||
model_profile = ModelConfigService.require_active_config(db, job.user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
|
||||
extraction = SRSExtraction(
|
||||
job_id=job.id,
|
||||
document_name=payload["document_name"],
|
||||
document_title=payload.get("document_title") or payload["document_name"],
|
||||
generated_at=_parse_generated_at(payload.get("generated_at")),
|
||||
total_requirements=len(payload.get("requirements", [])),
|
||||
statistics=payload.get("statistics", {}),
|
||||
raw_output=payload.get("raw_output", {}),
|
||||
)
|
||||
db.add(extraction)
|
||||
db.flush()
|
||||
|
||||
for index, item in enumerate(payload.get("requirements", [])):
|
||||
requirement = SRSRequirement(
|
||||
extraction_id=extraction.id,
|
||||
requirement_uid=item.get("id") or f"REQ-{index + 1:03d}",
|
||||
title=_build_internal_title(item.get("description"), item.get("id") or "", index),
|
||||
description=item.get("description") or "",
|
||||
priority=item.get("priority") or "中",
|
||||
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
|
||||
source_field=item.get("source_field") or "文档解析",
|
||||
section_uid=item.get("section_uid"),
|
||||
section_number=item.get("section_number"),
|
||||
section_title=item.get("section_title"),
|
||||
requirement_type=item.get("requirement_type"),
|
||||
interface_name=item.get("interface_name"),
|
||||
interface_type=item.get("interface_type"),
|
||||
data_source=item.get("data_source"),
|
||||
data_destination=item.get("data_destination"),
|
||||
sort_order=int(item.get("sort_order") or index),
|
||||
)
|
||||
db.add(requirement)
|
||||
return extraction
|
||||
|
||||
|
||||
def create_auto_consistency_tool_job(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
requirement_file_path: str,
|
||||
requirement_file_name: str,
|
||||
code_source_dir: str,
|
||||
code_kb_name: str,
|
||||
top_k: int,
|
||||
max_call_hops: int,
|
||||
min_similarity: float,
|
||||
use_llm: bool,
|
||||
use_semantic: bool,
|
||||
) -> ToolJob:
|
||||
job = ToolJob(
|
||||
user_id=user_id,
|
||||
tool_name="consistency.auto_compare",
|
||||
status="pending",
|
||||
input_file_name=requirement_file_name,
|
||||
input_file_path=requirement_file_path,
|
||||
output_summary={
|
||||
"current_step": "pending",
|
||||
"code_source_dir": code_source_dir,
|
||||
"code_kb_name": code_kb_name,
|
||||
"top_k": top_k,
|
||||
"max_call_hops": max_call_hops,
|
||||
"min_similarity": min_similarity,
|
||||
"use_llm": use_llm,
|
||||
"use_semantic": use_semantic,
|
||||
},
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def get_owned_auto_job(db: Session, user_id: int, job_id: int) -> Optional[ToolJob]:
|
||||
return (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == job_id,
|
||||
ToolJob.user_id == user_id,
|
||||
ToolJob.tool_name == "consistency.auto_compare",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def run_auto_consistency_job(tool_job_id: int) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
|
||||
if not tool_job:
|
||||
return
|
||||
options = tool_job.output_summary or {}
|
||||
tool_job.status = "processing"
|
||||
tool_job.started_at = datetime.utcnow()
|
||||
tool_job.output_summary = {**options, "current_step": "extracting_requirements"}
|
||||
db.add(tool_job)
|
||||
db.commit()
|
||||
|
||||
extraction = _create_srs_extraction_for_job(db, tool_job)
|
||||
db.commit()
|
||||
|
||||
options = tool_job.output_summary or options
|
||||
code_output_dir = str((AUTO_UPLOAD_ROOT / str(tool_job.id) / "code_kb").resolve())
|
||||
code_kb = create_uploaded_code_kb(
|
||||
db,
|
||||
tool_job.user_id,
|
||||
options.get("code_kb_name") or f"auto-code-kb-{tool_job.id}",
|
||||
options["code_source_dir"],
|
||||
code_output_dir,
|
||||
)
|
||||
tool_job.output_summary = {
|
||||
**options,
|
||||
"current_step": "building_code_kb",
|
||||
"srs_extraction_id": extraction.id,
|
||||
"code_kb_id": code_kb.id,
|
||||
}
|
||||
db.add(tool_job)
|
||||
db.commit()
|
||||
|
||||
db.close()
|
||||
run_code_kb_build(code_kb.id, use_semantic=bool(options.get("use_semantic", True)))
|
||||
db = SessionLocal()
|
||||
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb.id).first()
|
||||
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
|
||||
if not tool_job or not code_kb:
|
||||
return
|
||||
if code_kb.status != "active":
|
||||
raise RuntimeError((code_kb.metadata_summary or {}).get("error_message") or "Code KB build failed.")
|
||||
|
||||
consistency_payload = ConsistencyJobCreate(
|
||||
srs_extraction_id=extraction.id,
|
||||
code_kb_id=code_kb.id,
|
||||
top_k=int(options.get("top_k", 8)),
|
||||
max_call_hops=int(options.get("max_call_hops", 2)),
|
||||
min_similarity=float(options.get("min_similarity", 0.55)),
|
||||
use_llm=bool(options.get("use_llm", True)),
|
||||
)
|
||||
consistency_job = create_consistency_job(db, tool_job.user_id, consistency_payload)
|
||||
tool_job.output_summary = {
|
||||
**(tool_job.output_summary or {}),
|
||||
"current_step": "comparing",
|
||||
"consistency_job_id": consistency_job.id,
|
||||
}
|
||||
db.add(tool_job)
|
||||
db.commit()
|
||||
|
||||
db.close()
|
||||
run_consistency_job(consistency_job.id)
|
||||
db = SessionLocal()
|
||||
consistency_job = db.query(ConsistencyJob).filter(ConsistencyJob.id == consistency_job.id).first()
|
||||
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
|
||||
if not tool_job or not consistency_job:
|
||||
return
|
||||
if consistency_job.status == "failed":
|
||||
raise RuntimeError(consistency_job.error_message or "Consistency comparison failed.")
|
||||
|
||||
tool_job.status = "completed"
|
||||
tool_job.completed_at = datetime.utcnow()
|
||||
tool_job.output_summary = {
|
||||
**(tool_job.output_summary or {}),
|
||||
"current_step": "completed",
|
||||
"consistency_job_id": consistency_job.id,
|
||||
}
|
||||
db.add(tool_job)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
if "db" not in locals() or db is None:
|
||||
db = SessionLocal()
|
||||
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
|
||||
if tool_job:
|
||||
tool_job.status = "failed"
|
||||
tool_job.error_message = str(exc)[:2000]
|
||||
tool_job.completed_at = datetime.utcnow()
|
||||
tool_job.output_summary = {
|
||||
**(tool_job.output_summary or {}),
|
||||
"current_step": "failed",
|
||||
}
|
||||
db.add(tool_job)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def run_consistency_job(job_id: int) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(ConsistencyJob).filter(ConsistencyJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
db.add(job)
|
||||
db.commit()
|
||||
|
||||
options = job.output_summary or {}
|
||||
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == job.code_kb_id).first()
|
||||
if not code_kb:
|
||||
raise RuntimeError("Code knowledge base does not exist.")
|
||||
|
||||
model_profile = ModelConfigService.require_active_config(db, job.user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
adapter = CodeKnowledgeBaseAdapter(
|
||||
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
|
||||
)
|
||||
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
|
||||
llm = None
|
||||
if bool(options.get("use_llm", True)):
|
||||
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
|
||||
comparator = ConsistencyComparator(
|
||||
adapter,
|
||||
llm=llm,
|
||||
use_llm=bool(options.get("use_llm", True)),
|
||||
)
|
||||
|
||||
query = (
|
||||
db.query(SRSRequirement)
|
||||
.filter(SRSRequirement.extraction_id == job.srs_extraction_id)
|
||||
.order_by(SRSRequirement.sort_order)
|
||||
)
|
||||
requirement_uids = options.get("requirement_uids")
|
||||
if requirement_uids:
|
||||
query = query.filter(SRSRequirement.requirement_uid.in_(requirement_uids))
|
||||
requirements = query.all()
|
||||
job.total_requirements = len(requirements)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
|
||||
verdict_counter: Counter[str] = Counter()
|
||||
for requirement in requirements:
|
||||
result = comparator.compare_requirement(
|
||||
requirement,
|
||||
top_k=int(options.get("top_k", 8)),
|
||||
max_call_hops=int(options.get("max_call_hops", 2)),
|
||||
min_similarity=float(options.get("min_similarity", 0.55)),
|
||||
)
|
||||
_store_result(db, job, result)
|
||||
verdict_counter[result.verdict] += 1
|
||||
job.completed_requirements += 1
|
||||
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
|
||||
db.add(job)
|
||||
db.commit()
|
||||
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
|
||||
db.add(job)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
if "job" in locals() and job:
|
||||
job.status = "failed"
|
||||
job.error_message = str(exc)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.add(job)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
@@ -6,6 +6,7 @@ import traceback
|
||||
import json
|
||||
from app.db.session import SessionLocal
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import UploadFile
|
||||
from langchain_community.document_loaders import (
|
||||
@@ -26,6 +27,7 @@ from minio.error import MinioException
|
||||
from minio.commonconfig import CopySource
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
|
||||
class UploadResult(BaseModel):
|
||||
file_path: str
|
||||
@@ -120,7 +122,45 @@ def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> D
|
||||
|
||||
return sanitized
|
||||
|
||||
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
|
||||
def _resolve_model_profile(db: Session, user_id: Optional[int]) -> Any:
|
||||
if user_id is None:
|
||||
return None
|
||||
return ModelConfigService.require_active_config(db, user_id)
|
||||
|
||||
|
||||
def _model_profile_snapshot(model_profile: Any) -> Any:
|
||||
if model_profile is None:
|
||||
return None
|
||||
return SimpleNamespace(
|
||||
provider=model_profile.provider,
|
||||
api_key=model_profile.api_key,
|
||||
api_base=model_profile.api_base,
|
||||
chat_model=model_profile.chat_model,
|
||||
embedding_model=model_profile.embedding_model,
|
||||
)
|
||||
|
||||
|
||||
def _load_model_profile_for_user(user_id: Optional[int]) -> Any:
|
||||
if user_id is None:
|
||||
return None
|
||||
db = SessionLocal()
|
||||
try:
|
||||
model_profile = ModelConfigService.require_active_config(db, user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
return _model_profile_snapshot(model_profile)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def process_document(
|
||||
file_path: str,
|
||||
file_name: str,
|
||||
kb_id: int,
|
||||
document_id: int,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
user_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Process document and store in vector database with incremental updates"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -129,7 +169,8 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
|
||||
|
||||
# Initialize embeddings
|
||||
logger.info("Initializing OpenAI embeddings...")
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
model_profile = _load_model_profile_for_user(user_id)
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
|
||||
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
|
||||
vector_store = VectorStoreFactory.create(
|
||||
@@ -202,7 +243,7 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
|
||||
try:
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
graph_adapter = GraphRAGAdapter()
|
||||
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
|
||||
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
|
||||
await graph_adapter.ingest_texts(kb_id, source_texts)
|
||||
logger.info("GraphRAG ingestion completed in incremental processing")
|
||||
@@ -323,7 +364,8 @@ async def process_document_background(
|
||||
task_id: int,
|
||||
db: Session = None,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200
|
||||
chunk_overlap: int = 200,
|
||||
user_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Process document in background"""
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -348,6 +390,9 @@ async def process_document_background(
|
||||
logger.info(f"Task {task_id}: Setting status to processing")
|
||||
task.status = "processing"
|
||||
db.commit()
|
||||
model_profile = _resolve_model_profile(db, user_id)
|
||||
if model_profile is not None:
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
|
||||
# 1. 从临时目录下载文件
|
||||
minio_client = get_minio_client()
|
||||
@@ -416,7 +461,7 @@ async def process_document_background(
|
||||
|
||||
# 3. 创建向量存储
|
||||
logger.info(f"Task {task_id}: Initializing vector store")
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
@@ -520,7 +565,7 @@ async def process_document_background(
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
|
||||
graph_adapter = GraphRAGAdapter()
|
||||
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
|
||||
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
|
||||
await graph_adapter.ingest_texts(kb_id, source_texts)
|
||||
logger.info(f"Task {task_id}: GraphRAG ingestion completed")
|
||||
|
||||
@@ -1,30 +1,39 @@
|
||||
from app.core.config import settings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from typing import Optional
|
||||
# If you plan on adding other embeddings, import them here
|
||||
# from some_other_module import AnotherEmbeddingClass
|
||||
|
||||
|
||||
class EmbeddingsFactory:
|
||||
@staticmethod
|
||||
def create():
|
||||
def create(provider: Optional[str] = None, model_profile: Optional[object] = None):
|
||||
"""
|
||||
Factory method to create an embeddings instance based on .env config.
|
||||
"""
|
||||
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
|
||||
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
|
||||
if model_profile is not None:
|
||||
embeddings_provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
|
||||
api_key = getattr(model_profile, "api_key", "") or ""
|
||||
api_base = getattr(model_profile, "api_base", None) or _default_api_base(embeddings_provider)
|
||||
model = getattr(model_profile, "embedding_model", None) or _default_embedding_model(embeddings_provider)
|
||||
else:
|
||||
embeddings_provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
||||
api_key = _default_api_key(embeddings_provider)
|
||||
api_base = _default_api_base(embeddings_provider)
|
||||
model = _default_embedding_model(embeddings_provider)
|
||||
|
||||
if embeddings_provider == "openai":
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=settings.OPENAI_API_KEY,
|
||||
openai_api_base=settings.OPENAI_API_BASE,
|
||||
model=settings.OPENAI_EMBEDDINGS_MODEL
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base,
|
||||
model=model
|
||||
)
|
||||
elif embeddings_provider == "dashscope":
|
||||
elif embeddings_provider in {"dashscope", "openai_compatible"}:
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=settings.DASH_SCOPE_API_KEY,
|
||||
openai_api_base=settings.DASH_SCOPE_API_BASE,
|
||||
model=settings.DASH_SCOPE_EMBEDDINGS_MODEL,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base,
|
||||
model=model,
|
||||
# DashScope OpenAI-compatible embedding expects string input,
|
||||
# while LangChain's len-safe path may send token ids.
|
||||
check_embedding_ctx_length=False,
|
||||
@@ -35,8 +44,8 @@ class EmbeddingsFactory:
|
||||
)
|
||||
elif embeddings_provider == "ollama":
|
||||
return OllamaEmbeddings(
|
||||
model=settings.OLLAMA_EMBEDDINGS_MODEL,
|
||||
base_url=settings.OLLAMA_API_BASE
|
||||
model=model,
|
||||
base_url=api_base
|
||||
)
|
||||
|
||||
# Extend with other providers:
|
||||
@@ -44,3 +53,34 @@ class EmbeddingsFactory:
|
||||
# return AnotherEmbeddingClass(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")
|
||||
|
||||
|
||||
def _default_embedding_model(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_EMBEDDINGS_MODEL
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4"
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_EMBEDDINGS_MODEL
|
||||
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or settings.OPENAI_EMBEDDINGS_MODEL
|
||||
|
||||
|
||||
def _default_api_key(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_API_KEY
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_API_KEY
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def _default_api_base(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_API_BASE
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_API_BASE
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_API_BASE
|
||||
return settings.DASH_SCOPE_API_BASE
|
||||
|
||||
@@ -13,11 +13,11 @@ from app.services.llm.llm_factory import LLMFactory
|
||||
class GraphRAGAdapter:
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model_profile: Any = None):
|
||||
self._graphrag_instances: Dict[int, Any] = {}
|
||||
self._kb_locks: Dict[int, asyncio.Lock] = {}
|
||||
self._embedding_model = EmbeddingsFactory.create()
|
||||
self._llm_model = LLMFactory.create(streaming=False)
|
||||
self._embedding_model = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
self._llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
|
||||
self._symbols = self._load_symbols()
|
||||
|
||||
def _load_symbols(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -11,42 +11,51 @@ class LLMFactory:
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
streaming: bool = True,
|
||||
model_profile: Optional[object] = None,
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Create a LLM instance based on the provider
|
||||
"""
|
||||
# If no provider specified, use the one from settings
|
||||
provider = provider or settings.CHAT_PROVIDER
|
||||
if model_profile is not None:
|
||||
provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
|
||||
model = getattr(model_profile, "chat_model", None) or _default_chat_model(provider)
|
||||
api_key = getattr(model_profile, "api_key", "") or ""
|
||||
api_base = getattr(model_profile, "api_base", None) or _default_api_base(provider)
|
||||
else:
|
||||
provider = provider or settings.CHAT_PROVIDER
|
||||
model = _default_chat_model(provider)
|
||||
api_key = _default_api_key(provider)
|
||||
api_base = _default_api_base(provider)
|
||||
|
||||
if provider.lower() == "openai":
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.OPENAI_MODEL,
|
||||
openai_api_key=settings.OPENAI_API_KEY,
|
||||
openai_api_base=settings.OPENAI_API_BASE
|
||||
model=model,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base
|
||||
)
|
||||
elif provider.lower() == "deepseek":
|
||||
return ChatDeepSeek(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.DEEPSEEK_MODEL,
|
||||
api_key=settings.DEEPSEEK_API_KEY,
|
||||
api_base=settings.DEEPSEEK_API_BASE
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
elif provider.lower() == "dashscope":
|
||||
elif provider.lower() in {"dashscope", "openai_compatible"}:
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
model=settings.DASH_SCOPE_CHAT_MODEL,
|
||||
openai_api_key=settings.DASH_SCOPE_API_KEY,
|
||||
openai_api_base=settings.DASH_SCOPE_API_BASE,
|
||||
model=model,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base,
|
||||
)
|
||||
elif provider.lower() == "ollama":
|
||||
# Initialize Ollama model
|
||||
return OllamaLLM(
|
||||
model=settings.OLLAMA_MODEL,
|
||||
base_url=settings.OLLAMA_API_BASE,
|
||||
model=model,
|
||||
base_url=api_base,
|
||||
temperature=temperature,
|
||||
streaming=streaming
|
||||
)
|
||||
@@ -54,4 +63,41 @@ class LLMFactory:
|
||||
# elif provider.lower() == "anthropic":
|
||||
# return ChatAnthropic(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
def _default_chat_model(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.CHAT_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_MODEL
|
||||
if provider == "deepseek":
|
||||
return settings.DEEPSEEK_MODEL
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_CHAT_MODEL
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_MODEL
|
||||
return settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
|
||||
|
||||
|
||||
def _default_api_key(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.CHAT_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_API_KEY
|
||||
if provider == "deepseek":
|
||||
return settings.DEEPSEEK_API_KEY
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_API_KEY
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def _default_api_base(provider: Optional[str]) -> str:
|
||||
provider = (provider or settings.CHAT_PROVIDER).lower()
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_API_BASE
|
||||
if provider == "deepseek":
|
||||
return settings.DEEPSEEK_API_BASE
|
||||
if provider == "dashscope":
|
||||
return settings.DASH_SCOPE_API_BASE
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_API_BASE
|
||||
return settings.DASH_SCOPE_API_BASE
|
||||
|
||||
211
rag-web-ui/backend/app/services/model_config.py
Normal file
211
rag-web-ui/backend/app/services/model_config.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.model_config import UserModelConfig
|
||||
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
|
||||
|
||||
|
||||
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"provider": "dashscope",
|
||||
"label": "DashScope",
|
||||
"default_api_base": settings.DASH_SCOPE_API_BASE,
|
||||
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
|
||||
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
|
||||
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
|
||||
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "openai",
|
||||
"label": "OpenAI",
|
||||
"default_api_base": settings.OPENAI_API_BASE,
|
||||
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
|
||||
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
|
||||
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||||
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "openai_compatible",
|
||||
"label": "OpenAI Compatible",
|
||||
"default_api_base": "",
|
||||
"default_chat_model": "qwen3-max",
|
||||
"default_embedding_model": "text-embedding-v4",
|
||||
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
|
||||
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "ollama",
|
||||
"label": "Ollama",
|
||||
"default_api_base": settings.OLLAMA_API_BASE,
|
||||
"default_chat_model": settings.OLLAMA_MODEL,
|
||||
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
|
||||
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
|
||||
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
|
||||
"requires_api_key": False,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def provider_options_response() -> Dict[str, Any]:
|
||||
first = PROVIDER_OPTIONS[0]
|
||||
return {
|
||||
"providers": PROVIDER_OPTIONS,
|
||||
"defaults": {
|
||||
"provider": first["provider"],
|
||||
"api_base": first["default_api_base"],
|
||||
"chat_model": first["default_chat_model"],
|
||||
"embedding_model": first["default_embedding_model"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _provider_option(provider: str) -> Dict[str, Any]:
|
||||
normalized = (provider or "dashscope").strip().lower()
|
||||
for option in PROVIDER_OPTIONS:
|
||||
if option["provider"] == normalized:
|
||||
return option
|
||||
raise ValueError(f"Unsupported model provider: {provider}")
|
||||
|
||||
|
||||
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
|
||||
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
|
||||
option = _provider_option(provider)
|
||||
api_base = payload.get("api_base")
|
||||
if api_base is None and existing is not None:
|
||||
api_base = existing.api_base
|
||||
if not api_base:
|
||||
api_base = option["default_api_base"]
|
||||
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
|
||||
raise ValueError("OpenAI Compatible provider requires an API base URL.")
|
||||
|
||||
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
|
||||
embedding_model = str(
|
||||
payload.get("embedding_model")
|
||||
or getattr(existing, "embedding_model", "")
|
||||
or option["default_embedding_model"]
|
||||
).strip()
|
||||
if not chat_model:
|
||||
raise ValueError("Chat model is required.")
|
||||
if not embedding_model:
|
||||
raise ValueError("Embedding model is required.")
|
||||
|
||||
api_key = payload.get("api_key")
|
||||
if api_key is None and existing is not None:
|
||||
api_key = existing.api_key
|
||||
api_key = str(api_key or "").strip()
|
||||
if option["requires_api_key"] and not api_key:
|
||||
raise ValueError("API key is required for this provider.")
|
||||
|
||||
name = payload.get("name")
|
||||
if name is None and existing is not None:
|
||||
name = existing.name
|
||||
name = str(name or "").strip()
|
||||
if not name:
|
||||
name = option["label"]
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": str(api_base).strip() if api_base else None,
|
||||
"chat_model": chat_model,
|
||||
"embedding_model": embedding_model,
|
||||
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
|
||||
}
|
||||
|
||||
|
||||
class ModelConfigService:
|
||||
@staticmethod
|
||||
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.user_id == user_id)
|
||||
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
|
||||
.order_by(UserModelConfig.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
|
||||
config = ModelConfigService.get_active_config(db, user_id)
|
||||
if config is None:
|
||||
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
|
||||
data = _normalized_payload(payload.model_dump())
|
||||
if data["is_active"]:
|
||||
ModelConfigService._deactivate_user_configs(db, user_id)
|
||||
item = UserModelConfig(user_id=user_id, **data)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def update_config(
|
||||
db: Session,
|
||||
item: UserModelConfig,
|
||||
payload: ModelConfigUpdate,
|
||||
) -> UserModelConfig:
|
||||
raw = payload.model_dump(exclude_unset=True)
|
||||
if raw.get("api_key") == "":
|
||||
raw.pop("api_key")
|
||||
data = _normalized_payload(raw, existing=item)
|
||||
if data["is_active"]:
|
||||
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
|
||||
for field, value in data.items():
|
||||
setattr(item, field, value)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def delete_config(db: Session, item: UserModelConfig) -> None:
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
|
||||
item.last_used_at = datetime.utcnow()
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
|
||||
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
|
||||
if exclude_id is not None:
|
||||
query = query.filter(UserModelConfig.id != exclude_id)
|
||||
query.update({UserModelConfig.is_active: False}, synchronize_session=False)
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
TYPE_TO_CHINESE = {
|
||||
@@ -63,7 +64,9 @@ def run_srs_job(job_id: int) -> None:
|
||||
job.error_message = None
|
||||
db.commit()
|
||||
|
||||
payload = get_srs_tool().run(job.input_file_path)
|
||||
model_profile = ModelConfigService.require_active_config(db, job.user_id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
|
||||
|
||||
extraction = SRSExtraction(
|
||||
job_id=job.id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -11,10 +12,15 @@ from app.db.session import SessionLocal
|
||||
from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.models.tooling import TestingGeneration, ToolJob
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline import run_testing_pipeline
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
items: List[Dict[str, Any]] = []
|
||||
for current in value.values():
|
||||
@@ -22,8 +28,15 @@ def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, An
|
||||
return items
|
||||
|
||||
|
||||
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
def _build_kb_vector_stores(
|
||||
db: Session,
|
||||
knowledge_bases: List[KnowledgeBase],
|
||||
model_profile: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
if model_profile is None:
|
||||
return []
|
||||
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
@@ -47,8 +60,9 @@ def _resolve_knowledge_context(
|
||||
user_id: int,
|
||||
requirement_text: str,
|
||||
knowledge_base_id: int | None,
|
||||
model_profile: Any,
|
||||
) -> str:
|
||||
if knowledge_base_id is None:
|
||||
if knowledge_base_id is None or model_profile is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
@@ -60,7 +74,7 @@ def _resolve_knowledge_context(
|
||||
)
|
||||
.all()
|
||||
)
|
||||
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
|
||||
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases, model_profile)
|
||||
if not kb_vector_stores:
|
||||
return ""
|
||||
|
||||
@@ -143,6 +157,27 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
|
||||
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
|
||||
source_job_id = payload.get("source_job_id")
|
||||
knowledge_base_id = payload.get("knowledge_base_id")
|
||||
model_profile = ModelConfigService.get_active_config(db, job.user_id)
|
||||
if model_profile is not None:
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
|
||||
use_model_generation = model_profile is not None
|
||||
llm_model = None
|
||||
if use_model_generation:
|
||||
try:
|
||||
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Testing generation LLM initialization failed for job=%s, falling back to rule-based output: %s",
|
||||
job_id,
|
||||
exc,
|
||||
)
|
||||
use_model_generation = False
|
||||
else:
|
||||
logger.info(
|
||||
"Testing generation job=%s has no active model config; using rule-based output.",
|
||||
job_id,
|
||||
)
|
||||
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
@@ -183,6 +218,7 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
|
||||
user_id=job.user_id,
|
||||
requirement_text=description,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
model_profile=model_profile,
|
||||
)
|
||||
|
||||
pipeline_result = run_testing_pipeline(
|
||||
@@ -190,7 +226,8 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
|
||||
requirement_type_input=req.get("requirementType"),
|
||||
debug=False,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=True,
|
||||
use_model_generation=use_model_generation,
|
||||
llm_model=llm_model,
|
||||
max_items_per_group=12,
|
||||
cases_per_item=2,
|
||||
max_focus_points=6,
|
||||
|
||||
@@ -33,13 +33,13 @@ def run_testing_pipeline(
|
||||
debug: bool = False,
|
||||
knowledge_context: Optional[str] = None,
|
||||
use_model_generation: bool = False,
|
||||
llm_model: Any = None,
|
||||
max_items_per_group: int = 12,
|
||||
cases_per_item: int = 2,
|
||||
max_focus_points: int = 6,
|
||||
max_llm_calls: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
llm_model = None
|
||||
if use_model_generation:
|
||||
if use_model_generation and llm_model is None:
|
||||
try:
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
|
||||
|
||||
Reference in New Issue
Block a user