452 lines
16 KiB
Python
452 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import shutil
|
|
import threading
|
|
import time
|
|
import zipfile
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from uuid import uuid4
|
|
from typing import Callable
|
|
|
|
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
|
|
from app.analyzer import (
|
|
LLMClient,
|
|
build_analysis_prompt,
|
|
heuristic_analysis,
|
|
report_from_model_output,
|
|
select_relevant_skills,
|
|
)
|
|
from app.config import load_api_config
|
|
from app.docx_parser import parse_docx
|
|
from app.report_generator import generate_docx_report, generate_markdown_report
|
|
from app.review_filler import fill_review_docx_from_analysis
|
|
from app.skill_loader import load_skill_catalog
|
|
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent.parent
|
|
UPLOAD_DIR = ROOT_DIR / "uploads"
|
|
OUTPUT_DIR = ROOT_DIR / "outputs"
|
|
SKILL_ROOT = ROOT_DIR / "skills"
|
|
DEFAULT_SKILL_COLLECTION = "GJB438C-2021_prd_skills"
|
|
CONFIG_PATH = ROOT_DIR / "configs" / "api_config.yaml"
|
|
REVIEW_DOCX_TEMPLATE = ROOT_DIR / "test" / "附录A文档审查.docx"
|
|
MAX_UPLOAD_BYTES = 30 * 1024 * 1024
|
|
MAX_SKILL_ARCHIVE_BYTES = 50 * 1024 * 1024
|
|
ProgressCallback = Callable[[int, str], None]
|
|
|
|
|
|
def _discover_skill_collections() -> list[str]:
|
|
if not SKILL_ROOT.exists():
|
|
return []
|
|
return sorted(
|
|
path.name
|
|
for path in SKILL_ROOT.iterdir()
|
|
if path.is_dir() and (path / "index.md").is_file()
|
|
)
|
|
|
|
|
|
def _skill_collection_path(collection_slug: str) -> Path:
|
|
path = SKILL_ROOT / collection_slug
|
|
if not path.exists() or not path.is_dir() or not (path / "index.md").exists():
|
|
raise HTTPException(status_code=400, detail="技能集合不存在")
|
|
return path
|
|
|
|
|
|
def _skill_collection_options() -> list[dict[str, object]]:
|
|
options: list[dict[str, object]] = []
|
|
for collection_slug in _discover_skill_collections():
|
|
path = SKILL_ROOT / collection_slug
|
|
skills = load_skill_catalog(path)
|
|
options.append(
|
|
{
|
|
"slug": collection_slug,
|
|
"label": collection_slug.replace("_prd_skills", ""),
|
|
"skill_count": len(skills),
|
|
}
|
|
)
|
|
return options
|
|
|
|
|
|
def _validate_skill_archive_member(member_name: str) -> None:
|
|
path = Path(member_name)
|
|
if not member_name or "\\" in member_name or member_name.startswith(("/", "\\")) or path.is_absolute():
|
|
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
|
|
if any(part in {"", ".", ".."} for part in path.parts):
|
|
raise HTTPException(status_code=400, detail="压缩包包含非法路径")
|
|
|
|
|
|
def install_skill_collection_zip(archive_path: Path, collection_slug: str) -> dict[str, object]:
|
|
if not collection_slug or collection_slug in {".", ".."}:
|
|
raise HTTPException(status_code=400, detail="技能合集名称无效")
|
|
if "/" in collection_slug or "\\" in collection_slug:
|
|
raise HTTPException(status_code=400, detail="技能合集名称无效")
|
|
if not zipfile.is_zipfile(archive_path):
|
|
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包")
|
|
|
|
SKILL_ROOT.mkdir(parents=True, exist_ok=True)
|
|
target_dir = SKILL_ROOT / collection_slug
|
|
temp_dir = SKILL_ROOT / f".{collection_slug}.{uuid4().hex}.tmp"
|
|
|
|
try:
|
|
with zipfile.ZipFile(archive_path) as archive:
|
|
members = archive.infolist()
|
|
if not members:
|
|
raise HTTPException(status_code=400, detail="压缩包为空")
|
|
names = [member.filename for member in members]
|
|
for name in names:
|
|
_validate_skill_archive_member(name)
|
|
if "index.md" not in names:
|
|
raise HTTPException(status_code=400, detail="技能合集压缩包根目录必须包含 index.md")
|
|
archive.extractall(temp_dir)
|
|
|
|
skills = load_skill_catalog(temp_dir)
|
|
if not skills:
|
|
raise HTTPException(status_code=400, detail="技能合集未包含有效 SKILL.md")
|
|
|
|
if target_dir.exists():
|
|
shutil.rmtree(target_dir)
|
|
temp_dir.rename(target_dir)
|
|
return {
|
|
"slug": collection_slug,
|
|
"label": collection_slug.replace("_prd_skills", ""),
|
|
"skill_count": len(skills),
|
|
}
|
|
except HTTPException:
|
|
if temp_dir.exists():
|
|
shutil.rmtree(temp_dir)
|
|
raise
|
|
except zipfile.BadZipFile as exc:
|
|
if temp_dir.exists():
|
|
shutil.rmtree(temp_dir)
|
|
raise HTTPException(status_code=400, detail="仅支持有效的 zip 压缩包") from exc
|
|
except Exception:
|
|
if temp_dir.exists():
|
|
shutil.rmtree(temp_dir)
|
|
raise
|
|
|
|
|
|
@dataclass
|
|
class AnalysisTask:
|
|
task_id: str
|
|
source_filename: str
|
|
status: str = "queued"
|
|
progress: int = 0
|
|
message: str = "任务已创建"
|
|
summary: str = ""
|
|
matched_skills: list[str] = field(default_factory=list)
|
|
downloads: dict[str, str] = field(default_factory=dict)
|
|
error: str = ""
|
|
created_at: float = field(default_factory=time.time)
|
|
updated_at: float = field(default_factory=time.time)
|
|
|
|
def to_dict(self) -> dict[str, object]:
|
|
return {
|
|
"task_id": self.task_id,
|
|
"source_filename": self.source_filename,
|
|
"status": self.status,
|
|
"progress": self.progress,
|
|
"message": self.message,
|
|
"summary": self.summary,
|
|
"matched_skills": self.matched_skills,
|
|
"downloads": self.downloads,
|
|
"error": self.error,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at,
|
|
}
|
|
|
|
|
|
class AnalysisTaskStore:
|
|
def __init__(self) -> None:
|
|
self._tasks: dict[str, AnalysisTask] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def create(self, source_filename: str) -> AnalysisTask:
|
|
task = AnalysisTask(task_id=uuid4().hex, source_filename=source_filename)
|
|
with self._lock:
|
|
self._tasks[task.task_id] = task
|
|
return task
|
|
|
|
def update(
|
|
self,
|
|
task_id: str,
|
|
*,
|
|
status: str | None = None,
|
|
progress: int | None = None,
|
|
message: str | None = None,
|
|
summary: str | None = None,
|
|
matched_skills: list[str] | None = None,
|
|
downloads: dict[str, str] | None = None,
|
|
error: str | None = None,
|
|
) -> AnalysisTask:
|
|
with self._lock:
|
|
task = self._tasks[task_id]
|
|
if status is not None:
|
|
task.status = status
|
|
if progress is not None:
|
|
task.progress = progress
|
|
if message is not None:
|
|
task.message = message
|
|
if summary is not None:
|
|
task.summary = summary
|
|
if matched_skills is not None:
|
|
task.matched_skills = matched_skills
|
|
if downloads is not None:
|
|
task.downloads = downloads
|
|
if error is not None:
|
|
task.error = error
|
|
task.updated_at = time.time()
|
|
return task
|
|
|
|
def get(self, task_id: str) -> AnalysisTask | None:
|
|
with self._lock:
|
|
return self._tasks.get(task_id)
|
|
|
|
|
|
TASK_STORE = AnalysisTaskStore()
|
|
|
|
app = FastAPI(title="GJB438C DOCX 规范分析")
|
|
templates = Jinja2Templates(directory=str(ROOT_DIR / "app" / "templates"))
|
|
app.mount("/static", StaticFiles(directory=str(ROOT_DIR / "app" / "static")), name="static")
|
|
|
|
|
|
def analyze_saved_docx(
|
|
upload_path: Path,
|
|
provider: str | None = None,
|
|
use_model: bool = True,
|
|
display_filename: str | None = None,
|
|
skill_collection: str = DEFAULT_SKILL_COLLECTION,
|
|
progress_callback: ProgressCallback | None = None,
|
|
) -> dict[str, object]:
|
|
def progress(percent: int, message: str) -> None:
|
|
if progress_callback is not None:
|
|
progress_callback(percent, message)
|
|
|
|
progress(5, "正在解析 DOCX 文档")
|
|
parsed = parse_docx(upload_path, display_filename=display_filename)
|
|
progress(20, "DOCX 解析完成,正在加载技能规范")
|
|
skills = load_skill_catalog(_skill_collection_path(skill_collection))
|
|
progress(35, "技能规范已加载,正在匹配候选技能")
|
|
selected_skills = select_relevant_skills(parsed, skills)
|
|
progress(50, f"已匹配 {len(selected_skills)} 项技能,正在读取模型配置")
|
|
settings = load_api_config(CONFIG_PATH, provider_name=provider or None)
|
|
|
|
if use_model:
|
|
progress(65, f"正在调用 {settings.provider.model} 进行分析")
|
|
prompt = build_analysis_prompt(parsed, selected_skills)
|
|
try:
|
|
output = LLMClient(settings.provider).complete(prompt)
|
|
report = report_from_model_output(
|
|
parsed,
|
|
selected_skills,
|
|
settings.provider_name,
|
|
settings.provider.model,
|
|
output,
|
|
)
|
|
except Exception as exc:
|
|
report = heuristic_analysis(parsed, selected_skills)
|
|
report = report.__class__(
|
|
source_filename=report.source_filename,
|
|
provider_name=settings.provider_name,
|
|
model_name=f"{settings.provider.model} (调用失败,已降级)",
|
|
matched_skills=report.matched_skills,
|
|
summary=f"{report.summary};模型调用失败:{exc}",
|
|
findings=report.findings,
|
|
recommendations=report.recommendations,
|
|
raw_model_output=f"模型调用失败:{exc}\n\n{report.raw_model_output}",
|
|
)
|
|
else:
|
|
progress(70, "已关闭模型分析,正在使用本地规则生成结果")
|
|
report = heuristic_analysis(parsed, selected_skills)
|
|
|
|
progress(85, "正在生成 Markdown 分析文档")
|
|
markdown_path = generate_markdown_report(report, OUTPUT_DIR)
|
|
progress(92, "正在生成 DOCX 文档审查单")
|
|
review_docx_path = markdown_path.with_name(f"{markdown_path.stem}_review.docx")
|
|
fill_review_docx_from_analysis(markdown_path, REVIEW_DOCX_TEMPLATE, review_docx_path)
|
|
progress(100, "分析完成")
|
|
|
|
return {
|
|
"source_filename": parsed.filename,
|
|
"summary": report.summary,
|
|
"matched_skills": report.matched_skills,
|
|
"downloads": {
|
|
"markdown": f"/download/{markdown_path.name}",
|
|
"review_docx": f"/download/{review_docx_path.name}",
|
|
},
|
|
"markdown_filename": markdown_path.name,
|
|
"review_docx_filename": review_docx_path.name,
|
|
}
|
|
|
|
|
|
def _run_analysis_task(
|
|
task_id: str,
|
|
upload_path: Path,
|
|
provider: str | None,
|
|
use_model: bool,
|
|
display_filename: str,
|
|
skill_collection: str = DEFAULT_SKILL_COLLECTION,
|
|
) -> None:
|
|
def on_progress(progress: int, message: str) -> None:
|
|
TASK_STORE.update(task_id, status="running", progress=progress, message=message)
|
|
|
|
try:
|
|
TASK_STORE.update(task_id, status="running", progress=1, message="任务已启动")
|
|
result = analyze_saved_docx(
|
|
upload_path,
|
|
provider=provider,
|
|
use_model=use_model,
|
|
display_filename=display_filename,
|
|
skill_collection=skill_collection,
|
|
progress_callback=on_progress,
|
|
)
|
|
TASK_STORE.update(
|
|
task_id,
|
|
status="completed",
|
|
progress=100,
|
|
message="分析完成",
|
|
summary=str(result["summary"]),
|
|
matched_skills=list(result["matched_skills"]),
|
|
downloads=dict(result["downloads"]),
|
|
)
|
|
except Exception as exc:
|
|
TASK_STORE.update(task_id, status="error", progress=100, message="分析失败", error=str(exc))
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def index(request: Request) -> HTMLResponse:
|
|
settings = load_api_config(CONFIG_PATH)
|
|
skill_collections = _skill_collection_options()
|
|
return templates.TemplateResponse(
|
|
request,
|
|
"index.html",
|
|
{
|
|
"default_provider": settings.provider_name,
|
|
"skill_collection_count": len(skill_collections),
|
|
"skill_collections": skill_collections,
|
|
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/skill-collections")
|
|
def list_skill_collections() -> dict[str, object]:
|
|
return {
|
|
"collections": _skill_collection_options(),
|
|
"default_skill_collection": DEFAULT_SKILL_COLLECTION,
|
|
}
|
|
|
|
|
|
@app.post("/skill-collections/upload")
|
|
async def upload_skill_collection(file: UploadFile = File(...)) -> dict[str, object]:
|
|
if not file.filename or not file.filename.lower().endswith(".zip"):
|
|
raise HTTPException(status_code=400, detail="技能合集仅支持上传 .zip 压缩包")
|
|
|
|
content = await file.read()
|
|
if len(content) > MAX_SKILL_ARCHIVE_BYTES:
|
|
raise HTTPException(status_code=413, detail="技能合集压缩包超过 50MB 限制")
|
|
|
|
archive_name = Path(file.filename).name
|
|
collection_slug = archive_name[:-4]
|
|
archive_path = UPLOAD_DIR / f"{uuid4().hex}_{archive_name}"
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
archive_path.write_bytes(content)
|
|
|
|
try:
|
|
collection = install_skill_collection_zip(archive_path, collection_slug)
|
|
finally:
|
|
if archive_path.exists():
|
|
archive_path.unlink()
|
|
|
|
return {
|
|
"message": f"技能合集 {collection['slug']} 上传成功,已加载 {collection['skill_count']} 项技能",
|
|
"collection": collection,
|
|
"collections": _skill_collection_options(),
|
|
}
|
|
|
|
|
|
@app.post("/analyze")
|
|
async def analyze_docx(
|
|
file: UploadFile = File(...),
|
|
provider: str | None = Form(None),
|
|
use_model: str = Form("true"),
|
|
skill_collection: str = Form(DEFAULT_SKILL_COLLECTION),
|
|
):
|
|
if not file.filename or not file.filename.lower().endswith(".docx"):
|
|
raise HTTPException(status_code=400, detail="仅支持上传 .docx 文件")
|
|
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
upload_path = UPLOAD_DIR / f"{uuid4().hex}_{Path(file.filename).name}"
|
|
|
|
content = await file.read()
|
|
if len(content) > MAX_UPLOAD_BYTES:
|
|
raise HTTPException(status_code=413, detail="文件超过 30MB 限制")
|
|
upload_path.write_bytes(content)
|
|
|
|
should_use_model = use_model.lower() in {"1", "true", "yes", "on"}
|
|
task = TASK_STORE.create(Path(file.filename).name)
|
|
threading.Thread(
|
|
target=_run_analysis_task,
|
|
args=(
|
|
task.task_id,
|
|
upload_path,
|
|
provider,
|
|
should_use_model,
|
|
Path(file.filename).name,
|
|
skill_collection,
|
|
),
|
|
daemon=True,
|
|
).start()
|
|
return {
|
|
"task_id": task.task_id,
|
|
"status_url": f"/tasks/{task.task_id}",
|
|
"status": task.status,
|
|
"progress": task.progress,
|
|
"message": "任务已提交",
|
|
}
|
|
|
|
|
|
@app.get("/tasks/{task_id}")
|
|
def get_task(task_id: str) -> dict[str, object]:
|
|
task = TASK_STORE.get(task_id)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
return task.to_dict()
|
|
|
|
|
|
@app.get("/download/{filename}")
|
|
def download_report(filename: str):
|
|
safe_name = Path(filename).name
|
|
path = OUTPUT_DIR / safe_name
|
|
if not path.exists() or not path.is_file():
|
|
raise HTTPException(status_code=404, detail="报告不存在")
|
|
|
|
media_type = "application/octet-stream"
|
|
if path.suffix == ".docx":
|
|
media_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
elif path.suffix == ".md":
|
|
media_type = "text/markdown; charset=utf-8"
|
|
|
|
return FileResponse(path, filename=path.name, media_type=media_type)
|
|
|
|
|
|
@app.post("/cleanup")
|
|
def cleanup_runtime_files() -> dict[str, int]:
|
|
removed = 0
|
|
for directory in (UPLOAD_DIR, OUTPUT_DIR):
|
|
if not directory.exists():
|
|
continue
|
|
for path in directory.iterdir():
|
|
if path.is_file():
|
|
path.unlink()
|
|
removed += 1
|
|
elif path.is_dir():
|
|
shutil.rmtree(path)
|
|
removed += 1
|
|
return {"removed": removed}
|