Files

518 lines
21 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import json
import logging
import math
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from app.services.code_kb.graph import CodeCallGraph
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
logger = logging.getLogger(__name__)
FIELD_ALIASES = {
"name": ("name", "function_name"),
"file": ("file", "file_path"),
"summary": ("summary",),
"logic_flow": ("logic_flow", "logic"),
"code_snippet": ("code_snippet", "source", "code"),
"calls": ("calls", "called_functions"),
"called_by": ("called_by", "caller_functions", "callers"),
}
class SimpleVectorIndex:
def __init__(self, vectors: Any, dimension: int) -> None:
try:
import numpy as np
except ImportError: # pragma: no cover - depends on deployment environment
np = None
self._np = np
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
self.ntotal = len(self.vectors)
@classmethod
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
try:
with open(vector_path, "r", encoding="utf-8") as file:
payload = json.load(file)
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
return None
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
return None
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
def search(self, vector: Any, top_k: int) -> Any:
if self.ntotal == 0:
if self._np is not None:
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
return [[]], [[]]
if self._np is not None:
query = self._np.array(vector, dtype="float32")
if query.ndim != 2 or query.shape[1] != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
order = self._np.argsort(distances)[:top_k]
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
query_values = [float(value) for value in query_row]
if len(query_values) != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
scored = [
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
for index, stored in enumerate(self.vectors)
]
scored.sort(key=lambda item: item[0])
selected = scored[:top_k]
return [[item[0] for item in selected]], [[item[1] for item in selected]]
def has_nonzero_vectors(self) -> bool:
return any(
any(abs(float(value)) > 1e-12 for value in vector)
for vector in self.vectors
)
class UnavailableVectorIndex:
d = 0
ntotal = 0
def __init__(self, reason: str) -> None:
self.reason = reason
def search(self, vector: Any, top_k: int) -> Any:
raise RuntimeError(self.reason)
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
for alias in aliases:
value = source.get(alias)
if value not in (None, ""):
return value
return default
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, tuple):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, str):
return [value] if value else []
return [str(value)]
def _as_int(value: Any) -> Optional[int]:
if value in (None, ""):
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _as_bool(value: Any, default: bool = True) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "y"}:
return True
if normalized in {"0", "false", "no", "n"}:
return False
return default
def _load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as file:
return json.load(file)
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
metadata = _load_json(metadata_path)
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
metadata_items = metadata_items or []
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
return {
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
"graph_nodes": graph_meta.get("total_nodes"),
"graph_edges": graph_meta.get("total_edges"),
"project_root": graph_meta.get("project_root"),
"generated_at": graph_meta.get("generated_at"),
}
class CodeKnowledgeBaseAdapter:
def __init__(self, embedding_function: Any = None) -> None:
self.embedding_function = embedding_function
self.faiss_index: Any = None
self.metadata: List[Dict[str, Any]] = []
self.graph_data: Dict[str, Any] = {}
self.functions: List[CodeFunctionEvidence] = []
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
self.call_graph: Optional[CodeCallGraph] = None
self.vector_search_enabled = False
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
self._validate_paths(vector_path, metadata_path, graph_path)
self.faiss_index = self._read_faiss_index(vector_path)
self.metadata = self._read_metadata(metadata_path)
self.graph_data = _load_json(graph_path)
graph_nodes = self._index_graph_nodes(self.graph_data)
self.functions = [
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
for row in self.metadata
]
self.functions_by_id = {item.node_id: item for item in self.functions}
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
self._configure_vector_search()
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
missing = [
path
for path in [vector_path, metadata_path, graph_path]
if not path or not Path(path).exists()
]
if missing:
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
def _read_faiss_index(self, vector_path: str) -> Any:
simple_index = SimpleVectorIndex.from_file(vector_path)
if simple_index is not None:
return simple_index
try:
import faiss # type: ignore
except ImportError as exc:
logger.warning("faiss is not installed; code search will use lexical fallback.")
return UnavailableVectorIndex(
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
)
return faiss.read_index(vector_path)
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
metadata = _load_json(metadata_path)
if isinstance(metadata, dict):
metadata = metadata.get("functions", metadata.get("items", []))
if not isinstance(metadata, list):
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
return [item for item in metadata if isinstance(item, dict)]
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
item.get("id"): item
for item in graph_data.get("nodes", []) or []
if isinstance(item, dict) and item.get("id")
}
def _normalize_metadata(
self,
metadata: Dict[str, Any],
graph_nodes: Dict[str, Dict[str, Any]],
index_dimension: int,
) -> CodeFunctionEvidence:
name = _first_value(metadata, FIELD_ALIASES["name"], "")
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
graph_node = graph_nodes.get(node_id, {})
raw_attributes = graph_node.get("raw_attributes") or {}
if not name:
name = graph_node.get("name") or node_id.removeprefix("Function:")
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
embedding_dim = metadata.get("embedding_dim") or index_dimension
embedding_available = _as_bool(
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
default=True,
)
return CodeFunctionEvidence(
node_id=node_id,
name=name,
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
file=file_path or "",
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
signature=metadata.get("signature") or graph_node.get("signature") or "",
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
logic_flow=_first_value(
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
)
or "",
code_snippet=_first_value(
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
)
or "",
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
called_by=_as_list(
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
),
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
embedding_model=metadata.get("embedding_model") or "",
embedding_dim=int(embedding_dim or 0),
embedding_available=embedding_available,
raw=metadata,
)
def _configure_vector_search(self) -> None:
index_total = int(getattr(self.faiss_index, "ntotal", 0))
self.vector_search_enabled = False
self.vector_search_disabled_reason = ""
if isinstance(self.faiss_index, UnavailableVectorIndex):
self.vector_search_disabled_reason = self.faiss_index.reason
return
if index_total == 0:
self.vector_search_disabled_reason = "Vector index is empty."
return
if index_total != len(self.functions):
self.vector_search_disabled_reason = (
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
)
logger.warning(
"FAISS index size (%s) differs from metadata size (%s)",
index_total,
len(self.functions),
)
return
if not self._index_has_nonzero_vectors():
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
if not any(item.embedding_available for item in self.functions):
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
self.vector_search_enabled = True
def _index_has_nonzero_vectors(self) -> bool:
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
if callable(has_nonzero_vectors):
return bool(has_nonzero_vectors())
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
if not callable(reconstruct_n):
return True
try:
import numpy as np
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
sample_size = min(index_total, 1000)
if sample_size <= 0:
return False
vectors = reconstruct_n(0, sample_size)
return bool(np.any(np.abs(vectors) > 1e-12))
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
return True
def _get_embedding_function(self) -> Any:
if self.embedding_function is None:
from app.services.embedding.embedding_factory import EmbeddingsFactory
self.embedding_function = EmbeddingsFactory.create()
return self.embedding_function
def _embed_query(self, query: str) -> np.ndarray:
try:
import numpy as np
except ImportError as exc:
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
embedding_function = self._get_embedding_function()
if hasattr(embedding_function, "embed_query"):
vector = embedding_function.embed_query(query)
elif callable(embedding_function):
vector = embedding_function(query)
else:
raise TypeError("Unsupported embedding function.")
query_vector = np.array([vector], dtype="float32")
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
if index_dimension and query_vector.shape[1] != index_dimension:
raise ValueError(
f"Query embedding dimension {query_vector.shape[1]} does not match "
f"FAISS index dimension {index_dimension}."
)
return query_vector
@staticmethod
def distance_to_similarity(distance: float) -> float:
if math.isnan(distance) or distance < 0:
return 0.0
if distance <= 2.0:
return max(0.0, min(1.0, 1.0 - distance / 2.0))
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
def search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
if self.faiss_index is None:
raise RuntimeError("Code knowledge base has not been loaded.")
if not query.strip():
return []
if not self.vector_search_enabled:
logger.info(
"Vector code search disabled, using lexical search: %s",
self.vector_search_disabled_reason,
)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
try:
vector = self._embed_query(query)
distances, indices = self.faiss_index.search(vector, top_k)
except Exception as exc:
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
hits: List[CodeSearchHit] = []
for rank, raw_index in enumerate(indices[0], start=1):
index = int(raw_index)
if index < 0 or index >= len(self.functions):
continue
evidence = self.functions[index]
if not evidence.embedding_available:
continue
distance = float(distances[0][rank - 1])
similarity = self.distance_to_similarity(distance)
if similarity < min_similarity:
continue
hits.append(
CodeSearchHit(
evidence=evidence,
similarity=similarity,
distance=distance,
rank=rank,
)
)
if not hits:
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
return hits
def _lexical_search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
query_tokens = self._tokens(query)
if not query_tokens:
return []
scored: List[CodeSearchHit] = []
for evidence in self.functions:
text = self._function_search_text(evidence)
score = self._lexical_similarity(query_tokens, text, evidence)
if score < min_similarity:
continue
scored.append(
CodeSearchHit(
evidence=evidence,
similarity=score,
distance=max(0.0, 1.0 - score),
rank=0,
)
)
scored.sort(key=lambda item: item.similarity, reverse=True)
for rank, item in enumerate(scored[:top_k], start=1):
item.rank = rank
return scored[:top_k]
@staticmethod
def _tokens(text: str) -> List[str]:
normalized = (text or "").lower()
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
expanded: List[str] = []
for token in tokens:
expanded.append(token)
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
return expanded
@staticmethod
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
return "\n".join(
[
evidence.node_id,
evidence.name,
evidence.qualified_name,
evidence.file,
evidence.signature,
evidence.summary,
evidence.logic_flow,
evidence.code_snippet[:4000],
" ".join(evidence.calls),
" ".join(evidence.called_by),
" ".join(evidence.includes),
]
).lower()
def _lexical_similarity(
self,
query_tokens: List[str],
text: str,
evidence: CodeFunctionEvidence,
) -> float:
text_tokens = set(self._tokens(text))
if not text_tokens:
return 0.0
unique_query_tokens = list(dict.fromkeys(query_tokens))
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
substring_count = sum(1 for token in unique_query_tokens if token in text)
overlap_score = overlap_count / max(1, len(unique_query_tokens))
substring_score = substring_count / max(1, len(unique_query_tokens))
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
evidence_score = 0.0
if evidence.summary:
evidence_score += 0.08
if evidence.logic_flow:
evidence_score += 0.05
if evidence.code_snippet:
evidence_score += 0.04
if evidence.start_line is not None and evidence.file:
evidence_score += 0.03
return max(
0.0,
min(
1.0,
overlap_score * 0.45
+ substring_score * 0.35
+ name_score * 0.12
+ evidence_score,
),
)
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
return self.functions_by_id.get(node_id)
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
if not self.call_graph:
return CodeGraphContext(node_id=node_id)
return self.call_graph.expand(node_id, max_hops=max_hops)
def summary(self) -> Dict[str, Any]:
return {
"function_count": len(self.functions),
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
"vector_search_enabled": self.vector_search_enabled,
"vector_search_disabled_reason": self.vector_search_disabled_reason,
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
"graph_edges": len(self.graph_data.get("edges", []) or []),
}