skills upload function
This commit is contained in:
119
app/main.py
119
app/main.py
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import zipfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -31,15 +32,22 @@ UPLOAD_DIR = ROOT_DIR / "uploads"
|
|||||||
OUTPUT_DIR = ROOT_DIR / "outputs"
|
OUTPUT_DIR = ROOT_DIR / "outputs"
|
||||||
SKILL_ROOT = ROOT_DIR / "skills"
|
SKILL_ROOT = ROOT_DIR / "skills"
|
||||||
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
|
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
|
||||||
SKILL_COLLECTIONS = [
|
|
||||||
"GJB438B-2009_prd_skills",
|
|
||||||
"GJB438C-2021_prd_skills",
|
|
||||||
]
|
|
||||||
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
|
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
|
||||||
MAX_UPLOAD_BYTES = 30 * 1024 * 1024
|
MAX_UPLOAD_BYTES = 30 * 1024 * 1024
|
||||||
|
MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024
|
||||||
ProgressCallback = Callable[[int, str], None]
|
ProgressCallback = Callable[[int, str], None]
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_skill_collections() -> list[str]:
|
||||||
|
if not SKILL_ROOT.exists():
|
||||||
|
return []
|
||||||
|
return sorted(
|
||||||
|
path.name
|
||||||
|
for path in SKILL_ROOT.iterdir()
|
||||||
|
if path.is_dir() and (path / "index.md").is_file()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _skill_collection_path(collection_slug: str) -> Path:
|
def _skill_collection_path(collection_slug: str) -> Path:
|
||||||
path = SKILL_ROOT / collection_slug
|
path = SKILL_ROOT / collection_slug
|
||||||
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
|
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
|
||||||
@@ -49,10 +57,8 @@ def _skill_collection_path(collection_slug: str) -> Path:
|
|||||||
|
|
||||||
def _skill_collection_options() -> list[dict[str, object]]:
|
def _skill_collection_options() -> list[dict[str, object]]:
|
||||||
options: list[dict[str, object]] = []
|
options: list[dict[str, object]] = []
|
||||||
for collection_slug in SKILL_COLLECTIONS:
|
for collection_slug in _discover_skill_collections():
|
||||||
path = SKILL_ROOT / collection_slug
|
path = SKILL_ROOT / collection_slug
|
||||||
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
|
|
||||||
continue
|
|
||||||
skills = load_skill_catalog(path)
|
skills = load_skill_catalog(path)
|
||||||
options.append(
|
options.append(
|
||||||
{
|
{
|
||||||
@@ -64,6 +70,64 @@ def _skill_collection_options() -> list[dict[str, object]]:
|
|||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_skill_archive_member(member_name: str) -> None:
|
||||||
|
path = Path(member_name)
|
||||||
|
if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute():
|
||||||
|
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
|
||||||
|
if any(part in {"", ".", ".."} for part in path.parts):
|
||||||
|
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
|
||||||
|
|
||||||
|
|
||||||
|
def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]:
|
||||||
|
if not collection_slug or collection_slug in {".", ".."}:
|
||||||
|
raise HTTPException(status_code=400, detail="技能合集名称无效")
|
||||||
|
if "/" in collection_slug or "\\" in collection_slug:
|
||||||
|
raise HTTPException(status_code=400, detail="技能合集名称无效")
|
||||||
|
if not zipfile.is_zipfile(archive_path):
|
||||||
|
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包")
|
||||||
|
|
||||||
|
SKILL_ROOT.mkdir(parents=True, exist_ok=True)
|
||||||
|
target_dir = SKILL_ROOT / collection_slug
|
||||||
|
temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(archive_path) as archive:
|
||||||
|
members = archive.infolist()
|
||||||
|
if not members:
|
||||||
|
raise HTTPException(status_code=400, detail="压缩包为空")
|
||||||
|
names = [member.filename for member in members]
|
||||||
|
for name in names:
|
||||||
|
_validate_skill_archive_member(name)
|
||||||
|
if "index.md" not in names:
|
||||||
|
raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md")
|
||||||
|
archive.extractall(temp_dir)
|
||||||
|
|
||||||
|
skills = load_skill_catalog(temp_dir)
|
||||||
|
if not skills:
|
||||||
|
raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md")
|
||||||
|
|
||||||
|
if target_dir.exists():
|
||||||
|
shutil.rmtree(target_dir)
|
||||||
|
temp_dir.rename(target_dir)
|
||||||
|
return {
|
||||||
|
"slug": collection_slug,
|
||||||
|
"label": collection_slug.replace("_prd_skills", ""),
|
||||||
|
"skill_count": len(skills),
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
raise
|
||||||
|
except zipfile.BadZipFile as exc:
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc
|
||||||
|
except Exception:
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AnalysisTask:
|
class AnalysisTask:
|
||||||
task_id: str
|
task_id: str
|
||||||
@@ -247,18 +311,55 @@ def _run_analysis_task(
|
|||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
def index(request: Request) -> HTMLResponse:
|
def index(request: Request) -> HTMLResponse:
|
||||||
settings = load_api_config(CONFIG_PATH)
|
settings = load_api_config(CONFIG_PATH)
|
||||||
|
skill_collections = _skill_collection_options()
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"index.html",
|
"index.html",
|
||||||
{
|
{
|
||||||
"default_provider": settings.provider_name,
|
"default_provider": settings.provider_name,
|
||||||
"skill_collection_count": len(SKILL_COLLECTIONS),
|
"skill_collection_count": len(skill_collections),
|
||||||
"skill_collections": _skill_collection_options(),
|
"skill_collections": skill_collections,
|
||||||
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
|
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/skill-collections")
|
||||||
|
def list_skill_collections() -> dict[str, object]:
|
||||||
|
return {
|
||||||
|
"collections": _skill_collection_options(),
|
||||||
|
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/skill-collections/upload")
|
||||||
|
async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]:
|
||||||
|
if not file.filename or not file.filename.lower().endswith(".zip"):
|
||||||
|
raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包")
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
if len(content) > MAX_SKILL_ARCHIVE_BYTES:
|
||||||
|
raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制")
|
||||||
|
|
||||||
|
archive_name = Path(file.filename).name
|
||||||
|
collection_slug = archive_name[:-4]
|
||||||
|
archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}"
|
||||||
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
archive_path.write_bytes(content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
collection = install_skill_collection_zip(archive_path, collection_slug)
|
||||||
|
finally:
|
||||||
|
if archive_path.exists():
|
||||||
|
archive_path.unlink()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能",
|
||||||
|
"collection": collection,
|
||||||
|
"collections": _skill_collection_options(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/analyze")
|
@app.post("/analyze")
|
||||||
async def analyze_docx(
|
async def analyze_docx(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
|
|||||||
Reference in New Issue
Block a user