106 lines
4.2 KiB
Python
106 lines
4.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections import defaultdict, deque
|
||
|
|
from typing import Callable, Dict, Iterable, List, Optional, Set
|
||
|
|
|
||
|
|
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
|
||
|
|
|
||
|
|
|
||
|
|
class CodeCallGraph:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
graph_data: Optional[dict],
|
||
|
|
functions: Iterable[CodeFunctionEvidence],
|
||
|
|
) -> None:
|
||
|
|
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
|
||
|
|
item.node_id: item for item in functions
|
||
|
|
}
|
||
|
|
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
|
||
|
|
for item in self.functions_by_id.values():
|
||
|
|
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
|
||
|
|
if name:
|
||
|
|
self.name_to_ids[name].append(item.node_id)
|
||
|
|
|
||
|
|
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||
|
|
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||
|
|
self._load_graph_edges(graph_data or {})
|
||
|
|
self._load_metadata_edges()
|
||
|
|
|
||
|
|
def _load_graph_edges(self, graph_data: dict) -> None:
|
||
|
|
for edge in graph_data.get("edges", []) or []:
|
||
|
|
if edge.get("type") != "CALLS":
|
||
|
|
continue
|
||
|
|
source_id = edge.get("source_id")
|
||
|
|
target_id = edge.get("target_id")
|
||
|
|
if source_id in self.functions_by_id and target_id in self.functions_by_id:
|
||
|
|
self.calls_by_id[source_id].add(target_id)
|
||
|
|
self.called_by_id[target_id].add(source_id)
|
||
|
|
|
||
|
|
def _resolve_name(self, value: str) -> List[str]:
|
||
|
|
if not value:
|
||
|
|
return []
|
||
|
|
if value in self.functions_by_id:
|
||
|
|
return [value]
|
||
|
|
if value.startswith("Function:") and value in self.functions_by_id:
|
||
|
|
return [value]
|
||
|
|
return self.name_to_ids.get(value, [])
|
||
|
|
|
||
|
|
def _load_metadata_edges(self) -> None:
|
||
|
|
for function in self.functions_by_id.values():
|
||
|
|
for callee in function.calls:
|
||
|
|
for target_id in self._resolve_name(callee):
|
||
|
|
if target_id != function.node_id:
|
||
|
|
self.calls_by_id[function.node_id].add(target_id)
|
||
|
|
self.called_by_id[target_id].add(function.node_id)
|
||
|
|
for caller in function.called_by:
|
||
|
|
for source_id in self._resolve_name(caller):
|
||
|
|
if source_id != function.node_id:
|
||
|
|
self.called_by_id[function.node_id].add(source_id)
|
||
|
|
self.calls_by_id[source_id].add(function.node_id)
|
||
|
|
|
||
|
|
def _bfs(
|
||
|
|
self,
|
||
|
|
start_id: str,
|
||
|
|
max_hops: int,
|
||
|
|
next_nodes: Callable[[str], Iterable[str]],
|
||
|
|
) -> List[CodeFunctionEvidence]:
|
||
|
|
seen = {start_id}
|
||
|
|
result: List[CodeFunctionEvidence] = []
|
||
|
|
queue = deque([(start_id, 0)])
|
||
|
|
while queue:
|
||
|
|
current_id, depth = queue.popleft()
|
||
|
|
if depth >= max_hops:
|
||
|
|
continue
|
||
|
|
for next_id in sorted(next_nodes(current_id)):
|
||
|
|
if next_id in seen:
|
||
|
|
continue
|
||
|
|
seen.add(next_id)
|
||
|
|
function = self.functions_by_id.get(next_id)
|
||
|
|
if function:
|
||
|
|
result.append(function)
|
||
|
|
queue.append((next_id, depth + 1))
|
||
|
|
return result
|
||
|
|
|
||
|
|
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||
|
|
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
|
||
|
|
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
|
||
|
|
center = self.functions_by_id.get(node_id)
|
||
|
|
center_name = center.name if center else node_id
|
||
|
|
|
||
|
|
call_chains: List[str] = []
|
||
|
|
for caller in callers[:5]:
|
||
|
|
call_chains.append(f"{caller.name} -> {center_name}")
|
||
|
|
for callee in callees[:5]:
|
||
|
|
call_chains.append(f"{center_name} -> {callee.name}")
|
||
|
|
for caller in callers[:3]:
|
||
|
|
for callee in callees[:3]:
|
||
|
|
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
|
||
|
|
|
||
|
|
return CodeGraphContext(
|
||
|
|
node_id=node_id,
|
||
|
|
callers=callers,
|
||
|
|
callees=callees,
|
||
|
|
call_chains=list(dict.fromkeys(call_chains)),
|
||
|
|
)
|
||
|
|
|