From 0f8917d874bf1f702f5bf7a4cb83caaf29b2ccbc Mon Sep 17 00:00:00 2001 From: kuangji <819823900@qq.com> Date: Tue, 26 May 2026 10:30:07 +0800 Subject: [PATCH] skills upload function --- app/main.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 9 deletions(-) diff --git a/app/main.py b/app/main.py index 174d318..dfaadaa 100644 --- a/app/main.py +++ b/app/main.py @@ -3,6 +3,7 @@ from __future__ import annotations import shutil import threading import time +import zipfile from dataclasses import dataclass, field from pathlib import Path from uuid import uuid4 @@ -31,15 +32,22 @@ UPLOAD_DIR = ROOT_DIR / "uploads" OUTPUT_DIR = ROOT_DIR / "outputs" SKILL_ROOT = ROOT_DIR / "skills" DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills" -SKILL_COLLECTIONS = [ - "GJB438B-2009_prd_skills", - "GJB438C-2021_prd_skills", -] CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml" MAX_UPLOAD_BYTES = 30 * 1024 * 1024 +MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024 ProgressCallback = Callable[[int, str], None] +def _discover_skill_collections() -> list[str]: + if not SKILL_ROOT.exists(): + return [] + return sorted( + path.name + for path in SKILL_ROOT.iterdir() + if path.is_dir() and (path / "index.md").is_file() + ) + + def _skill_collection_path(collection_slug: str) -> Path: path = SKILL_ROOT / collection_slug if not path.exists() or not path.is_dir() or not (path / "index.md").exists(): @@ -49,10 +57,8 @@ def _skill_collection_path(collection_slug: str) -> Path: def _skill_collection_options() -> list[dict[str, object]]: options: list[dict[str, object]] = [] - for collection_slug in SKILL_COLLECTIONS: + for collection_slug in _discover_skill_collections(): path = SKILL_ROOT / collection_slug - if not path.exists() or not path.is_dir() or not (path / "index.md").exists(): - continue skills = load_skill_catalog(path) options.append( { @@ -64,6 +70,64 @@ def _skill_collection_options() -> list[dict[str, object]]: return options +def _validate_skill_archive_member(member_name: str) -> None: + path = Path(member_name) + if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute(): + raise HTTPException(status_code=400, detail="压缩包包含非法路径") + if any(part in {"", ".", ".."} for part in path.parts): + raise HTTPException(status_code=400, detail="压缩包包含非法路径") + + +def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]: + if not collection_slug or collection_slug in {".", ".."}: + raise HTTPException(status_code=400, detail="技能合集名称无效") + if "/" in collection_slug or "\\" in collection_slug: + raise HTTPException(status_code=400, detail="技能合集名称无效") + if not zipfile.is_zipfile(archive_path): + raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") + + SKILL_ROOT.mkdir(parents=True, exist_ok=True) + target_dir = SKILL_ROOT / collection_slug + temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp" + + try: + with zipfile.ZipFile(archive_path) as archive: + members = archive.infolist() + if not members: + raise HTTPException(status_code=400, detail="压缩包为空") + names = [member.filename for member in members] + for name in names: + _validate_skill_archive_member(name) + if "index.md" not in names: + raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md") + archive.extractall(temp_dir) + + skills = load_skill_catalog(temp_dir) + if not skills: + raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md") + + if target_dir.exists(): + shutil.rmtree(target_dir) + temp_dir.rename(target_dir) + return { + "slug": collection_slug, + "label": collection_slug.replace("_prd_skills", ""), + "skill_count": len(skills), + } + except HTTPException: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + raise + except zipfile.BadZipFile as exc: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc + except Exception: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + raise + + @dataclass class AnalysisTask: task_id: str @@ -247,18 +311,55 @@ def _run_analysis_task( @app.get("/", response_class=HTMLResponse) def index(request: Request) -> HTMLResponse: settings = load_api_config(CONFIG_PATH) + skill_collections = _skill_collection_options() return templates.TemplateResponse( request, "index.html", { "default_provider": settings.provider_name, - "skill_collection_count": len(SKILL_COLLECTIONS), - "skill_collections": _skill_collection_options(), + "skill_collection_count": len(skill_collections), + "skill_collections": skill_collections, "default_skill_collection": DEFAULT_SKILL_COLLECTION, }, ) +@app.get("/skill-collections") +def list_skill_collections() -> dict[str, object]: + return { + "collections": _skill_collection_options(), + "default_skill_collection": DEFAULT_SKILL_COLLECTION, + } + + +@app.post("/skill-collections/upload") +async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]: + if not file.filename or not file.filename.lower().endswith(".zip"): + raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包") + + content = await file.read() + if len(content) > MAX_SKILL_ARCHIVE_BYTES: + raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制") + + archive_name = Path(file.filename).name + collection_slug = archive_name[:-4] + archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}" + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + archive_path.write_bytes(content) + + try: + collection = install_skill_collection_zip(archive_path, collection_slug) + finally: + if archive_path.exists(): + archive_path.unlink() + + return { + "message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能", + "collection": collection, + "collections": _skill_collection_options(), + } + + @app.post("/analyze") async def analyze_docx( file: UploadFile = File(...),