from __future__ import annotations from collections import defaultdict, deque from typing import Callable, Dict, Iterable, List, Optional, Set from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext class CodeCallGraph: def __init__( self, graph_data: Optional[dict], functions: Iterable[CodeFunctionEvidence], ) -> None: self.functions_by_id: Dict[str, CodeFunctionEvidence] = { item.node_id: item for item in functions } self.name_to_ids: Dict[str, List[str]] = defaultdict(list) for item in self.functions_by_id.values(): for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}: if name: self.name_to_ids[name].append(item.node_id) self.calls_by_id: Dict[str, Set[str]] = defaultdict(set) self.called_by_id: Dict[str, Set[str]] = defaultdict(set) self._load_graph_edges(graph_data or {}) self._load_metadata_edges() def _load_graph_edges(self, graph_data: dict) -> None: for edge in graph_data.get("edges", []) or []: if edge.get("type") != "CALLS": continue source_id = edge.get("source_id") target_id = edge.get("target_id") if source_id in self.functions_by_id and target_id in self.functions_by_id: self.calls_by_id[source_id].add(target_id) self.called_by_id[target_id].add(source_id) def _resolve_name(self, value: str) -> List[str]: if not value: return [] if value in self.functions_by_id: return [value] if value.startswith("Function:") and value in self.functions_by_id: return [value] return self.name_to_ids.get(value, []) def _load_metadata_edges(self) -> None: for function in self.functions_by_id.values(): for callee in function.calls: for target_id in self._resolve_name(callee): if target_id != function.node_id: self.calls_by_id[function.node_id].add(target_id) self.called_by_id[target_id].add(function.node_id) for caller in function.called_by: for source_id in self._resolve_name(caller): if source_id != function.node_id: self.called_by_id[function.node_id].add(source_id) self.calls_by_id[source_id].add(function.node_id) def _bfs( self, start_id: str, max_hops: int, next_nodes: Callable[[str], Iterable[str]], ) -> List[CodeFunctionEvidence]: seen = {start_id} result: List[CodeFunctionEvidence] = [] queue = deque([(start_id, 0)]) while queue: current_id, depth = queue.popleft() if depth >= max_hops: continue for next_id in sorted(next_nodes(current_id)): if next_id in seen: continue seen.add(next_id) function = self.functions_by_id.get(next_id) if function: result.append(function) queue.append((next_id, depth + 1)) return result def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext: callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set())) callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set())) center = self.functions_by_id.get(node_id) center_name = center.name if center else node_id call_chains: List[str] = [] for caller in callers[:5]: call_chains.append(f"{caller.name} -> {center_name}") for callee in callees[:5]: call_chains.append(f"{center_name} -> {callee.name}") for caller in callers[:3]: for callee in callees[:3]: call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}") return CodeGraphContext( node_id=node_id, callers=callers, callees=callees, call_chains=list(dict.fromkeys(call_chains)), )