from __future__ import annotations import json import logging import math import re from pathlib import Path from typing import Any, Dict, Iterable, List, Optional from app.services.code_kb.graph import CodeCallGraph from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit logger = logging.getLogger(__name__) FIELD_ALIASES = { "name": ("name", "function_name"), "file": ("file", "file_path"), "summary": ("summary",), "logic_flow": ("logic_flow", "logic"), "code_snippet": ("code_snippet", "source", "code"), "calls": ("calls", "called_functions"), "called_by": ("called_by", "caller_functions", "callers"), } class SimpleVectorIndex: def __init__(self, vectors: Any, dimension: int) -> None: try: import numpy as np except ImportError: # pragma: no cover - depends on deployment environment np = None self._np = np self.vectors = [[float(value) for value in vector] for vector in (vectors or [])] self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0)) self.ntotal = len(self.vectors) @classmethod def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]: try: with open(vector_path, "r", encoding="utf-8") as file: payload = json.load(file) except (OSError, UnicodeDecodeError, json.JSONDecodeError): return None if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index": return None return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0)) def search(self, vector: Any, top_k: int) -> Any: if self.ntotal == 0: if self._np is not None: return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64") return [[]], [[]] if self._np is not None: query = self._np.array(vector, dtype="float32") if query.ndim != 2 or query.shape[1] != self.d: raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.") distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1) order = self._np.argsort(distances)[:top_k] return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64") query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector query_values = [float(value) for value in query_row] if len(query_values) != self.d: raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.") scored = [ (sum((left - right) ** 2 for left, right in zip(stored, query_values)), index) for index, stored in enumerate(self.vectors) ] scored.sort(key=lambda item: item[0]) selected = scored[:top_k] return [[item[0] for item in selected]], [[item[1] for item in selected]] def has_nonzero_vectors(self) -> bool: return any( any(abs(float(value)) > 1e-12 for value in vector) for vector in self.vectors ) class UnavailableVectorIndex: d = 0 ntotal = 0 def __init__(self, reason: str) -> None: self.reason = reason def search(self, vector: Any, top_k: int) -> Any: raise RuntimeError(self.reason) def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any: for alias in aliases: value = source.get(alias) if value not in (None, ""): return value return default def _as_list(value: Any) -> List[str]: if value is None: return [] if isinstance(value, list): return [str(item) for item in value if item not in (None, "")] if isinstance(value, tuple): return [str(item) for item in value if item not in (None, "")] if isinstance(value, str): return [value] if value else [] return [str(value)] def _as_int(value: Any) -> Optional[int]: if value in (None, ""): return None try: return int(value) except (TypeError, ValueError): return None def _as_bool(value: Any, default: bool = True) -> bool: if value is None: return default if isinstance(value, bool): return value if isinstance(value, (int, float)): return bool(value) if isinstance(value, str): normalized = value.strip().lower() if normalized in {"1", "true", "yes", "y"}: return True if normalized in {"0", "false", "no", "n"}: return False return default def _load_json(path: str) -> Any: with open(path, "r", encoding="utf-8") as file: return json.load(file) def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]: metadata = _load_json(metadata_path) graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {} metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata metadata_items = metadata_items or [] graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {} return { "function_count": len(metadata_items) if isinstance(metadata_items, list) else 0, "graph_nodes": graph_meta.get("total_nodes"), "graph_edges": graph_meta.get("total_edges"), "project_root": graph_meta.get("project_root"), "generated_at": graph_meta.get("generated_at"), } class CodeKnowledgeBaseAdapter: def __init__(self, embedding_function: Any = None) -> None: self.embedding_function = embedding_function self.faiss_index: Any = None self.metadata: List[Dict[str, Any]] = [] self.graph_data: Dict[str, Any] = {} self.functions: List[CodeFunctionEvidence] = [] self.functions_by_id: Dict[str, CodeFunctionEvidence] = {} self.call_graph: Optional[CodeCallGraph] = None self.vector_search_enabled = False self.vector_search_disabled_reason = "Code knowledge base has not been loaded." def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None: self._validate_paths(vector_path, metadata_path, graph_path) self.faiss_index = self._read_faiss_index(vector_path) self.metadata = self._read_metadata(metadata_path) self.graph_data = _load_json(graph_path) graph_nodes = self._index_graph_nodes(self.graph_data) self.functions = [ self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d) for row in self.metadata ] self.functions_by_id = {item.node_id: item for item in self.functions} self.call_graph = CodeCallGraph(self.graph_data, self.functions) self._configure_vector_search() def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None: missing = [ path for path in [vector_path, metadata_path, graph_path] if not path or not Path(path).exists() ] if missing: raise FileNotFoundError(f"Code knowledge base files not found: {missing}") def _read_faiss_index(self, vector_path: str) -> Any: simple_index = SimpleVectorIndex.from_file(vector_path) if simple_index is not None: return simple_index try: import faiss # type: ignore except ImportError as exc: logger.warning("faiss is not installed; code search will use lexical fallback.") return UnavailableVectorIndex( "faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB." ) return faiss.read_index(vector_path) def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]: metadata = _load_json(metadata_path) if isinstance(metadata, dict): metadata = metadata.get("functions", metadata.get("items", [])) if not isinstance(metadata, list): raise ValueError("Code KB metadata must be a list or a dict with functions/items.") return [item for item in metadata if isinstance(item, dict)] def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: return { item.get("id"): item for item in graph_data.get("nodes", []) or [] if isinstance(item, dict) and item.get("id") } def _normalize_metadata( self, metadata: Dict[str, Any], graph_nodes: Dict[str, Dict[str, Any]], index_dimension: int, ) -> CodeFunctionEvidence: name = _first_value(metadata, FIELD_ALIASES["name"], "") node_id = metadata.get("node_id") or (f"Function:{name}" if name else "") graph_node = graph_nodes.get(node_id, {}) raw_attributes = graph_node.get("raw_attributes") or {} if not name: name = graph_node.get("name") or node_id.removeprefix("Function:") file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", "")) embedding_dim = metadata.get("embedding_dim") or index_dimension embedding_available = _as_bool( metadata.get("embedding_available", raw_attributes.get("embedding_available")), default=True, ) return CodeFunctionEvidence( node_id=node_id, name=name, qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name, file=file_path or "", start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")), end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")), signature=metadata.get("signature") or graph_node.get("signature") or "", summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "", logic_flow=_first_value( metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "") ) or "", code_snippet=_first_value( metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "") ) or "", calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))), called_by=_as_list( _first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", [])) ), includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")), embedding_model=metadata.get("embedding_model") or "", embedding_dim=int(embedding_dim or 0), embedding_available=embedding_available, raw=metadata, ) def _configure_vector_search(self) -> None: index_total = int(getattr(self.faiss_index, "ntotal", 0)) self.vector_search_enabled = False self.vector_search_disabled_reason = "" if isinstance(self.faiss_index, UnavailableVectorIndex): self.vector_search_disabled_reason = self.faiss_index.reason return if index_total == 0: self.vector_search_disabled_reason = "Vector index is empty." return if index_total != len(self.functions): self.vector_search_disabled_reason = ( f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})." ) logger.warning( "FAISS index size (%s) differs from metadata size (%s)", index_total, len(self.functions), ) return if not self._index_has_nonzero_vectors(): self.vector_search_disabled_reason = "Vector index contains only zero embeddings." logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason) return if not any(item.embedding_available for item in self.functions): self.vector_search_disabled_reason = "No function metadata has usable embeddings." logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason) return self.vector_search_enabled = True def _index_has_nonzero_vectors(self) -> bool: has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None) if callable(has_nonzero_vectors): return bool(has_nonzero_vectors()) reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None) if not callable(reconstruct_n): return True try: import numpy as np index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0) sample_size = min(index_total, 1000) if sample_size <= 0: return False vectors = reconstruct_n(0, sample_size) return bool(np.any(np.abs(vectors) > 1e-12)) except Exception as exc: # pragma: no cover - depends on FAISS capabilities logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc) return True def _get_embedding_function(self) -> Any: if self.embedding_function is None: from app.services.embedding.embedding_factory import EmbeddingsFactory self.embedding_function = EmbeddingsFactory.create() return self.embedding_function def _embed_query(self, query: str) -> np.ndarray: try: import numpy as np except ImportError as exc: raise RuntimeError("numpy is required to search a code knowledge base.") from exc embedding_function = self._get_embedding_function() if hasattr(embedding_function, "embed_query"): vector = embedding_function.embed_query(query) elif callable(embedding_function): vector = embedding_function(query) else: raise TypeError("Unsupported embedding function.") query_vector = np.array([vector], dtype="float32") index_dimension = int(getattr(self.faiss_index, "d", 0) or 0) if index_dimension and query_vector.shape[1] != index_dimension: raise ValueError( f"Query embedding dimension {query_vector.shape[1]} does not match " f"FAISS index dimension {index_dimension}." ) return query_vector @staticmethod def distance_to_similarity(distance: float) -> float: if math.isnan(distance) or distance < 0: return 0.0 if distance <= 2.0: return max(0.0, min(1.0, 1.0 - distance / 2.0)) return max(0.0, min(1.0, 1.0 / (1.0 + distance))) def search_functions( self, query: str, top_k: int = 8, min_similarity: float = 0.0, ) -> List[CodeSearchHit]: if self.faiss_index is None: raise RuntimeError("Code knowledge base has not been loaded.") if not query.strip(): return [] if not self.vector_search_enabled: logger.info( "Vector code search disabled, using lexical search: %s", self.vector_search_disabled_reason, ) return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity) try: vector = self._embed_query(query) distances, indices = self.faiss_index.search(vector, top_k) except Exception as exc: logger.warning("Vector code search failed, falling back to lexical search: %s", exc) return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity) hits: List[CodeSearchHit] = [] for rank, raw_index in enumerate(indices[0], start=1): index = int(raw_index) if index < 0 or index >= len(self.functions): continue evidence = self.functions[index] if not evidence.embedding_available: continue distance = float(distances[0][rank - 1]) similarity = self.distance_to_similarity(distance) if similarity < min_similarity: continue hits.append( CodeSearchHit( evidence=evidence, similarity=similarity, distance=distance, rank=rank, ) ) if not hits: return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity) return hits def _lexical_search_functions( self, query: str, top_k: int = 8, min_similarity: float = 0.0, ) -> List[CodeSearchHit]: query_tokens = self._tokens(query) if not query_tokens: return [] scored: List[CodeSearchHit] = [] for evidence in self.functions: text = self._function_search_text(evidence) score = self._lexical_similarity(query_tokens, text, evidence) if score < min_similarity: continue scored.append( CodeSearchHit( evidence=evidence, similarity=score, distance=max(0.0, 1.0 - score), rank=0, ) ) scored.sort(key=lambda item: item.similarity, reverse=True) for rank, item in enumerate(scored[:top_k], start=1): item.rank = rank return scored[:top_k] @staticmethod def _tokens(text: str) -> List[str]: normalized = (text or "").lower() tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized) expanded: List[str] = [] for token in tokens: expanded.append(token) if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token): expanded.extend(token[index : index + 2] for index in range(len(token) - 1)) return expanded @staticmethod def _function_search_text(evidence: CodeFunctionEvidence) -> str: return "\n".join( [ evidence.node_id, evidence.name, evidence.qualified_name, evidence.file, evidence.signature, evidence.summary, evidence.logic_flow, evidence.code_snippet[:4000], " ".join(evidence.calls), " ".join(evidence.called_by), " ".join(evidence.includes), ] ).lower() def _lexical_similarity( self, query_tokens: List[str], text: str, evidence: CodeFunctionEvidence, ) -> float: text_tokens = set(self._tokens(text)) if not text_tokens: return 0.0 unique_query_tokens = list(dict.fromkeys(query_tokens)) overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens) substring_count = sum(1 for token in unique_query_tokens if token in text) overlap_score = overlap_count / max(1, len(unique_query_tokens)) substring_score = substring_count / max(1, len(unique_query_tokens)) name_text = f"{evidence.name} {evidence.qualified_name}".lower() name_hits = sum(1 for token in unique_query_tokens if token in name_text) name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4))) evidence_score = 0.0 if evidence.summary: evidence_score += 0.08 if evidence.logic_flow: evidence_score += 0.05 if evidence.code_snippet: evidence_score += 0.04 if evidence.start_line is not None and evidence.file: evidence_score += 0.03 return max( 0.0, min( 1.0, overlap_score * 0.45 + substring_score * 0.35 + name_score * 0.12 + evidence_score, ), ) def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]: return self.functions_by_id.get(node_id) def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext: if not self.call_graph: return CodeGraphContext(node_id=node_id) return self.call_graph.expand(node_id, max_hops=max_hops) def summary(self) -> Dict[str, Any]: return { "function_count": len(self.functions), "index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0), "embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0), "embedding_available_count": sum(1 for item in self.functions if item.embedding_available), "vector_search_enabled": self.vector_search_enabled, "vector_search_disabled_reason": self.vector_search_disabled_reason, "graph_nodes": len(self.graph_data.get("nodes", []) or []), "graph_edges": len(self.graph_data.get("edges", []) or []), }