Files

106 lines
4.2 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from collections import defaultdict, deque
from typing import Callable, Dict, Iterable, List, Optional, Set
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
class CodeCallGraph:
def __init__(
self,
graph_data: Optional[dict],
functions: Iterable[CodeFunctionEvidence],
) -> None:
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
item.node_id: item for item in functions
}
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
for item in self.functions_by_id.values():
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
if name:
self.name_to_ids[name].append(item.node_id)
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
self._load_graph_edges(graph_data or {})
self._load_metadata_edges()
def _load_graph_edges(self, graph_data: dict) -> None:
for edge in graph_data.get("edges", []) or []:
if edge.get("type") != "CALLS":
continue
source_id = edge.get("source_id")
target_id = edge.get("target_id")
if source_id in self.functions_by_id and target_id in self.functions_by_id:
self.calls_by_id[source_id].add(target_id)
self.called_by_id[target_id].add(source_id)
def _resolve_name(self, value: str) -> List[str]:
if not value:
return []
if value in self.functions_by_id:
return [value]
if value.startswith("Function:") and value in self.functions_by_id:
return [value]
return self.name_to_ids.get(value, [])
def _load_metadata_edges(self) -> None:
for function in self.functions_by_id.values():
for callee in function.calls:
for target_id in self._resolve_name(callee):
if target_id != function.node_id:
self.calls_by_id[function.node_id].add(target_id)
self.called_by_id[target_id].add(function.node_id)
for caller in function.called_by:
for source_id in self._resolve_name(caller):
if source_id != function.node_id:
self.called_by_id[function.node_id].add(source_id)
self.calls_by_id[source_id].add(function.node_id)
def _bfs(
self,
start_id: str,
max_hops: int,
next_nodes: Callable[[str], Iterable[str]],
) -> List[CodeFunctionEvidence]:
seen = {start_id}
result: List[CodeFunctionEvidence] = []
queue = deque([(start_id, 0)])
while queue:
current_id, depth = queue.popleft()
if depth >= max_hops:
continue
for next_id in sorted(next_nodes(current_id)):
if next_id in seen:
continue
seen.add(next_id)
function = self.functions_by_id.get(next_id)
if function:
result.append(function)
queue.append((next_id, depth + 1))
return result
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
center = self.functions_by_id.get(node_id)
center_name = center.name if center else node_id
call_chains: List[str] = []
for caller in callers[:5]:
call_chains.append(f"{caller.name} -> {center_name}")
for callee in callees[:5]:
call_chains.append(f"{center_name} -> {callee.name}")
for caller in callers[:3]:
for callee in callees[:3]:
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
return CodeGraphContext(
node_id=node_id,
callers=callers,
callees=callees,
call_chains=list(dict.fromkeys(call_chains)),
)