skills upload function

This commit is contained in:
kuangji
2026-05-26 10:30:07 +08:00
parent bb2e55e889
commit 0f8917d874

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import shutil
import threading
import time
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
@@ -31,15 +32,22 @@ UPLOAD_DIR = ROOT_DIR / "uploads"
OUTPUT_DIR = ROOT_DIR / "outputs"
SKILL_ROOT = ROOT_DIR / "skills"
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
SKILL_COLLECTIONS = [
"GJB438B-2009_prd_skills",
"GJB438C-2021_prd_skills",
]
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
MAX_UPLOAD_BYTES = 30 * 1024 * 1024
MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024
ProgressCallback = Callable[[int, str], None]
def _discover_skill_collections() -> list[str]:
if not SKILL_ROOT.exists():
return []
return sorted(
path.name
for path in SKILL_ROOT.iterdir()
if path.is_dir() and (path / "index.md").is_file()
)
def _skill_collection_path(collection_slug: str) -> Path:
path = SKILL_ROOT / collection_slug
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
@@ -49,10 +57,8 @@ def _skill_collection_path(collection_slug: str) -> Path:
def _skill_collection_options() -> list[dict[str, object]]:
options: list[dict[str, object]] = []
for collection_slug in SKILL_COLLECTIONS:
for collection_slug in _discover_skill_collections():
path = SKILL_ROOT / collection_slug
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
continue
skills = load_skill_catalog(path)
options.append(
{
@@ -64,6 +70,64 @@ def _skill_collection_options() -> list[dict[str, object]]:
return options
def _validate_skill_archive_member(member_name: str) -> None:
path = Path(member_name)
if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute():
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
if any(part in {"", ".", ".."} for part in path.parts):
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]:
if not collection_slug or collection_slug in {".", ".."}:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if "/" in collection_slug or "\\" in collection_slug:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if not zipfile.is_zipfile(archive_path):
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包")
SKILL_ROOT.mkdir(parents=True, exist_ok=True)
target_dir = SKILL_ROOT / collection_slug
temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp"
try:
with zipfile.ZipFile(archive_path) as archive:
members = archive.infolist()
if not members:
raise HTTPException(status_code=400, detail="压缩包为空")
names = [member.filename for member in members]
for name in names:
_validate_skill_archive_member(name)
if "index.md" not in names:
raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md")
archive.extractall(temp_dir)
skills = load_skill_catalog(temp_dir)
if not skills:
raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md")
if target_dir.exists():
shutil.rmtree(target_dir)
temp_dir.rename(target_dir)
return {
"slug": collection_slug,
"label": collection_slug.replace("_prd_skills", ""),
"skill_count": len(skills),
}
except HTTPException:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
except zipfile.BadZipFile as exc:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc
except Exception:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
@dataclass
class AnalysisTask:
task_id: str
@@ -247,18 +311,55 @@ def _run_analysis_task(
@app.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
settings = load_api_config(CONFIG_PATH)
skill_collections = _skill_collection_options()
return templates.TemplateResponse(
request,
"index.html",
{
"default_provider": settings.provider_name,
"skill_collection_count": len(SKILL_COLLECTIONS),
"skill_collections": _skill_collection_options(),
"skill_collection_count": len(skill_collections),
"skill_collections": skill_collections,
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
},
)
@app.get("/skill-collections")
def list_skill_collections() -> dict[str, object]:
return {
"collections": _skill_collection_options(),
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
}
@app.post("/skill-collections/upload")
async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]:
if not file.filename or not file.filename.lower().endswith(".zip"):
raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包")
content = await file.read()
if len(content) > MAX_SKILL_ARCHIVE_BYTES:
raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制")
archive_name = Path(file.filename).name
collection_slug = archive_name[:-4]
archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
archive_path.write_bytes(content)
try:
collection = install_skill_collection_zip(archive_path, collection_slug)
finally:
if archive_path.exists():
archive_path.unlink()
return {
"message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能",
"collection": collection,
"collections": _skill_collection_options(),
}
@app.post("/analyze")
async def analyze_docx(
file: UploadFile = File(...),