init. project
This commit is contained in:
164
rag-web-ui/backend/app/services/reranker/external_api.py
Normal file
164
rag-web-ui/backend/app/services/reranker/external_api.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib import request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalRerankerClient:
|
||||
api_url: str
|
||||
api_key: str = ""
|
||||
model: str = ""
|
||||
timeout_seconds: float = 8.0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return bool(self.api_url)
|
||||
|
||||
@property
|
||||
def is_dashscope_rerank(self) -> bool:
|
||||
return "dashscope.aliyuncs.com" in self.api_url and "/services/rerank/" in self.api_url
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
metadata: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[List[float]]:
|
||||
if not self.enabled:
|
||||
return None
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
payload = self._build_payload(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n or len(documents),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(self._post_json, payload)
|
||||
scores = self._parse_scores(response, len(documents))
|
||||
return scores
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
req = request.Request(
|
||||
self.api_url,
|
||||
data=json.dumps(payload).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=self.timeout_seconds) as resp:
|
||||
body = resp.read().decode("utf-8")
|
||||
return json.loads(body)
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: int,
|
||||
metadata: Optional[List[Dict[str, Any]]],
|
||||
) -> Dict[str, Any]:
|
||||
if self.is_dashscope_rerank:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": {
|
||||
"return_documents": True,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return payload
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n,
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return payload
|
||||
|
||||
def _parse_scores(self, response: Dict[str, Any], expected_len: int) -> List[float]:
|
||||
# DashScope format:
|
||||
# {"output": {"results": [{"index": 0, "relevance_score": 0.98}, ...]}}
|
||||
output_block = response.get("output")
|
||||
if isinstance(output_block, dict) and isinstance(output_block.get("results"), list):
|
||||
raw_results = output_block["results"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
# Common response format #1:
|
||||
# {"results": [{"index": 0, "relevance_score": 0.98}, ...]}
|
||||
if isinstance(response.get("results"), list):
|
||||
raw_results = response["results"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
# Common response format #2:
|
||||
# {"scores": [0.9, 0.1, ...]}
|
||||
if isinstance(response.get("scores"), list):
|
||||
values = response["scores"]
|
||||
scores: List[float] = []
|
||||
for i in range(expected_len):
|
||||
try:
|
||||
scores.append(float(values[i]))
|
||||
except Exception:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
# Common response format #3:
|
||||
# {"data": [{"index": 0, "score": 0.88}, ...]}
|
||||
if isinstance(response.get("data"), list):
|
||||
raw_results = response["data"]
|
||||
scores = [0.0] * expected_len
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("score", item.get("relevance_score", 0.0))
|
||||
if isinstance(idx, int) and 0 <= idx < expected_len:
|
||||
try:
|
||||
scores[idx] = float(score)
|
||||
except Exception:
|
||||
scores[idx] = 0.0
|
||||
return scores
|
||||
|
||||
return [0.0] * expected_len
|
||||
Reference in New Issue
Block a user