增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
10
rag-web-ui/backend/app/services/code_kb/__init__.py
Normal file
10
rag-web-ui/backend/app/services/code_kb/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
|
||||
|
||||
__all__ = [
|
||||
"CodeFunctionEvidence",
|
||||
"CodeGraphContext",
|
||||
"CodeKnowledgeBaseAdapter",
|
||||
"CodeSearchHit",
|
||||
]
|
||||
|
||||
517
rag-web-ui/backend/app/services/code_kb/adapter.py
Normal file
517
rag-web-ui/backend/app/services/code_kb/adapter.py
Normal file
@@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from app.services.code_kb.graph import CodeCallGraph
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FIELD_ALIASES = {
|
||||
"name": ("name", "function_name"),
|
||||
"file": ("file", "file_path"),
|
||||
"summary": ("summary",),
|
||||
"logic_flow": ("logic_flow", "logic"),
|
||||
"code_snippet": ("code_snippet", "source", "code"),
|
||||
"calls": ("calls", "called_functions"),
|
||||
"called_by": ("called_by", "caller_functions", "callers"),
|
||||
}
|
||||
|
||||
|
||||
class SimpleVectorIndex:
|
||||
def __init__(self, vectors: Any, dimension: int) -> None:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError: # pragma: no cover - depends on deployment environment
|
||||
np = None
|
||||
self._np = np
|
||||
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
|
||||
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
|
||||
self.ntotal = len(self.vectors)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
|
||||
try:
|
||||
with open(vector_path, "r", encoding="utf-8") as file:
|
||||
payload = json.load(file)
|
||||
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
|
||||
return None
|
||||
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
|
||||
return None
|
||||
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
|
||||
|
||||
def search(self, vector: Any, top_k: int) -> Any:
|
||||
if self.ntotal == 0:
|
||||
if self._np is not None:
|
||||
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
|
||||
return [[]], [[]]
|
||||
if self._np is not None:
|
||||
query = self._np.array(vector, dtype="float32")
|
||||
if query.ndim != 2 or query.shape[1] != self.d:
|
||||
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
||||
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
|
||||
order = self._np.argsort(distances)[:top_k]
|
||||
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
|
||||
|
||||
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
|
||||
query_values = [float(value) for value in query_row]
|
||||
if len(query_values) != self.d:
|
||||
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
|
||||
scored = [
|
||||
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
|
||||
for index, stored in enumerate(self.vectors)
|
||||
]
|
||||
scored.sort(key=lambda item: item[0])
|
||||
selected = scored[:top_k]
|
||||
return [[item[0] for item in selected]], [[item[1] for item in selected]]
|
||||
|
||||
def has_nonzero_vectors(self) -> bool:
|
||||
return any(
|
||||
any(abs(float(value)) > 1e-12 for value in vector)
|
||||
for vector in self.vectors
|
||||
)
|
||||
|
||||
|
||||
class UnavailableVectorIndex:
|
||||
d = 0
|
||||
ntotal = 0
|
||||
|
||||
def __init__(self, reason: str) -> None:
|
||||
self.reason = reason
|
||||
|
||||
def search(self, vector: Any, top_k: int) -> Any:
|
||||
raise RuntimeError(self.reason)
|
||||
|
||||
|
||||
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
|
||||
for alias in aliases:
|
||||
value = source.get(alias)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def _as_list(value: Any) -> List[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value if item not in (None, "")]
|
||||
if isinstance(value, tuple):
|
||||
return [str(item) for item in value if item not in (None, "")]
|
||||
if isinstance(value, str):
|
||||
return [value] if value else []
|
||||
return [str(value)]
|
||||
|
||||
|
||||
def _as_int(value: Any) -> Optional[int]:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _as_bool(value: Any, default: bool = True) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "y"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "n"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _load_json(path: str) -> Any:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
|
||||
metadata = _load_json(metadata_path)
|
||||
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
|
||||
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
|
||||
metadata_items = metadata_items or []
|
||||
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
|
||||
return {
|
||||
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
|
||||
"graph_nodes": graph_meta.get("total_nodes"),
|
||||
"graph_edges": graph_meta.get("total_edges"),
|
||||
"project_root": graph_meta.get("project_root"),
|
||||
"generated_at": graph_meta.get("generated_at"),
|
||||
}
|
||||
|
||||
|
||||
class CodeKnowledgeBaseAdapter:
|
||||
def __init__(self, embedding_function: Any = None) -> None:
|
||||
self.embedding_function = embedding_function
|
||||
self.faiss_index: Any = None
|
||||
self.metadata: List[Dict[str, Any]] = []
|
||||
self.graph_data: Dict[str, Any] = {}
|
||||
self.functions: List[CodeFunctionEvidence] = []
|
||||
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
|
||||
self.call_graph: Optional[CodeCallGraph] = None
|
||||
self.vector_search_enabled = False
|
||||
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
|
||||
|
||||
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
||||
self._validate_paths(vector_path, metadata_path, graph_path)
|
||||
self.faiss_index = self._read_faiss_index(vector_path)
|
||||
self.metadata = self._read_metadata(metadata_path)
|
||||
self.graph_data = _load_json(graph_path)
|
||||
graph_nodes = self._index_graph_nodes(self.graph_data)
|
||||
|
||||
self.functions = [
|
||||
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
|
||||
for row in self.metadata
|
||||
]
|
||||
self.functions_by_id = {item.node_id: item for item in self.functions}
|
||||
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
|
||||
self._configure_vector_search()
|
||||
|
||||
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
|
||||
missing = [
|
||||
path
|
||||
for path in [vector_path, metadata_path, graph_path]
|
||||
if not path or not Path(path).exists()
|
||||
]
|
||||
if missing:
|
||||
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
|
||||
|
||||
def _read_faiss_index(self, vector_path: str) -> Any:
|
||||
simple_index = SimpleVectorIndex.from_file(vector_path)
|
||||
if simple_index is not None:
|
||||
return simple_index
|
||||
try:
|
||||
import faiss # type: ignore
|
||||
except ImportError as exc:
|
||||
logger.warning("faiss is not installed; code search will use lexical fallback.")
|
||||
return UnavailableVectorIndex(
|
||||
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
|
||||
)
|
||||
return faiss.read_index(vector_path)
|
||||
|
||||
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
|
||||
metadata = _load_json(metadata_path)
|
||||
if isinstance(metadata, dict):
|
||||
metadata = metadata.get("functions", metadata.get("items", []))
|
||||
if not isinstance(metadata, list):
|
||||
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
|
||||
return [item for item in metadata if isinstance(item, dict)]
|
||||
|
||||
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
item.get("id"): item
|
||||
for item in graph_data.get("nodes", []) or []
|
||||
if isinstance(item, dict) and item.get("id")
|
||||
}
|
||||
|
||||
def _normalize_metadata(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
graph_nodes: Dict[str, Dict[str, Any]],
|
||||
index_dimension: int,
|
||||
) -> CodeFunctionEvidence:
|
||||
name = _first_value(metadata, FIELD_ALIASES["name"], "")
|
||||
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
|
||||
graph_node = graph_nodes.get(node_id, {})
|
||||
raw_attributes = graph_node.get("raw_attributes") or {}
|
||||
if not name:
|
||||
name = graph_node.get("name") or node_id.removeprefix("Function:")
|
||||
|
||||
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
|
||||
embedding_dim = metadata.get("embedding_dim") or index_dimension
|
||||
embedding_available = _as_bool(
|
||||
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
|
||||
default=True,
|
||||
)
|
||||
return CodeFunctionEvidence(
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
|
||||
file=file_path or "",
|
||||
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
|
||||
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
|
||||
signature=metadata.get("signature") or graph_node.get("signature") or "",
|
||||
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
|
||||
logic_flow=_first_value(
|
||||
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
|
||||
)
|
||||
or "",
|
||||
code_snippet=_first_value(
|
||||
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
|
||||
)
|
||||
or "",
|
||||
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
|
||||
called_by=_as_list(
|
||||
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
|
||||
),
|
||||
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
|
||||
embedding_model=metadata.get("embedding_model") or "",
|
||||
embedding_dim=int(embedding_dim or 0),
|
||||
embedding_available=embedding_available,
|
||||
raw=metadata,
|
||||
)
|
||||
|
||||
def _configure_vector_search(self) -> None:
|
||||
index_total = int(getattr(self.faiss_index, "ntotal", 0))
|
||||
self.vector_search_enabled = False
|
||||
self.vector_search_disabled_reason = ""
|
||||
if isinstance(self.faiss_index, UnavailableVectorIndex):
|
||||
self.vector_search_disabled_reason = self.faiss_index.reason
|
||||
return
|
||||
if index_total == 0:
|
||||
self.vector_search_disabled_reason = "Vector index is empty."
|
||||
return
|
||||
if index_total != len(self.functions):
|
||||
self.vector_search_disabled_reason = (
|
||||
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
|
||||
)
|
||||
logger.warning(
|
||||
"FAISS index size (%s) differs from metadata size (%s)",
|
||||
index_total,
|
||||
len(self.functions),
|
||||
)
|
||||
return
|
||||
if not self._index_has_nonzero_vectors():
|
||||
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
|
||||
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
||||
return
|
||||
if not any(item.embedding_available for item in self.functions):
|
||||
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
|
||||
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
|
||||
return
|
||||
self.vector_search_enabled = True
|
||||
|
||||
def _index_has_nonzero_vectors(self) -> bool:
|
||||
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
|
||||
if callable(has_nonzero_vectors):
|
||||
return bool(has_nonzero_vectors())
|
||||
|
||||
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
|
||||
if not callable(reconstruct_n):
|
||||
return True
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
|
||||
sample_size = min(index_total, 1000)
|
||||
if sample_size <= 0:
|
||||
return False
|
||||
vectors = reconstruct_n(0, sample_size)
|
||||
return bool(np.any(np.abs(vectors) > 1e-12))
|
||||
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
|
||||
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
|
||||
return True
|
||||
|
||||
def _get_embedding_function(self) -> Any:
|
||||
if self.embedding_function is None:
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
|
||||
self.embedding_function = EmbeddingsFactory.create()
|
||||
return self.embedding_function
|
||||
|
||||
def _embed_query(self, query: str) -> np.ndarray:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
|
||||
embedding_function = self._get_embedding_function()
|
||||
if hasattr(embedding_function, "embed_query"):
|
||||
vector = embedding_function.embed_query(query)
|
||||
elif callable(embedding_function):
|
||||
vector = embedding_function(query)
|
||||
else:
|
||||
raise TypeError("Unsupported embedding function.")
|
||||
query_vector = np.array([vector], dtype="float32")
|
||||
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
|
||||
if index_dimension and query_vector.shape[1] != index_dimension:
|
||||
raise ValueError(
|
||||
f"Query embedding dimension {query_vector.shape[1]} does not match "
|
||||
f"FAISS index dimension {index_dimension}."
|
||||
)
|
||||
return query_vector
|
||||
|
||||
@staticmethod
|
||||
def distance_to_similarity(distance: float) -> float:
|
||||
if math.isnan(distance) or distance < 0:
|
||||
return 0.0
|
||||
if distance <= 2.0:
|
||||
return max(0.0, min(1.0, 1.0 - distance / 2.0))
|
||||
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
|
||||
|
||||
def search_functions(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
if self.faiss_index is None:
|
||||
raise RuntimeError("Code knowledge base has not been loaded.")
|
||||
if not query.strip():
|
||||
return []
|
||||
if not self.vector_search_enabled:
|
||||
logger.info(
|
||||
"Vector code search disabled, using lexical search: %s",
|
||||
self.vector_search_disabled_reason,
|
||||
)
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
|
||||
try:
|
||||
vector = self._embed_query(query)
|
||||
distances, indices = self.faiss_index.search(vector, top_k)
|
||||
except Exception as exc:
|
||||
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
|
||||
hits: List[CodeSearchHit] = []
|
||||
for rank, raw_index in enumerate(indices[0], start=1):
|
||||
index = int(raw_index)
|
||||
if index < 0 or index >= len(self.functions):
|
||||
continue
|
||||
evidence = self.functions[index]
|
||||
if not evidence.embedding_available:
|
||||
continue
|
||||
distance = float(distances[0][rank - 1])
|
||||
similarity = self.distance_to_similarity(distance)
|
||||
if similarity < min_similarity:
|
||||
continue
|
||||
hits.append(
|
||||
CodeSearchHit(
|
||||
evidence=evidence,
|
||||
similarity=similarity,
|
||||
distance=distance,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
if not hits:
|
||||
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
|
||||
return hits
|
||||
|
||||
def _lexical_search_functions(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
query_tokens = self._tokens(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
scored: List[CodeSearchHit] = []
|
||||
for evidence in self.functions:
|
||||
text = self._function_search_text(evidence)
|
||||
score = self._lexical_similarity(query_tokens, text, evidence)
|
||||
if score < min_similarity:
|
||||
continue
|
||||
scored.append(
|
||||
CodeSearchHit(
|
||||
evidence=evidence,
|
||||
similarity=score,
|
||||
distance=max(0.0, 1.0 - score),
|
||||
rank=0,
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.similarity, reverse=True)
|
||||
for rank, item in enumerate(scored[:top_k], start=1):
|
||||
item.rank = rank
|
||||
return scored[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _tokens(text: str) -> List[str]:
|
||||
normalized = (text or "").lower()
|
||||
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
|
||||
expanded: List[str] = []
|
||||
for token in tokens:
|
||||
expanded.append(token)
|
||||
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
|
||||
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
|
||||
return expanded
|
||||
|
||||
@staticmethod
|
||||
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
evidence.node_id,
|
||||
evidence.name,
|
||||
evidence.qualified_name,
|
||||
evidence.file,
|
||||
evidence.signature,
|
||||
evidence.summary,
|
||||
evidence.logic_flow,
|
||||
evidence.code_snippet[:4000],
|
||||
" ".join(evidence.calls),
|
||||
" ".join(evidence.called_by),
|
||||
" ".join(evidence.includes),
|
||||
]
|
||||
).lower()
|
||||
|
||||
def _lexical_similarity(
|
||||
self,
|
||||
query_tokens: List[str],
|
||||
text: str,
|
||||
evidence: CodeFunctionEvidence,
|
||||
) -> float:
|
||||
text_tokens = set(self._tokens(text))
|
||||
if not text_tokens:
|
||||
return 0.0
|
||||
unique_query_tokens = list(dict.fromkeys(query_tokens))
|
||||
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
|
||||
substring_count = sum(1 for token in unique_query_tokens if token in text)
|
||||
overlap_score = overlap_count / max(1, len(unique_query_tokens))
|
||||
substring_score = substring_count / max(1, len(unique_query_tokens))
|
||||
|
||||
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
|
||||
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
|
||||
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
|
||||
evidence_score = 0.0
|
||||
if evidence.summary:
|
||||
evidence_score += 0.08
|
||||
if evidence.logic_flow:
|
||||
evidence_score += 0.05
|
||||
if evidence.code_snippet:
|
||||
evidence_score += 0.04
|
||||
if evidence.start_line is not None and evidence.file:
|
||||
evidence_score += 0.03
|
||||
|
||||
return max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
overlap_score * 0.45
|
||||
+ substring_score * 0.35
|
||||
+ name_score * 0.12
|
||||
+ evidence_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
|
||||
return self.functions_by_id.get(node_id)
|
||||
|
||||
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||||
if not self.call_graph:
|
||||
return CodeGraphContext(node_id=node_id)
|
||||
return self.call_graph.expand(node_id, max_hops=max_hops)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"function_count": len(self.functions),
|
||||
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
|
||||
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
|
||||
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
|
||||
"vector_search_enabled": self.vector_search_enabled,
|
||||
"vector_search_disabled_reason": self.vector_search_disabled_reason,
|
||||
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
|
||||
"graph_edges": len(self.graph_data.get("edges", []) or []),
|
||||
}
|
||||
46
rag-web-ui/backend/app/services/code_kb/formatter.py
Normal file
46
rag-web-ui/backend/app/services/code_kb/formatter.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, List
|
||||
|
||||
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
|
||||
|
||||
|
||||
def _clip(value: str, limit: int) -> str:
|
||||
value = value or ""
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
return value[:limit].rstrip() + "\n...[truncated]"
|
||||
|
||||
|
||||
def format_evidence_context(
|
||||
hits: Iterable[CodeSearchHit],
|
||||
graph_contexts: Iterable[CodeGraphContext],
|
||||
max_code_chars: int = 1000,
|
||||
max_text_chars: int = 800,
|
||||
) -> str:
|
||||
context_by_node = {item.node_id: item for item in graph_contexts}
|
||||
blocks: List[str] = []
|
||||
for hit in hits:
|
||||
item = hit.evidence
|
||||
graph = context_by_node.get(item.node_id)
|
||||
lines = [
|
||||
f"[Function Evidence #{hit.rank}]",
|
||||
f"node_id: {item.node_id}",
|
||||
f"name: {item.name}",
|
||||
f"qualified_name: {item.qualified_name}",
|
||||
f"file: {item.file}",
|
||||
f"lines: {item.start_line}-{item.end_line}",
|
||||
f"similarity: {hit.similarity:.4f}",
|
||||
f"signature: {_clip(item.signature, 300)}",
|
||||
f"summary: {_clip(item.summary, max_text_chars)}",
|
||||
f"logic_flow: {_clip(item.logic_flow, max_text_chars)}",
|
||||
f"calls: {', '.join(item.calls[:20]) or '-'}",
|
||||
f"called_by: {', '.join(item.called_by[:20]) or '-'}",
|
||||
]
|
||||
if graph:
|
||||
lines.append(f"call_chain_evidence: {'; '.join(graph.call_chains[:12]) or '-'}")
|
||||
if item.code_snippet:
|
||||
lines.extend(["code_snippet:", _clip(item.code_snippet, max_code_chars)])
|
||||
blocks.append("\n".join(lines))
|
||||
return "\n\n---\n\n".join(blocks)
|
||||
|
||||
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
105
rag-web-ui/backend/app/services/code_kb/graph.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
|
||||
|
||||
|
||||
class CodeCallGraph:
|
||||
def __init__(
|
||||
self,
|
||||
graph_data: Optional[dict],
|
||||
functions: Iterable[CodeFunctionEvidence],
|
||||
) -> None:
|
||||
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
|
||||
item.node_id: item for item in functions
|
||||
}
|
||||
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
|
||||
for item in self.functions_by_id.values():
|
||||
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
|
||||
if name:
|
||||
self.name_to_ids[name].append(item.node_id)
|
||||
|
||||
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
|
||||
self._load_graph_edges(graph_data or {})
|
||||
self._load_metadata_edges()
|
||||
|
||||
def _load_graph_edges(self, graph_data: dict) -> None:
|
||||
for edge in graph_data.get("edges", []) or []:
|
||||
if edge.get("type") != "CALLS":
|
||||
continue
|
||||
source_id = edge.get("source_id")
|
||||
target_id = edge.get("target_id")
|
||||
if source_id in self.functions_by_id and target_id in self.functions_by_id:
|
||||
self.calls_by_id[source_id].add(target_id)
|
||||
self.called_by_id[target_id].add(source_id)
|
||||
|
||||
def _resolve_name(self, value: str) -> List[str]:
|
||||
if not value:
|
||||
return []
|
||||
if value in self.functions_by_id:
|
||||
return [value]
|
||||
if value.startswith("Function:") and value in self.functions_by_id:
|
||||
return [value]
|
||||
return self.name_to_ids.get(value, [])
|
||||
|
||||
def _load_metadata_edges(self) -> None:
|
||||
for function in self.functions_by_id.values():
|
||||
for callee in function.calls:
|
||||
for target_id in self._resolve_name(callee):
|
||||
if target_id != function.node_id:
|
||||
self.calls_by_id[function.node_id].add(target_id)
|
||||
self.called_by_id[target_id].add(function.node_id)
|
||||
for caller in function.called_by:
|
||||
for source_id in self._resolve_name(caller):
|
||||
if source_id != function.node_id:
|
||||
self.called_by_id[function.node_id].add(source_id)
|
||||
self.calls_by_id[source_id].add(function.node_id)
|
||||
|
||||
def _bfs(
|
||||
self,
|
||||
start_id: str,
|
||||
max_hops: int,
|
||||
next_nodes: Callable[[str], Iterable[str]],
|
||||
) -> List[CodeFunctionEvidence]:
|
||||
seen = {start_id}
|
||||
result: List[CodeFunctionEvidence] = []
|
||||
queue = deque([(start_id, 0)])
|
||||
while queue:
|
||||
current_id, depth = queue.popleft()
|
||||
if depth >= max_hops:
|
||||
continue
|
||||
for next_id in sorted(next_nodes(current_id)):
|
||||
if next_id in seen:
|
||||
continue
|
||||
seen.add(next_id)
|
||||
function = self.functions_by_id.get(next_id)
|
||||
if function:
|
||||
result.append(function)
|
||||
queue.append((next_id, depth + 1))
|
||||
return result
|
||||
|
||||
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
|
||||
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
|
||||
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
|
||||
center = self.functions_by_id.get(node_id)
|
||||
center_name = center.name if center else node_id
|
||||
|
||||
call_chains: List[str] = []
|
||||
for caller in callers[:5]:
|
||||
call_chains.append(f"{caller.name} -> {center_name}")
|
||||
for callee in callees[:5]:
|
||||
call_chains.append(f"{center_name} -> {callee.name}")
|
||||
for caller in callers[:3]:
|
||||
for callee in callees[:3]:
|
||||
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
|
||||
|
||||
return CodeGraphContext(
|
||||
node_id=node_id,
|
||||
callers=callers,
|
||||
callees=callees,
|
||||
call_chains=list(dict.fromkeys(call_chains)),
|
||||
)
|
||||
|
||||
24
rag-web-ui/backend/app/services/code_kb/retriever.py
Normal file
24
rag-web-ui/backend/app/services/code_kb/retriever.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||||
from app.services.code_kb.schema import CodeSearchHit
|
||||
|
||||
|
||||
class CodeFunctionRetriever:
|
||||
def __init__(self, adapter: CodeKnowledgeBaseAdapter) -> None:
|
||||
self.adapter = adapter
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 8,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[CodeSearchHit]:
|
||||
return self.adapter.search_functions(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
min_similarity=min_similarity,
|
||||
)
|
||||
|
||||
63
rag-web-ui/backend/app/services/code_kb/schema.py
Normal file
63
rag-web-ui/backend/app/services/code_kb/schema.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeFunctionEvidence:
|
||||
node_id: str
|
||||
name: str
|
||||
qualified_name: str
|
||||
file: str
|
||||
start_line: Optional[int] = None
|
||||
end_line: Optional[int] = None
|
||||
signature: str = ""
|
||||
summary: str = ""
|
||||
logic_flow: str = ""
|
||||
code_snippet: str = ""
|
||||
calls: List[str] = field(default_factory=list)
|
||||
called_by: List[str] = field(default_factory=list)
|
||||
includes: List[str] = field(default_factory=list)
|
||||
embedding_model: str = ""
|
||||
embedding_dim: int = 0
|
||||
embedding_available: bool = True
|
||||
raw: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeSearchHit:
|
||||
evidence: CodeFunctionEvidence
|
||||
similarity: float
|
||||
distance: float
|
||||
rank: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = self.evidence.to_dict()
|
||||
data.update(
|
||||
{
|
||||
"similarity": self.similarity,
|
||||
"distance": self.distance,
|
||||
"rank": self.rank,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeGraphContext:
|
||||
node_id: str
|
||||
callers: List[CodeFunctionEvidence] = field(default_factory=list)
|
||||
callees: List[CodeFunctionEvidence] = field(default_factory=list)
|
||||
call_chains: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"callers": [item.to_dict() for item in self.callers],
|
||||
"callees": [item.to_dict() for item in self.callees],
|
||||
"call_chains": self.call_chains,
|
||||
}
|
||||
Reference in New Issue
Block a user