165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
import asyncio
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib import request
|
|
|
|
|
|
@dataclass
|
|
class ExternalRerankerClient:
|
|
api_url: str
|
|
api_key: str = ""
|
|
model: str = ""
|
|
timeout_seconds: float = 8.0
|
|
|
|
@property
|
|
def enabled(self) -> bool:
|
|
return bool(self.api_url)
|
|
|
|
@property
|
|
def is_dashscope_rerank(self) -> bool:
|
|
return "dashscope.aliyuncs.com" in self.api_url and "/services/rerank/" in self.api_url
|
|
|
|
async def rerank(
|
|
self,
|
|
*,
|
|
query: str,
|
|
documents: List[str],
|
|
top_n: Optional[int] = None,
|
|
metadata: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Optional[List[float]]:
|
|
if not self.enabled:
|
|
return None
|
|
if not documents:
|
|
return []
|
|
|
|
payload = self._build_payload(
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n or len(documents),
|
|
metadata=metadata,
|
|
)
|
|
|
|
try:
|
|
response = await asyncio.to_thread(self._post_json, payload)
|
|
scores = self._parse_scores(response, len(documents))
|
|
return scores
|
|
except Exception:
|
|
return None
|
|
|
|
def _post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
req = request.Request(
|
|
self.api_url,
|
|
data=json.dumps(payload).encode("utf-8"),
|
|
headers=headers,
|
|
method="POST",
|
|
)
|
|
with request.urlopen(req, timeout=self.timeout_seconds) as resp:
|
|
body = resp.read().decode("utf-8")
|
|
return json.loads(body)
|
|
|
|
def _build_payload(
|
|
self,
|
|
*,
|
|
query: str,
|
|
documents: List[str],
|
|
top_n: int,
|
|
metadata: Optional[List[Dict[str, Any]]],
|
|
) -> Dict[str, Any]:
|
|
if self.is_dashscope_rerank:
|
|
payload = {
|
|
"model": self.model,
|
|
"input": {
|
|
"query": query,
|
|
"documents": documents,
|
|
},
|
|
"parameters": {
|
|
"return_documents": True,
|
|
"top_n": top_n,
|
|
},
|
|
}
|
|
if metadata:
|
|
payload["metadata"] = metadata
|
|
return payload
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": top_n,
|
|
}
|
|
if metadata:
|
|
payload["metadata"] = metadata
|
|
return payload
|
|
|
|
def _parse_scores(self, response: Dict[str, Any], expected_len: int) -> List[float]:
|
|
# DashScope format:
|
|
# {"output": {"results": [{"index": 0, "relevance_score": 0.98}, ...]}}
|
|
output_block = response.get("output")
|
|
if isinstance(output_block, dict) and isinstance(output_block.get("results"), list):
|
|
raw_results = output_block["results"]
|
|
scores = [0.0] * expected_len
|
|
for item in raw_results:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
idx = item.get("index")
|
|
score = item.get("relevance_score", item.get("score", 0.0))
|
|
if isinstance(idx, int) and 0 <= idx < expected_len:
|
|
try:
|
|
scores[idx] = float(score)
|
|
except Exception:
|
|
scores[idx] = 0.0
|
|
return scores
|
|
|
|
# Common response format #1:
|
|
# {"results": [{"index": 0, "relevance_score": 0.98}, ...]}
|
|
if isinstance(response.get("results"), list):
|
|
raw_results = response["results"]
|
|
scores = [0.0] * expected_len
|
|
for item in raw_results:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
idx = item.get("index")
|
|
score = item.get("relevance_score", item.get("score", 0.0))
|
|
if isinstance(idx, int) and 0 <= idx < expected_len:
|
|
try:
|
|
scores[idx] = float(score)
|
|
except Exception:
|
|
scores[idx] = 0.0
|
|
return scores
|
|
|
|
# Common response format #2:
|
|
# {"scores": [0.9, 0.1, ...]}
|
|
if isinstance(response.get("scores"), list):
|
|
values = response["scores"]
|
|
scores: List[float] = []
|
|
for i in range(expected_len):
|
|
try:
|
|
scores.append(float(values[i]))
|
|
except Exception:
|
|
scores.append(0.0)
|
|
return scores
|
|
|
|
# Common response format #3:
|
|
# {"data": [{"index": 0, "score": 0.88}, ...]}
|
|
if isinstance(response.get("data"), list):
|
|
raw_results = response["data"]
|
|
scores = [0.0] * expected_len
|
|
for item in raw_results:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
idx = item.get("index")
|
|
score = item.get("score", item.get("relevance_score", 0.0))
|
|
if isinstance(idx, int) and 0 <= idx < expected_len:
|
|
try:
|
|
scores[idx] = float(score)
|
|
except Exception:
|
|
scores[idx] = 0.0
|
|
return scores
|
|
|
|
return [0.0] * expected_len
|