518 lines
21 KiB
Python
518 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
from app.services.code_kb.graph import CodeCallGraph
|
|
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
FIELD_ALIASES = {
|
|
"name": ("name", "function_name"),
|
|
"file": ("file", "file_path"),
|
|
"summary": ("summary",),
|
|
"logic_flow": ("logic_flow", "logic"),
|
|
"code_snippet": ("code_snippet", "source", "code"),
|
|
"calls": ("calls", "called_functions"),
|
|
"called_by": ("called_by", "caller_functions", "callers"),
|
|
}
|
|
|
|
|
|
class SimpleVectorIndex:
|
|
def __init__(self, vectors: Any, dimension: int) -> None:
|
|
try:
|
|
import numpy as np
|
|
except ImportError: # pragma: no cover - depends on deployment environment
|
|
np = None
|
|
self._np = np
|
|
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
|
|
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
|
|
self.ntotal = len(self.vectors)
|
|
|
|
@classmethod
|
|
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
|
|
try:
|
|
with open(vector_path, "r", encoding="utf-8") as file:
|
|
payload = json.load(file)
|
|
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
|
|
return None
|
|
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
|
|
return None
|
|
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
|
|
|
|
def search(self, vector: Any, top_k: int) -> Any:
|
|
if self.ntotal == 0:
|
|
if self._np is not None:
|
|
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
|
|
return [[]], [[]]
|
|
if self._np is not None:
|
|
query = self._np.array(vector, dtype="float32")
|
|
if query.ndim != 2 or query.shape[1] != self.d:
|
|
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
|
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
|
|
order = self._np.argsort(distances)[:top_k]
|
|
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
|
|
|
|
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
|
|
query_values = [float(value) for value in query_row]
|
|
if len(query_values) != self.d:
|
|
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
|
scored = [
|
|
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
|
|
for index, stored in enumerate(self.vectors)
|
|
]
|
|
scored.sort(key=lambda item: item[0])
|
|
selected = scored[:top_k]
|
|
return [[item[0] for item in selected]], [[item[1] for item in selected]]
|
|
|
|
def has_nonzero_vectors(self) -> bool:
|
|
return any(
|
|
any(abs(float(value)) > 1e-12 for value in vector)
|
|
for vector in self.vectors
|
|
)
|
|
|
|
|
|
class UnavailableVectorIndex:
|
|
d = 0
|
|
ntotal = 0
|
|
|
|
def __init__(self, reason: str) -> None:
|
|
self.reason = reason
|
|
|
|
def search(self, vector: Any, top_k: int) -> Any:
|
|
raise RuntimeError(self.reason)
|
|
|
|
|
|
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
|
|
for alias in aliases:
|
|
value = source.get(alias)
|
|
if value not in (None, ""):
|
|
return value
|
|
return default
|
|
|
|
|
|
def _as_list(value: Any) -> List[str]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return [str(item) for item in value if item not in (None, "")]
|
|
if isinstance(value, tuple):
|
|
return [str(item) for item in value if item not in (None, "")]
|
|
if isinstance(value, str):
|
|
return [value] if value else []
|
|
return [str(value)]
|
|
|
|
|
|
def _as_int(value: Any) -> Optional[int]:
|
|
if value in (None, ""):
|
|
return None
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
def _as_bool(value: Any, default: bool = True) -> bool:
|
|
if value is None:
|
|
return default
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, (int, float)):
|
|
return bool(value)
|
|
if isinstance(value, str):
|
|
normalized = value.strip().lower()
|
|
if normalized in {"1", "true", "yes", "y"}:
|
|
return True
|
|
if normalized in {"0", "false", "no", "n"}:
|
|
return False
|
|
return default
|
|
|
|
|
|
def _load_json(path: str) -> Any:
|
|
with open(path, "r", encoding="utf-8") as file:
|
|
return json.load(file)
|
|
|
|
|
|
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
|
|
metadata = _load_json(metadata_path)
|
|
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
|
|
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
|
|
metadata_items = metadata_items or []
|
|
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
|
|
return {
|
|
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
|
|
"graph_nodes": graph_meta.get("total_nodes"),
|
|
"graph_edges": graph_meta.get("total_edges"),
|
|
"project_root": graph_meta.get("project_root"),
|
|
"generated_at": graph_meta.get("generated_at"),
|
|
}
|
|
|
|
|
|
class CodeKnowledgeBaseAdapter:
|
|
def __init__(self, embedding_function: Any = None) -> None:
|
|
self.embedding_function = embedding_function
|
|
self.faiss_index: Any = None
|
|
self.metadata: List[Dict[str, Any]] = []
|
|
self.graph_data: Dict[str, Any] = {}
|
|
self.functions: List[CodeFunctionEvidence] = []
|
|
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
|
|
self.call_graph: Optional[CodeCallGraph] = None
|
|
self.vector_search_enabled = False
|
|
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
|
|
|
|
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
|
self._validate_paths(vector_path, metadata_path, graph_path)
|
|
self.faiss_index = self._read_faiss_index(vector_path)
|
|
self.metadata = self._read_metadata(metadata_path)
|
|
self.graph_data = _load_json(graph_path)
|
|
graph_nodes = self._index_graph_nodes(self.graph_data)
|
|
|
|
self.functions = [
|
|
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
|
|
for row in self.metadata
|
|
]
|
|
self.functions_by_id = {item.node_id: item for item in self.functions}
|
|
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
|
|
self._configure_vector_search()
|
|
|
|
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
|
missing = [
|
|
path
|
|
for path in [vector_path, metadata_path, graph_path]
|
|
if not path or not Path(path).exists()
|
|
]
|
|
if missing:
|
|
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
|
|
|
|
def _read_faiss_index(self, vector_path: str) -> Any:
|
|
simple_index = SimpleVectorIndex.from_file(vector_path)
|
|
if simple_index is not None:
|
|
return simple_index
|
|
try:
|
|
import faiss # type: ignore
|
|
except ImportError as exc:
|
|
logger.warning("faiss is not installed; code search will use lexical fallback.")
|
|
return UnavailableVectorIndex(
|
|
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
|
|
)
|
|
return faiss.read_index(vector_path)
|
|
|
|
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
|
|
metadata = _load_json(metadata_path)
|
|
if isinstance(metadata, dict):
|
|
metadata = metadata.get("functions", metadata.get("items", []))
|
|
if not isinstance(metadata, list):
|
|
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
|
|
return [item for item in metadata if isinstance(item, dict)]
|
|
|
|
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
|
return {
|
|
item.get("id"): item
|
|
for item in graph_data.get("nodes", []) or []
|
|
if isinstance(item, dict) and item.get("id")
|
|
}
|
|
|
|
def _normalize_metadata(
|
|
self,
|
|
metadata: Dict[str, Any],
|
|
graph_nodes: Dict[str, Dict[str, Any]],
|
|
index_dimension: int,
|
|
) -> CodeFunctionEvidence:
|
|
name = _first_value(metadata, FIELD_ALIASES["name"], "")
|
|
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
|
|
graph_node = graph_nodes.get(node_id, {})
|
|
raw_attributes = graph_node.get("raw_attributes") or {}
|
|
if not name:
|
|
name = graph_node.get("name") or node_id.removeprefix("Function:")
|
|
|
|
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
|
|
embedding_dim = metadata.get("embedding_dim") or index_dimension
|
|
embedding_available = _as_bool(
|
|
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
|
|
default=True,
|
|
)
|
|
return CodeFunctionEvidence(
|
|
node_id=node_id,
|
|
name=name,
|
|
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
|
|
file=file_path or "",
|
|
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
|
|
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
|
|
signature=metadata.get("signature") or graph_node.get("signature") or "",
|
|
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
|
|
logic_flow=_first_value(
|
|
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
|
|
)
|
|
or "",
|
|
code_snippet=_first_value(
|
|
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
|
|
)
|
|
or "",
|
|
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
|
|
called_by=_as_list(
|
|
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
|
|
),
|
|
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
|
|
embedding_model=metadata.get("embedding_model") or "",
|
|
embedding_dim=int(embedding_dim or 0),
|
|
embedding_available=embedding_available,
|
|
raw=metadata,
|
|
)
|
|
|
|
def _configure_vector_search(self) -> None:
|
|
index_total = int(getattr(self.faiss_index, "ntotal", 0))
|
|
self.vector_search_enabled = False
|
|
self.vector_search_disabled_reason = ""
|
|
if isinstance(self.faiss_index, UnavailableVectorIndex):
|
|
self.vector_search_disabled_reason = self.faiss_index.reason
|
|
return
|
|
if index_total == 0:
|
|
self.vector_search_disabled_reason = "Vector index is empty."
|
|
return
|
|
if index_total != len(self.functions):
|
|
self.vector_search_disabled_reason = (
|
|
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
|
|
)
|
|
logger.warning(
|
|
"FAISS index size (%s) differs from metadata size (%s)",
|
|
index_total,
|
|
len(self.functions),
|
|
)
|
|
return
|
|
if not self._index_has_nonzero_vectors():
|
|
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
|
|
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
|
return
|
|
if not any(item.embedding_available for item in self.functions):
|
|
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
|
|
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
|
return
|
|
self.vector_search_enabled = True
|
|
|
|
def _index_has_nonzero_vectors(self) -> bool:
|
|
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
|
|
if callable(has_nonzero_vectors):
|
|
return bool(has_nonzero_vectors())
|
|
|
|
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
|
|
if not callable(reconstruct_n):
|
|
return True
|
|
try:
|
|
import numpy as np
|
|
|
|
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
|
|
sample_size = min(index_total, 1000)
|
|
if sample_size <= 0:
|
|
return False
|
|
vectors = reconstruct_n(0, sample_size)
|
|
return bool(np.any(np.abs(vectors) > 1e-12))
|
|
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
|
|
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
|
|
return True
|
|
|
|
def _get_embedding_function(self) -> Any:
|
|
if self.embedding_function is None:
|
|
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
|
|
|
self.embedding_function = EmbeddingsFactory.create()
|
|
return self.embedding_function
|
|
|
|
def _embed_query(self, query: str) -> np.ndarray:
|
|
try:
|
|
import numpy as np
|
|
except ImportError as exc:
|
|
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
|
|
embedding_function = self._get_embedding_function()
|
|
if hasattr(embedding_function, "embed_query"):
|
|
vector = embedding_function.embed_query(query)
|
|
elif callable(embedding_function):
|
|
vector = embedding_function(query)
|
|
else:
|
|
raise TypeError("Unsupported embedding function.")
|
|
query_vector = np.array([vector], dtype="float32")
|
|
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
|
|
if index_dimension and query_vector.shape[1] != index_dimension:
|
|
raise ValueError(
|
|
f"Query embedding dimension {query_vector.shape[1]} does not match "
|
|
f"FAISS index dimension {index_dimension}."
|
|
)
|
|
return query_vector
|
|
|
|
@staticmethod
|
|
def distance_to_similarity(distance: float) -> float:
|
|
if math.isnan(distance) or distance < 0:
|
|
return 0.0
|
|
if distance <= 2.0:
|
|
return max(0.0, min(1.0, 1.0 - distance / 2.0))
|
|
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
|
|
|
|
def search_functions(
|
|
self,
|
|
query: str,
|
|
top_k: int = 8,
|
|
min_similarity: float = 0.0,
|
|
) -> List[CodeSearchHit]:
|
|
if self.faiss_index is None:
|
|
raise RuntimeError("Code knowledge base has not been loaded.")
|
|
if not query.strip():
|
|
return []
|
|
if not self.vector_search_enabled:
|
|
logger.info(
|
|
"Vector code search disabled, using lexical search: %s",
|
|
self.vector_search_disabled_reason,
|
|
)
|
|
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
|
|
|
try:
|
|
vector = self._embed_query(query)
|
|
distances, indices = self.faiss_index.search(vector, top_k)
|
|
except Exception as exc:
|
|
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
|
|
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
|
|
|
hits: List[CodeSearchHit] = []
|
|
for rank, raw_index in enumerate(indices[0], start=1):
|
|
index = int(raw_index)
|
|
if index < 0 or index >= len(self.functions):
|
|
continue
|
|
evidence = self.functions[index]
|
|
if not evidence.embedding_available:
|
|
continue
|
|
distance = float(distances[0][rank - 1])
|
|
similarity = self.distance_to_similarity(distance)
|
|
if similarity < min_similarity:
|
|
continue
|
|
hits.append(
|
|
CodeSearchHit(
|
|
evidence=evidence,
|
|
similarity=similarity,
|
|
distance=distance,
|
|
rank=rank,
|
|
)
|
|
)
|
|
if not hits:
|
|
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
|
return hits
|
|
|
|
def _lexical_search_functions(
|
|
self,
|
|
query: str,
|
|
top_k: int = 8,
|
|
min_similarity: float = 0.0,
|
|
) -> List[CodeSearchHit]:
|
|
query_tokens = self._tokens(query)
|
|
if not query_tokens:
|
|
return []
|
|
scored: List[CodeSearchHit] = []
|
|
for evidence in self.functions:
|
|
text = self._function_search_text(evidence)
|
|
score = self._lexical_similarity(query_tokens, text, evidence)
|
|
if score < min_similarity:
|
|
continue
|
|
scored.append(
|
|
CodeSearchHit(
|
|
evidence=evidence,
|
|
similarity=score,
|
|
distance=max(0.0, 1.0 - score),
|
|
rank=0,
|
|
)
|
|
)
|
|
scored.sort(key=lambda item: item.similarity, reverse=True)
|
|
for rank, item in enumerate(scored[:top_k], start=1):
|
|
item.rank = rank
|
|
return scored[:top_k]
|
|
|
|
@staticmethod
|
|
def _tokens(text: str) -> List[str]:
|
|
normalized = (text or "").lower()
|
|
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
|
|
expanded: List[str] = []
|
|
for token in tokens:
|
|
expanded.append(token)
|
|
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
|
|
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
|
|
return expanded
|
|
|
|
@staticmethod
|
|
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
|
|
return "\n".join(
|
|
[
|
|
evidence.node_id,
|
|
evidence.name,
|
|
evidence.qualified_name,
|
|
evidence.file,
|
|
evidence.signature,
|
|
evidence.summary,
|
|
evidence.logic_flow,
|
|
evidence.code_snippet[:4000],
|
|
" ".join(evidence.calls),
|
|
" ".join(evidence.called_by),
|
|
" ".join(evidence.includes),
|
|
]
|
|
).lower()
|
|
|
|
def _lexical_similarity(
|
|
self,
|
|
query_tokens: List[str],
|
|
text: str,
|
|
evidence: CodeFunctionEvidence,
|
|
) -> float:
|
|
text_tokens = set(self._tokens(text))
|
|
if not text_tokens:
|
|
return 0.0
|
|
unique_query_tokens = list(dict.fromkeys(query_tokens))
|
|
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
|
|
substring_count = sum(1 for token in unique_query_tokens if token in text)
|
|
overlap_score = overlap_count / max(1, len(unique_query_tokens))
|
|
substring_score = substring_count / max(1, len(unique_query_tokens))
|
|
|
|
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
|
|
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
|
|
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
|
|
evidence_score = 0.0
|
|
if evidence.summary:
|
|
evidence_score += 0.08
|
|
if evidence.logic_flow:
|
|
evidence_score += 0.05
|
|
if evidence.code_snippet:
|
|
evidence_score += 0.04
|
|
if evidence.start_line is not None and evidence.file:
|
|
evidence_score += 0.03
|
|
|
|
return max(
|
|
0.0,
|
|
min(
|
|
1.0,
|
|
overlap_score * 0.45
|
|
+ substring_score * 0.35
|
|
+ name_score * 0.12
|
|
+ evidence_score,
|
|
),
|
|
)
|
|
|
|
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
|
|
return self.functions_by_id.get(node_id)
|
|
|
|
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
|
if not self.call_graph:
|
|
return CodeGraphContext(node_id=node_id)
|
|
return self.call_graph.expand(node_id, max_hops=max_hops)
|
|
|
|
def summary(self) -> Dict[str, Any]:
|
|
return {
|
|
"function_count": len(self.functions),
|
|
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
|
|
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
|
|
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
|
|
"vector_search_enabled": self.vector_search_enabled,
|
|
"vector_search_disabled_reason": self.vector_search_disabled_reason,
|
|
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
|
|
"graph_edges": len(self.graph_data.get("edges", []) or []),
|
|
}
|