skills upload function

This commit is contained in:
kuangji
2026-05-26 10:30:07 +08:00
parent bb2e55e889
commit 0f8917d874

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import shutil import shutil
import threading import threading
import time import time
import zipfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from uuid import uuid4 from uuid import uuid4
@@ -31,15 +32,22 @@ UPLOAD_DIR = ROOT_DIR / "uploads"
OUTPUT_DIR = ROOT_DIR / "outputs" OUTPUT_DIR = ROOT_DIR / "outputs"
SKILL_ROOT = ROOT_DIR / "skills" SKILL_ROOT = ROOT_DIR / "skills"
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills" DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
SKILL_COLLECTIONS = [
"GJB438B-2009_prd_skills",
"GJB438C-2021_prd_skills",
]
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml" CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
MAX_UPLOAD_BYTES = 30 * 1024 * 1024 MAX_UPLOAD_BYTES = 30 * 1024 * 1024
MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024
ProgressCallback = Callable[[int, str], None] ProgressCallback = Callable[[int, str], None]
def _discover_skill_collections() -> list[str]:
if not SKILL_ROOT.exists():
return []
return sorted(
path.name
for path in SKILL_ROOT.iterdir()
if path.is_dir() and (path / "index.md").is_file()
)
def _skill_collection_path(collection_slug: str) -> Path: def _skill_collection_path(collection_slug: str) -> Path:
path = SKILL_ROOT / collection_slug path = SKILL_ROOT / collection_slug
if not path.exists() or not path.is_dir() or not (path / "index.md").exists(): if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
@@ -49,10 +57,8 @@ def _skill_collection_path(collection_slug: str) -> Path:
def _skill_collection_options() -> list[dict[str, object]]: def _skill_collection_options() -> list[dict[str, object]]:
options: list[dict[str, object]] = [] options: list[dict[str, object]] = []
for collection_slug in SKILL_COLLECTIONS: for collection_slug in _discover_skill_collections():
path = SKILL_ROOT / collection_slug path = SKILL_ROOT / collection_slug
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
continue
skills = load_skill_catalog(path) skills = load_skill_catalog(path)
options.append( options.append(
{ {
@@ -64,6 +70,64 @@ def _skill_collection_options() -> list[dict[str, object]]:
return options return options
def _validate_skill_archive_member(member_name: str) -> None:
path = Path(member_name)
if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute():
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
if any(part in {"", ".", ".."} for part in path.parts):
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]:
if not collection_slug or collection_slug in {".", ".."}:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if "/" in collection_slug or "\\" in collection_slug:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if not zipfile.is_zipfile(archive_path):
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包")
SKILL_ROOT.mkdir(parents=True, exist_ok=True)
target_dir = SKILL_ROOT / collection_slug
temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp"
try:
with zipfile.ZipFile(archive_path) as archive:
members = archive.infolist()
if not members:
raise HTTPException(status_code=400, detail="压缩包为空")
names = [member.filename for member in members]
for name in names:
_validate_skill_archive_member(name)
if "index.md" not in names:
raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md")
archive.extractall(temp_dir)
skills = load_skill_catalog(temp_dir)
if not skills:
raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md")
if target_dir.exists():
shutil.rmtree(target_dir)
temp_dir.rename(target_dir)
return {
"slug": collection_slug,
"label": collection_slug.replace("_prd_skills", ""),
"skill_count": len(skills),
}
except HTTPException:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
except zipfile.BadZipFile as exc:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc
except Exception:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
@dataclass @dataclass
class AnalysisTask: class AnalysisTask:
task_id: str task_id: str
@@ -247,18 +311,55 @@ def _run_analysis_task(
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse: def index(request: Request) -> HTMLResponse:
settings = load_api_config(CONFIG_PATH) settings = load_api_config(CONFIG_PATH)
skill_collections = _skill_collection_options()
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"index.html", "index.html",
{ {
"default_provider": settings.provider_name, "default_provider": settings.provider_name,
"skill_collection_count": len(SKILL_COLLECTIONS), "skill_collection_count": len(skill_collections),
"skill_collections": _skill_collection_options(), "skill_collections": skill_collections,
"default_skill_collection": DEFAULT_SKILL_COLLECTION, "default_skill_collection": DEFAULT_SKILL_COLLECTION,
}, },
) )
@app.get("/skill-collections")
def list_skill_collections() -> dict[str, object]:
return {
"collections": _skill_collection_options(),
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
}
@app.post("/skill-collections/upload")
async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]:
if not file.filename or not file.filename.lower().endswith(".zip"):
raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包")
content = await file.read()
if len(content) > MAX_SKILL_ARCHIVE_BYTES:
raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制")
archive_name = Path(file.filename).name
collection_slug = archive_name[:-4]
archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
archive_path.write_bytes(content)
try:
collection = install_skill_collection_zip(archive_path, collection_slug)
finally:
if archive_path.exists():
archive_path.unlink()
return {
"message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能",
"collection": collection,
"collections": _skill_collection_options(),
}
@app.post("/analyze") @app.post("/analyze")
async def analyze_docx( async def analyze_docx(
file: UploadFile = File(...), file: UploadFile = File(...),