Files
2026-05-26 10:30:07 +08:00

443 lines
15 KiB
Python

from __future__ import annotations
import shutil
import threading
import time
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
from typing import Callable
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from app.analyzer import (
LLMClient,
build_analysis_prompt,
heuristic_analysis,
report_from_model_output,
select_relevant_skills,
)
from app.config import load_api_config
from app.docx_parser import parse_docx
from app.report_generator import generate_docx_report, generate_markdown_report
from app.skill_loader import load_skill_catalog
ROOT_DIR = Path(__file__).resolve().parent.parent
UPLOAD_DIR = ROOT_DIR / "uploads"
OUTPUT_DIR = ROOT_DIR / "outputs"
SKILL_ROOT = ROOT_DIR / "skills"
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
MAX_UPLOAD_BYTES = 30 * 1024 * 1024
MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024
ProgressCallback = Callable[[int, str], None]
def _discover_skill_collections() -> list[str]:
if not SKILL_ROOT.exists():
return []
return sorted(
path.name
for path in SKILL_ROOT.iterdir()
if path.is_dir() and (path / "index.md").is_file()
)
def _skill_collection_path(collection_slug: str) -> Path:
path = SKILL_ROOT / collection_slug
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
raise HTTPException(status_code=400, detail="技能集合不存在")
return path
def _skill_collection_options() -> list[dict[str, object]]:
options: list[dict[str, object]] = []
for collection_slug in _discover_skill_collections():
path = SKILL_ROOT / collection_slug
skills = load_skill_catalog(path)
options.append(
{
"slug": collection_slug,
"label": collection_slug.replace("_prd_skills", ""),
"skill_count": len(skills),
}
)
return options
def _validate_skill_archive_member(member_name: str) -> None:
path = Path(member_name)
if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute():
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
if any(part in {"", ".", ".."} for part in path.parts):
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]:
if not collection_slug or collection_slug in {".", ".."}:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if "/" in collection_slug or "\\" in collection_slug:
raise HTTPException(status_code=400, detail="技能合集名称无效")
if not zipfile.is_zipfile(archive_path):
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包")
SKILL_ROOT.mkdir(parents=True, exist_ok=True)
target_dir = SKILL_ROOT / collection_slug
temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp"
try:
with zipfile.ZipFile(archive_path) as archive:
members = archive.infolist()
if not members:
raise HTTPException(status_code=400, detail="压缩包为空")
names = [member.filename for member in members]
for name in names:
_validate_skill_archive_member(name)
if "index.md" not in names:
raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md")
archive.extractall(temp_dir)
skills = load_skill_catalog(temp_dir)
if not skills:
raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md")
if target_dir.exists():
shutil.rmtree(target_dir)
temp_dir.rename(target_dir)
return {
"slug": collection_slug,
"label": collection_slug.replace("_prd_skills", ""),
"skill_count": len(skills),
}
except HTTPException:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
except zipfile.BadZipFile as exc:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc
except Exception:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise
@dataclass
class AnalysisTask:
task_id: str
source_filename: str
status: str = "queued"
progress: int = 0
message: str = "任务已创建"
summary: str = ""
matched_skills: list[str] = field(default_factory=list)
downloads: dict[str, str] = field(default_factory=dict)
error: str = ""
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, object]:
return {
"task_id": self.task_id,
"source_filename": self.source_filename,
"status": self.status,
"progress": self.progress,
"message": self.message,
"summary": self.summary,
"matched_skills": self.matched_skills,
"downloads": self.downloads,
"error": self.error,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class AnalysisTaskStore:
def __init__(self) -> None:
self._tasks: dict[str, AnalysisTask] = {}
self._lock = threading.Lock()
def create(self, source_filename: str) -> AnalysisTask:
task = AnalysisTask(task_id=uuid4().hex, source_filename=source_filename)
with self._lock:
self._tasks[task.task_id] = task
return task
def update(
self,
task_id: str,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
summary: str | None = None,
matched_skills: list[str] | None = None,
downloads: dict[str, str] | None = None,
error: str | None = None,
) -> AnalysisTask:
with self._lock:
task = self._tasks[task_id]
if status is not None:
task.status = status
if progress is not None:
task.progress = progress
if message is not None:
task.message = message
if summary is not None:
task.summary = summary
if matched_skills is not None:
task.matched_skills = matched_skills
if downloads is not None:
task.downloads = downloads
if error is not None:
task.error = error
task.updated_at = time.time()
return task
def get(self, task_id: str) -> AnalysisTask | None:
with self._lock:
return self._tasks.get(task_id)
TASK_STORE = AnalysisTaskStore()
app = FastAPI(title="GJB438C DOCX 规范分析")
templates = Jinja2Templates(directory=str(ROOT_DIR / "app" / "templates"))
app.mount("/static", StaticFiles(directory=str(ROOT_DIR / "app" / "static")), name="static")
def analyze_saved_docx(
upload_path: Path,
provider: str | None = None,
use_model: bool = True,
display_filename: str | None = None,
skill_collection: str = DEFAULT_SKILL_COLLECTION,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
def progress(percent: int, message: str) -> None:
if progress_callback is not None:
progress_callback(percent, message)
progress(5, "正在解析 DOCX 文档")
parsed = parse_docx(upload_path, display_filename=display_filename)
progress(20, "DOCX 解析完成,正在加载技能规范")
skills = load_skill_catalog(_skill_collection_path(skill_collection))
progress(35, "技能规范已加载,正在匹配候选技能")
selected_skills = select_relevant_skills(parsed, skills)
progress(50, f"已匹配 {len(selected_skills)} 项技能,正在读取模型配置")
settings = load_api_config(CONFIG_PATH, provider_name=provider or None)
if use_model:
progress(65, f"正在调用 {settings.provider.model} 进行分析")
prompt = build_analysis_prompt(parsed, selected_skills)
try:
output = LLMClient(settings.provider).complete(prompt)
report = report_from_model_output(
parsed,
selected_skills,
settings.provider_name,
settings.provider.model,
output,
)
except Exception as exc:
report = heuristic_analysis(parsed, selected_skills)
report = report.__class__(
source_filename=report.source_filename,
provider_name=settings.provider_name,
model_name=f"{settings.provider.model} (调用失败,已降级)",
matched_skills=report.matched_skills,
summary=f"{report.summary};模型调用失败:{exc}",
findings=report.findings,
recommendations=report.recommendations,
raw_model_output=f"模型调用失败:{exc}\n\n{report.raw_model_output}",
)
else:
progress(70, "已关闭模型分析,正在使用本地规则生成结果")
report = heuristic_analysis(parsed, selected_skills)
progress(85, "正在生成 Markdown 分析文档")
markdown_path = generate_markdown_report(report, OUTPUT_DIR)
progress(100, "分析完成")
return {
"source_filename": parsed.filename,
"summary": report.summary,
"matched_skills": report.matched_skills,
"downloads": {"markdown": f"/download/{markdown_path.name}"},
"markdown_filename": markdown_path.name,
}
def _run_analysis_task(
task_id: str,
upload_path: Path,
provider: str | None,
use_model: bool,
display_filename: str,
skill_collection: str = DEFAULT_SKILL_COLLECTION,
) -> None:
def on_progress(progress: int, message: str) -> None:
TASK_STORE.update(task_id, status="running", progress=progress, message=message)
try:
TASK_STORE.update(task_id, status="running", progress=1, message="任务已启动")
result = analyze_saved_docx(
upload_path,
provider=provider,
use_model=use_model,
display_filename=display_filename,
skill_collection=skill_collection,
progress_callback=on_progress,
)
TASK_STORE.update(
task_id,
status="completed",
progress=100,
message="分析完成",
summary=str(result["summary"]),
matched_skills=list(result["matched_skills"]),
downloads=dict(result["downloads"]),
)
except Exception as exc:
TASK_STORE.update(task_id, status="error", progress=100, message="分析失败", error=str(exc))
@app.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
settings = load_api_config(CONFIG_PATH)
skill_collections = _skill_collection_options()
return templates.TemplateResponse(
request,
"index.html",
{
"default_provider": settings.provider_name,
"skill_collection_count": len(skill_collections),
"skill_collections": skill_collections,
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
},
)
@app.get("/skill-collections")
def list_skill_collections() -> dict[str, object]:
return {
"collections": _skill_collection_options(),
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
}
@app.post("/skill-collections/upload")
async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]:
if not file.filename or not file.filename.lower().endswith(".zip"):
raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包")
content = await file.read()
if len(content) > MAX_SKILL_ARCHIVE_BYTES:
raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制")
archive_name = Path(file.filename).name
collection_slug = archive_name[:-4]
archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
archive_path.write_bytes(content)
try:
collection = install_skill_collection_zip(archive_path, collection_slug)
finally:
if archive_path.exists():
archive_path.unlink()
return {
"message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能",
"collection": collection,
"collections": _skill_collection_options(),
}
@app.post("/analyze")
async def analyze_docx(
file: UploadFile = File(...),
provider: str | None = Form(None),
use_model: str = Form("true"),
skill_collection: str = Form(DEFAULT_SKILL_COLLECTION),
):
if not file.filename or not file.filename.lower().endswith(".docx"):
raise HTTPException(status_code=400, detail="仅支持上传 .docx 文件")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
upload_path = UPLOAD_DIR / f"{uuid4().hex}_{Path(file.filename).name}"
content = await file.read()
if len(content) > MAX_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="文件超过 30MB 限制")
upload_path.write_bytes(content)
should_use_model = use_model.lower() in {"1", "true", "yes", "on"}
task = TASK_STORE.create(Path(file.filename).name)
threading.Thread(
target=_run_analysis_task,
args=(
task.task_id,
upload_path,
provider,
should_use_model,
Path(file.filename).name,
skill_collection,
),
daemon=True,
).start()
return {
"task_id": task.task_id,
"status_url": f"/tasks/{task.task_id}",
"status": task.status,
"progress": task.progress,
"message": "任务已提交",
}
@app.get("/tasks/{task_id}")
def get_task(task_id: str) -> dict[str, object]:
task = TASK_STORE.get(task_id)
if task is None:
raise HTTPException(status_code=404, detail="任务不存在")
return task.to_dict()
@app.get("/download/{filename}")
def download_report(filename: str):
safe_name = Path(filename).name
path = OUTPUT_DIR / safe_name
if not path.exists() or not path.is_file():
raise HTTPException(status_code=404, detail="报告不存在")
media_type = "application/octet-stream"
if path.suffix == ".docx":
media_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
elif path.suffix == ".md":
media_type = "text/markdown; charset=utf-8"
return FileResponse(path, filename=path.name, media_type=media_type)
@app.post("/cleanup")
def cleanup_runtime_files() -> dict[str, int]:
removed = 0
for directory in (UPLOAD_DIR, OUTPUT_DIR):
if not directory.exists():
continue
for path in directory.iterdir():
if path.is_file():
path.unlink()
removed += 1
elif path.is_dir():
shutil.rmtree(path)
removed += 1
return {"removed": removed}