init. project
This commit is contained in:
85
rag-web-ui/backend/app/services/hybrid_retriever.py
Normal file
85
rag-web-ui/backend/app/services/hybrid_retriever.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.services.vector_store.base import BaseVectorStore
|
||||
|
||||
|
||||
def _tokenize_for_keyword_score(text: str) -> List[str]:
|
||||
"""Simple multilingual tokenizer for lexical matching without extra dependencies."""
|
||||
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
|
||||
return [token for token in tokens if token.strip()]
|
||||
|
||||
|
||||
def _keyword_score(query: str, doc_text: str) -> float:
|
||||
query_terms = set(_tokenize_for_keyword_score(query))
|
||||
doc_terms = set(_tokenize_for_keyword_score(doc_text))
|
||||
|
||||
if not query_terms or not doc_terms:
|
||||
return 0.0
|
||||
|
||||
overlap = len(query_terms.intersection(doc_terms))
|
||||
return overlap / max(1, len(query_terms))
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
vector_store: BaseVectorStore,
|
||||
query: str,
|
||||
top_k: int = 6,
|
||||
fetch_k: int = 20,
|
||||
alpha: float = 0.65,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Hybrid retrieval via vector candidate generation + lexical reranking.
|
||||
|
||||
score = alpha * vector_rank_score + (1 - alpha) * keyword_score
|
||||
"""
|
||||
raw_results = vector_store.similarity_search_with_score(query, k=fetch_k)
|
||||
if not raw_results:
|
||||
return []
|
||||
|
||||
ranked: List[Dict[str, Any]] = []
|
||||
total = len(raw_results)
|
||||
|
||||
for index, item in enumerate(raw_results):
|
||||
if not isinstance(item, (tuple, list)) or len(item) < 1:
|
||||
continue
|
||||
|
||||
doc = item[0]
|
||||
if not hasattr(doc, "page_content"):
|
||||
continue
|
||||
|
||||
rank_score = 1.0 - (index / max(1, total))
|
||||
lexical_score = _keyword_score(query, doc.page_content)
|
||||
final_score = alpha * rank_score + (1.0 - alpha) * lexical_score
|
||||
|
||||
ranked.append(
|
||||
{
|
||||
"document": doc,
|
||||
"vector_rank_score": round(rank_score, 6),
|
||||
"keyword_score": round(lexical_score, 6),
|
||||
"final_score": round(final_score, 6),
|
||||
}
|
||||
)
|
||||
|
||||
ranked.sort(key=lambda row: row["final_score"], reverse=True)
|
||||
return ranked[:top_k]
|
||||
|
||||
|
||||
def format_hybrid_context(rows: List[Dict[str, Any]]) -> str:
|
||||
parts: List[str] = []
|
||||
|
||||
for i, row in enumerate(rows, start=1):
|
||||
doc = row["document"]
|
||||
metadata = doc.metadata or {}
|
||||
source = metadata.get("source") or metadata.get("file_name") or "unknown"
|
||||
chunk_id = metadata.get("chunk_id") or "unknown"
|
||||
|
||||
parts.append(
|
||||
(
|
||||
f"[{i}] source={source}, chunk_id={chunk_id}, "
|
||||
f"score={row['final_score']}\n"
|
||||
f"{doc.page_content.strip()}"
|
||||
)
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
Reference in New Issue
Block a user