import re from typing import Any, Dict, List, Optional from app.services.reranker.external_api import ExternalRerankerClient def _tokenize(text: str) -> List[str]: tokens = re.findall(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text.lower()) return [token for token in tokens if token.strip()] def _keyword_score(query: str, text: str) -> float: query_terms = set(_tokenize(query)) text_terms = set(_tokenize(text)) if not query_terms or not text_terms: return 0.0 overlap = len(query_terms.intersection(text_terms)) return overlap / max(1, len(query_terms)) def format_retrieval_context(rows: List[Dict[str, Any]]) -> str: blocks: List[str] = [] for i, row in enumerate(rows, start=1): doc = row["document"] metadata = doc.metadata or {} blocks.append( ( f"[{i}] kb_id={row.get('kb_id')}, source={metadata.get('source') or metadata.get('file_name') or 'unknown'}, " f"chunk_id={metadata.get('chunk_id') or 'unknown'}, score={row.get('final_score', 0):.6f}\n" f"{doc.page_content.strip()}" ) ) return "\n\n".join(blocks) class MultiKBRetriever: def __init__( self, *, reranker_client: Optional[ExternalRerankerClient] = None, reranker_weight: float = 0.75, vector_weight: float = 0.2, keyword_weight: float = 0.05, ): self.reranker_client = reranker_client self.reranker_weight = reranker_weight self.vector_weight = vector_weight self.keyword_weight = keyword_weight async def retrieve( self, *, query: str, kb_vector_stores: List[Dict[str, Any]], fetch_k_per_kb: int = 12, top_k: int = 12, ) -> List[Dict[str, Any]]: candidates: List[Dict[str, Any]] = [] for kb_store in kb_vector_stores: kb_id = kb_store["kb_id"] vector_store = kb_store["store"] raw = vector_store.similarity_search_with_score(query, k=fetch_k_per_kb) total = len(raw) for index, item in enumerate(raw): if not isinstance(item, (tuple, list)) or not item: continue doc = item[0] if not hasattr(doc, "page_content"): continue metadata = doc.metadata or {} rank_score = 1.0 - (index / max(1, total)) lexical_score = _keyword_score(query, doc.page_content) candidates.append( { "kb_id": kb_id, "document": doc, "chunk_key": f"{kb_id}:{metadata.get('chunk_id', index)}", "vector_rank_score": round(rank_score, 6), "keyword_score": round(lexical_score, 6), } ) if not candidates: return [] # Dedupe by KB + chunk id to avoid repeated chunks from same collection. unique_map: Dict[str, Dict[str, Any]] = {} for row in candidates: key = row["chunk_key"] existing = unique_map.get(key) if existing is None: unique_map[key] = row continue if row["vector_rank_score"] > existing["vector_rank_score"]: unique_map[key] = row merged = list(unique_map.values()) merged.sort(key=lambda x: x["vector_rank_score"], reverse=True) reranker_scores: Optional[List[float]] = None if self.reranker_client is not None and self.reranker_client.enabled: reranker_scores = await self.reranker_client.rerank( query=query, documents=[row["document"].page_content for row in merged], top_n=min(top_k, len(merged)), metadata=[{"kb_id": row["kb_id"]} for row in merged], ) for idx, row in enumerate(merged): base_score = ( self.vector_weight * row["vector_rank_score"] + self.keyword_weight * row["keyword_score"] ) if reranker_scores is not None: rerank_value = float(reranker_scores[idx]) final_score = self.reranker_weight * rerank_value + (1 - self.reranker_weight) * base_score row["reranker_score"] = round(rerank_value, 6) else: final_score = base_score row["reranker_score"] = None row["final_score"] = round(final_score, 6) merged.sort(key=lambda x: x["final_score"], reverse=True) return merged[:top_k]