Files
rag_agent/rag-web-ui/backend/app/services/hybrid_retriever.py
2026-04-13 11:34:23 +08:00

86 lines
2.5 KiB
Python

import re
from typing import Any, Dict, List
from app.services.vector_store.base import BaseVectorStore
def _tokenize_for_keyword_score(text: str) -> List[str]:
"""Simple multilingual tokenizer for lexical matching without extra dependencies."""
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
return [token for token in tokens if token.strip()]
def _keyword_score(query: str, doc_text: str) -> float:
query_terms = set(_tokenize_for_keyword_score(query))
doc_terms = set(_tokenize_for_keyword_score(doc_text))
if not query_terms or not doc_terms:
return 0.0
overlap = len(query_terms.intersection(doc_terms))
return overlap / max(1, len(query_terms))
def hybrid_search(
vector_store: BaseVectorStore,
query: str,
top_k: int = 6,
fetch_k: int = 20,
alpha: float = 0.65,
) -> List[Dict[str, Any]]:
"""
Hybrid retrieval via vector candidate generation + lexical reranking.
score = alpha * vector_rank_score + (1 - alpha) * keyword_score
"""
raw_results = vector_store.similarity_search_with_score(query, k=fetch_k)
if not raw_results:
return []
ranked: List[Dict[str, Any]] = []
total = len(raw_results)
for index, item in enumerate(raw_results):
if not isinstance(item, (tuple, list)) or len(item) < 1:
continue
doc = item[0]
if not hasattr(doc, "page_content"):
continue
rank_score = 1.0 - (index / max(1, total))
lexical_score = _keyword_score(query, doc.page_content)
final_score = alpha * rank_score + (1.0 - alpha) * lexical_score
ranked.append(
{
"document": doc,
"vector_rank_score": round(rank_score, 6),
"keyword_score": round(lexical_score, 6),
"final_score": round(final_score, 6),
}
)
ranked.sort(key=lambda row: row["final_score"], reverse=True)
return ranked[:top_k]
def format_hybrid_context(rows: List[Dict[str, Any]]) -> str:
parts: List[str] = []
for i, row in enumerate(rows, start=1):
doc = row["document"]
metadata = doc.metadata or {}
source = metadata.get("source") or metadata.get("file_name") or "unknown"
chunk_id = metadata.get("chunk_id") or "unknown"
parts.append(
(
f"[{i}] source={source}, chunk_id={chunk_id}, "
f"score={row['final_score']}\n"
f"{doc.page_content.strip()}"
)
)
return "\n\n".join(parts)