132 lines
4.6 KiB
Python
132 lines
4.6 KiB
Python
import re
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from app.services.reranker.external_api import ExternalRerankerClient
|
|
|
|
|
|
def _tokenize(text: str) -> List[str]:
|
|
tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower())
|
|
return [token for token in tokens if token.strip()]
|
|
|
|
|
|
def _keyword_score(query: str, text: str) -> float:
|
|
query_terms = set(_tokenize(query))
|
|
text_terms = set(_tokenize(text))
|
|
if not query_terms or not text_terms:
|
|
return 0.0
|
|
overlap = len(query_terms.intersection(text_terms))
|
|
return overlap / max(1, len(query_terms))
|
|
|
|
|
|
def format_retrieval_context(rows: List[Dict[str, Any]]) -> str:
|
|
blocks: List[str] = []
|
|
for i, row in enumerate(rows, start=1):
|
|
doc = row["document"]
|
|
metadata = doc.metadata or {}
|
|
blocks.append(
|
|
(
|
|
f"[{i}] kb_id={row.get('kb_id')}, source={metadata.get('source') or metadata.get('file_name') or 'unknown'}, "
|
|
f"chunk_id={metadata.get('chunk_id') or 'unknown'}, score={row.get('final_score', 0):.6f}\n"
|
|
f"{doc.page_content.strip()}"
|
|
)
|
|
)
|
|
return "\n\n".join(blocks)
|
|
|
|
|
|
class MultiKBRetriever:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
reranker_client: Optional[ExternalRerankerClient] = None,
|
|
reranker_weight: float = 0.75,
|
|
vector_weight: float = 0.2,
|
|
keyword_weight: float = 0.05,
|
|
):
|
|
self.reranker_client = reranker_client
|
|
self.reranker_weight = reranker_weight
|
|
self.vector_weight = vector_weight
|
|
self.keyword_weight = keyword_weight
|
|
|
|
async def retrieve(
|
|
self,
|
|
*,
|
|
query: str,
|
|
kb_vector_stores: List[Dict[str, Any]],
|
|
fetch_k_per_kb: int = 12,
|
|
top_k: int = 12,
|
|
) -> List[Dict[str, Any]]:
|
|
candidates: List[Dict[str, Any]] = []
|
|
|
|
for kb_store in kb_vector_stores:
|
|
kb_id = kb_store["kb_id"]
|
|
vector_store = kb_store["store"]
|
|
raw = vector_store.similarity_search_with_score(query, k=fetch_k_per_kb)
|
|
total = len(raw)
|
|
|
|
for index, item in enumerate(raw):
|
|
if not isinstance(item, (tuple, list)) or not item:
|
|
continue
|
|
|
|
doc = item[0]
|
|
if not hasattr(doc, "page_content"):
|
|
continue
|
|
|
|
metadata = doc.metadata or {}
|
|
rank_score = 1.0 - (index / max(1, total))
|
|
lexical_score = _keyword_score(query, doc.page_content)
|
|
|
|
candidates.append(
|
|
{
|
|
"kb_id": kb_id,
|
|
"document": doc,
|
|
"chunk_key": f"{kb_id}:{metadata.get('chunk_id', index)}",
|
|
"vector_rank_score": round(rank_score, 6),
|
|
"keyword_score": round(lexical_score, 6),
|
|
}
|
|
)
|
|
|
|
if not candidates:
|
|
return []
|
|
|
|
# Dedupe by KB + chunk id to avoid repeated chunks from same collection.
|
|
unique_map: Dict[str, Dict[str, Any]] = {}
|
|
for row in candidates:
|
|
key = row["chunk_key"]
|
|
existing = unique_map.get(key)
|
|
if existing is None:
|
|
unique_map[key] = row
|
|
continue
|
|
if row["vector_rank_score"] > existing["vector_rank_score"]:
|
|
unique_map[key] = row
|
|
|
|
merged = list(unique_map.values())
|
|
merged.sort(key=lambda x: x["vector_rank_score"], reverse=True)
|
|
|
|
reranker_scores: Optional[List[float]] = None
|
|
if self.reranker_client is not None and self.reranker_client.enabled:
|
|
reranker_scores = await self.reranker_client.rerank(
|
|
query=query,
|
|
documents=[row["document"].page_content for row in merged],
|
|
top_n=min(top_k, len(merged)),
|
|
metadata=[{"kb_id": row["kb_id"]} for row in merged],
|
|
)
|
|
|
|
for idx, row in enumerate(merged):
|
|
base_score = (
|
|
self.vector_weight * row["vector_rank_score"]
|
|
+ self.keyword_weight * row["keyword_score"]
|
|
)
|
|
|
|
if reranker_scores is not None:
|
|
rerank_value = float(reranker_scores[idx])
|
|
final_score = self.reranker_weight * rerank_value + (1 - self.reranker_weight) * base_score
|
|
row["reranker_score"] = round(rerank_value, 6)
|
|
else:
|
|
final_score = base_score
|
|
row["reranker_score"] = None
|
|
|
|
row["final_score"] = round(final_score, 6)
|
|
|
|
merged.sort(key=lambda x: x["final_score"], reverse=True)
|
|
return merged[:top_k]
|