Files

165 lines
5.4 KiB
Python
Raw Permalink Normal View History

2026-04-13 11:34:23 +08:00
import asyncio
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib import request
@dataclass
class ExternalRerankerClient:
api_url: str
api_key: str = ""
model: str = ""
timeout_seconds: float = 8.0
@property
def enabled(self) -> bool:
return bool(self.api_url)
@property
def is_dashscope_rerank(self) -> bool:
return "dashscope.aliyuncs.com" in self.api_url and "/services/rerank/" in self.api_url
async def rerank(
self,
*,
query: str,
documents: List[str],
top_n: Optional[int] = None,
metadata: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[float]]:
if not self.enabled:
return None
if not documents:
return []
payload = self._build_payload(
query=query,
documents=documents,
top_n=top_n or len(documents),
metadata=metadata,
)
try:
response = await asyncio.to_thread(self._post_json, payload)
scores = self._parse_scores(response, len(documents))
return scores
except Exception:
return None
def _post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
req = request.Request(
self.api_url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
def _build_payload(
self,
*,
query: str,
documents: List[str],
top_n: int,
metadata: Optional[List[Dict[str, Any]]],
) -> Dict[str, Any]:
if self.is_dashscope_rerank:
payload = {
"model": self.model,
"input": {
"query": query,
"documents": documents,
},
"parameters": {
"return_documents": True,
"top_n": top_n,
},
}
if metadata:
payload["metadata"] = metadata
return payload
payload = {
"model": self.model,
"query": query,
"documents": documents,
"top_n": top_n,
}
if metadata:
payload["metadata"] = metadata
return payload
def _parse_scores(self, response: Dict[str, Any], expected_len: int) -> List[float]:
# DashScope format:
# {"output": {"results": [{"index": 0, "relevance_score": 0.98}, ...]}}
output_block = response.get("output")
if isinstance(output_block, dict) and isinstance(output_block.get("results"), list):
raw_results = output_block["results"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
# Common response format #1:
# {"results": [{"index": 0, "relevance_score": 0.98}, ...]}
if isinstance(response.get("results"), list):
raw_results = response["results"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
# Common response format #2:
# {"scores": [0.9, 0.1, ...]}
if isinstance(response.get("scores"), list):
values = response["scores"]
scores: List[float] = []
for i in range(expected_len):
try:
scores.append(float(values[i]))
except Exception:
scores.append(0.0)
return scores
# Common response format #3:
# {"data": [{"index": 0, "score": 0.88}, ...]}
if isinstance(response.get("data"), list):
raw_results = response["data"]
scores = [0.0] * expected_len
for item in raw_results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("score", item.get("relevance_score", 0.0))
if isinstance(idx, int) and 0 <= idx < expected_len:
try:
scores[idx] = float(score)
except Exception:
scores[idx] = 0.0
return scores
return [0.0] * expected_len