增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
|
||||
|
||||
|
||||
class CodeCallGraph:
|
||||
def __init__(
|
||||
self,
|
||||
graph_data: Optional[dict],
|
||||
functions: Iterable[CodeFunctionEvidence],
|
||||
) -> None:
|
||||
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
|
||||
item.node_id: item for item in functions
|
||||
}
|
||||
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
|
||||
for item in self.functions_by_id.values():
|
||||
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
|
||||
if name:
|
||||
self.name_to_ids[name].append(item.node_id)
|
||||
|
||||
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self._load_graph_edges(graph_data or {})
|
||||
self._load_metadata_edges()
|
||||
|
||||
def _load_graph_edges(self, graph_data: dict) -> None:
|
||||
for edge in graph_data.get("edges", []) or []:
|
||||
if edge.get("type") != "CALLS":
|
||||
continue
|
||||
source_id = edge.get("source_id")
|
||||
target_id = edge.get("target_id")
|
||||
if source_id in self.functions_by_id and target_id in self.functions_by_id:
|
||||
self.calls_by_id[source_id].add(target_id)
|
||||
self.called_by_id[target_id].add(source_id)
|
||||
|
||||
def _resolve_name(self, value: str) -> List[str]:
|
||||
if not value:
|
||||
return []
|
||||
if value in self.functions_by_id:
|
||||
return [value]
|
||||
if value.startswith("Function:") and value in self.functions_by_id:
|
||||
return [value]
|
||||
return self.name_to_ids.get(value, [])
|
||||
|
||||
def _load_metadata_edges(self) -> None:
|
||||
for function in self.functions_by_id.values():
|
||||
for callee in function.calls:
|
||||
for target_id in self._resolve_name(callee):
|
||||
if target_id != function.node_id:
|
||||
self.calls_by_id[function.node_id].add(target_id)
|
||||
self.called_by_id[target_id].add(function.node_id)
|
||||
for caller in function.called_by:
|
||||
for source_id in self._resolve_name(caller):
|
||||
if source_id != function.node_id:
|
||||
self.called_by_id[function.node_id].add(source_id)
|
||||
self.calls_by_id[source_id].add(function.node_id)
|
||||
|
||||
def _bfs(
|
||||
self,
|
||||
start_id: str,
|
||||
max_hops: int,
|
||||
next_nodes: Callable[[str], Iterable[str]],
|
||||
) -> List[CodeFunctionEvidence]:
|
||||
seen = {start_id}
|
||||
result: List[CodeFunctionEvidence] = []
|
||||
queue = deque([(start_id, 0)])
|
||||
while queue:
|
||||
current_id, depth = queue.popleft()
|
||||
if depth >= max_hops:
|
||||
continue
|
||||
for next_id in sorted(next_nodes(current_id)):
|
||||
if next_id in seen:
|
||||
continue
|
||||
seen.add(next_id)
|
||||
function = self.functions_by_id.get(next_id)
|
||||
if function:
|
||||
result.append(function)
|
||||
queue.append((next_id, depth + 1))
|
||||
return result
|
||||
|
||||
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||||
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
|
||||
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
|
||||
center = self.functions_by_id.get(node_id)
|
||||
center_name = center.name if center else node_id
|
||||
|
||||
call_chains: List[str] = []
|
||||
for caller in callers[:5]:
|
||||
call_chains.append(f"{caller.name} -> {center_name}")
|
||||
for callee in callees[:5]:
|
||||
call_chains.append(f"{center_name} -> {callee.name}")
|
||||
for caller in callers[:3]:
|
||||
for callee in callees[:3]:
|
||||
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
|
||||
|
||||
return CodeGraphContext(
|
||||
node_id=node_id,
|
||||
callers=callers,
|
||||
callees=callees,
|
||||
call_chains=list(dict.fromkeys(call_chains)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user