432 lines
14 KiB
Python
432 lines
14 KiB
Python
"""
|
||
CorrectBench Flask Web Server
|
||
提供 Web 界面来管理和监控 CorrectBench 任务
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import subprocess
|
||
import threading
|
||
import queue
|
||
import time
|
||
import yaml
|
||
|
||
from flask import Flask, render_template, jsonify, request, Response, send_from_directory
|
||
|
||
# 项目根目录
|
||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
app = Flask(__name__,
|
||
template_folder=os.path.join(PROJECT_ROOT, "web", "templates"),
|
||
static_folder=os.path.join(PROJECT_ROOT, "web", "static"))
|
||
|
||
# ========== 全局状态 ==========
|
||
class TaskState:
|
||
"""全局任务状态管理"""
|
||
MAX_LOG_LINES = 5000 # 内存中保留的最大日志行数
|
||
|
||
def __init__(self):
|
||
self.running = False
|
||
self.process = None
|
||
self.log_queue = queue.Queue()
|
||
self.stop_event = threading.Event()
|
||
self.log_lines = []
|
||
self.config_path = ""
|
||
self.start_time = None
|
||
self._lock = threading.Lock()
|
||
|
||
def add_log_line(self, line):
|
||
"""添加日志行,超过上限时裁剪旧日志"""
|
||
self.log_lines.append(line)
|
||
if len(self.log_lines) > self.MAX_LOG_LINES:
|
||
self.log_lines = self.log_lines[-self.MAX_LOG_LINES:]
|
||
|
||
def reset(self):
|
||
self.running = False
|
||
self.process = None
|
||
self.log_queue = queue.Queue()
|
||
self.stop_event = threading.Event()
|
||
self.log_lines = []
|
||
self.config_path = ""
|
||
self.start_time = None
|
||
|
||
state = TaskState()
|
||
|
||
|
||
# ========== 配置文件管理 ==========
|
||
def scan_config_files():
|
||
"""扫描所有配置文件,返回 {name: path} 字典"""
|
||
configs = {}
|
||
|
||
# 扫描 config/configs/ 目录
|
||
configs_dir = os.path.join(PROJECT_ROOT, "config", "configs")
|
||
if os.path.exists(configs_dir):
|
||
for f in sorted(os.listdir(configs_dir)):
|
||
if f.endswith(('.yaml', '.yml')):
|
||
name = f"configs/{f}"
|
||
configs[name] = os.path.join(configs_dir, f)
|
||
|
||
# 扫描 config/ 根目录下的 yaml 文件
|
||
config_dir = os.path.join(PROJECT_ROOT, "config")
|
||
if os.path.exists(config_dir):
|
||
for f in sorted(os.listdir(config_dir)):
|
||
if f.endswith(('.yaml', '.yml')) and f != 'default.yaml':
|
||
name = f
|
||
configs[name] = os.path.join(config_dir, f)
|
||
|
||
return configs
|
||
|
||
|
||
def load_config_info(config_path):
|
||
"""加载配置文件的基本信息"""
|
||
if not os.path.exists(config_path):
|
||
return {"error": "配置文件不存在"}
|
||
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
cfg = yaml.safe_load(f)
|
||
|
||
info = {
|
||
"model": cfg.get("gpt", {}).get("model", "未指定"),
|
||
"rtlgen_model": cfg.get("gpt", {}).get("rtlgen_model", "未指定"),
|
||
"mode": cfg.get("run", {}).get("mode", "未指定"),
|
||
"tasks": cfg.get("autoline", {}).get("probset", {}).get("only", []),
|
||
"itermax": cfg.get("autoline", {}).get("itermax", "未指定"),
|
||
}
|
||
|
||
# 如果没有 only 字段,显示数据集路径
|
||
if not info["tasks"]:
|
||
data_path = cfg.get("autoline", {}).get("probset", {}).get("path", "")
|
||
info["dataset"] = data_path if data_path else "未指定"
|
||
|
||
return info
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
|
||
# ========== 任务进程管理 ==========
|
||
def run_process(config_path):
|
||
"""后台线程运行生成进程"""
|
||
global state
|
||
|
||
state.running = True
|
||
state.start_time = time.time()
|
||
state.config_path = config_path
|
||
|
||
try:
|
||
state.process = subprocess.Popen(
|
||
[sys.executable, "main.py", "-c", config_path],
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.STDOUT,
|
||
text=True,
|
||
bufsize=1,
|
||
cwd=PROJECT_ROOT
|
||
)
|
||
|
||
# 启动日志读取线程
|
||
def _read_output():
|
||
for line in iter(state.process.stdout.readline, ''):
|
||
if state.stop_event.is_set():
|
||
break
|
||
if line:
|
||
state.log_queue.put(line)
|
||
state.add_log_line(line.rstrip('\n'))
|
||
|
||
reader_thread = threading.Thread(target=_read_output, daemon=True)
|
||
reader_thread.start()
|
||
|
||
# 等待进程结束,同时检测 stop_event
|
||
while state.process.poll() is None:
|
||
if state.stop_event.is_set():
|
||
state.process.terminate()
|
||
try:
|
||
state.process.wait(timeout=3)
|
||
except subprocess.TimeoutExpired:
|
||
state.process.kill()
|
||
state.process.wait()
|
||
break
|
||
time.sleep(0.1)
|
||
|
||
# 等待读取线程结束,确保所有输出都已入队
|
||
reader_thread.join(timeout=5)
|
||
|
||
# 读取可能残留的输出
|
||
for line in iter(state.process.stdout.readline, ''):
|
||
if line:
|
||
state.log_queue.put(line)
|
||
state.add_log_line(line.rstrip('\n'))
|
||
|
||
except Exception as e:
|
||
state.log_queue.put(f"[错误] {str(e)}\n")
|
||
state.add_log_line(f"[错误] {str(e)}")
|
||
finally:
|
||
if state.stop_event.is_set():
|
||
state.log_queue.put("[已停止]\n")
|
||
state.add_log_line("[已停止]")
|
||
else:
|
||
state.log_queue.put("[完成]\n")
|
||
state.add_log_line("[完成]")
|
||
state.running = False
|
||
state.process = None
|
||
|
||
|
||
# ========== Flask 路由 ==========
|
||
@app.route('/')
|
||
def index():
|
||
"""主页"""
|
||
return render_template('index.html')
|
||
|
||
|
||
@app.route('/api/configs')
|
||
def get_configs():
|
||
"""获取所有配置文件列表"""
|
||
configs = scan_config_files()
|
||
result = []
|
||
for name, path in configs.items():
|
||
info = load_config_info(path)
|
||
result.append({
|
||
"name": name,
|
||
"path": path,
|
||
"info": info
|
||
})
|
||
return jsonify(result)
|
||
|
||
|
||
@app.route('/api/config/<path:config_name>')
|
||
def get_config_detail(config_name):
|
||
"""获取配置文件详情"""
|
||
configs = scan_config_files()
|
||
# 查找匹配的配置
|
||
for name, path in configs.items():
|
||
if name == config_name:
|
||
info = load_config_info(path)
|
||
try:
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return jsonify({"name": name, "path": path, "info": info, "content": content})
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
return jsonify({"error": "配置文件不存在"}), 404
|
||
|
||
|
||
@app.route('/api/start', methods=['POST'])
|
||
def start_task():
|
||
"""启动任务"""
|
||
global state
|
||
|
||
if state.running:
|
||
return jsonify({"error": "任务正在运行中"}), 400
|
||
|
||
data = request.get_json() or {}
|
||
config_name = data.get('config', '')
|
||
|
||
configs = scan_config_files()
|
||
config_path = configs.get(config_name)
|
||
|
||
if not config_path:
|
||
return jsonify({"error": f"配置文件不存在: {config_name}"}), 400
|
||
|
||
# 重置状态
|
||
state.reset()
|
||
state.stop_event.clear()
|
||
|
||
# 启动后台线程
|
||
thread = threading.Thread(target=run_process, args=(config_path,), daemon=True)
|
||
thread.start()
|
||
|
||
return jsonify({"status": "started", "config": config_name, "path": config_path})
|
||
|
||
|
||
@app.route('/api/stop', methods=['POST'])
|
||
def stop_task():
|
||
"""停止任务"""
|
||
global state
|
||
|
||
if not state.running:
|
||
return jsonify({"error": "没有正在运行的任务"}), 400
|
||
|
||
state.stop_event.set()
|
||
return jsonify({"status": "stopping"})
|
||
|
||
|
||
@app.route('/api/status')
|
||
def get_status():
|
||
"""获取当前任务状态"""
|
||
global state
|
||
|
||
elapsed = 0
|
||
if state.running and state.start_time:
|
||
elapsed = round(time.time() - state.start_time, 1)
|
||
|
||
return jsonify({
|
||
"running": state.running,
|
||
"config": state.config_path,
|
||
"elapsed": elapsed,
|
||
"log_count": len(state.log_lines),
|
||
})
|
||
|
||
|
||
@app.route('/api/logs')
|
||
def get_logs():
|
||
"""获取日志(SSE 流式推送),支持长时间运行"""
|
||
def generate():
|
||
was_running = state.running
|
||
idle_count = 0 # 空闲计数,用于检测进程结束
|
||
while True:
|
||
try:
|
||
line = state.log_queue.get(timeout=2.0)
|
||
idle_count = 0
|
||
# SSE 格式
|
||
yield f"data: {json.dumps({'line': line.rstrip(chr(10))})}\n\n"
|
||
# 检测结束标记
|
||
if line.rstrip(chr(10)) in ('[完成]', '[已停止]'):
|
||
break
|
||
except queue.Empty:
|
||
idle_count += 1
|
||
# 进程已结束,关闭连接
|
||
if was_running and not state.running:
|
||
# 确保结束标记已发送
|
||
yield f"data: {json.dumps({'line': '[完成]'})}\n\n"
|
||
break
|
||
# 发送心跳,防止浏览器/代理超时断开
|
||
yield f": heartbeat\n\n"
|
||
|
||
return Response(generate(), mimetype='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'X-Accel-Buffering': 'no',
|
||
'Connection': 'keep-alive',
|
||
})
|
||
|
||
|
||
@app.route('/api/logs/snapshot')
|
||
def get_logs_snapshot():
|
||
"""获取日志快照(非流式)"""
|
||
lines = state.log_lines[-500:]
|
||
return jsonify({"lines": lines, "total": len(state.log_lines)})
|
||
|
||
|
||
# ========== 结果浏览 API ==========
|
||
SAVES_DIR = os.path.join(PROJECT_ROOT, "saves")
|
||
|
||
|
||
@app.route('/api/results/runs')
|
||
def list_runs():
|
||
"""列出所有运行记录(扫描 saves 目录)"""
|
||
if not os.path.exists(SAVES_DIR):
|
||
return jsonify([])
|
||
|
||
runs = []
|
||
# 遍历 saves 下的周目录
|
||
for week_dir in sorted(os.listdir(SAVES_DIR), reverse=True):
|
||
week_path = os.path.join(SAVES_DIR, week_dir)
|
||
if not os.path.isdir(week_path):
|
||
continue
|
||
|
||
# 遍历子目录(可能是 Main_Results/CorrectBench/ 或直接是运行目录)
|
||
for root, dirs, files in os.walk(week_path):
|
||
# 检查是否包含 run_info.json 或 final_TB.v 的父目录
|
||
run_info_path = os.path.join(root, "run_info.json")
|
||
if os.path.exists(run_info_path):
|
||
# 这是一个 task 目录
|
||
rel_path = os.path.relpath(root, SAVES_DIR)
|
||
try:
|
||
with open(run_info_path, 'r', encoding='utf-8') as f:
|
||
info = json.load(f)
|
||
runs.append({
|
||
"path": rel_path,
|
||
"task_id": info.get("task_id", os.path.basename(root)),
|
||
"coverage": info.get("coverage"),
|
||
"full_pass": info.get("full_pass"),
|
||
"time": info.get("time"),
|
||
"token_cost": info.get("token_cost"),
|
||
"circuit_type": info.get("circuit_type"),
|
||
"op_record": info.get("op_record", []),
|
||
})
|
||
except Exception:
|
||
runs.append({
|
||
"path": rel_path,
|
||
"task_id": os.path.basename(root),
|
||
"coverage": None,
|
||
"full_pass": None,
|
||
"time": None,
|
||
"token_cost": None,
|
||
"circuit_type": None,
|
||
"op_record": [],
|
||
})
|
||
|
||
return jsonify(runs)
|
||
|
||
|
||
@app.route('/api/results/task/<path:task_path>')
|
||
def get_task_detail(task_path):
|
||
"""获取某个 task 的详细信息,包括 final_TB.v、DUT.v、run_info.json"""
|
||
full_path = os.path.join(SAVES_DIR, task_path)
|
||
if not os.path.exists(full_path):
|
||
return jsonify({"error": "路径不存在"}), 404
|
||
|
||
result = {"path": task_path, "task_id": os.path.basename(task_path)}
|
||
|
||
# 读取 run_info.json
|
||
run_info_path = os.path.join(full_path, "run_info.json")
|
||
if os.path.exists(run_info_path):
|
||
try:
|
||
with open(run_info_path, 'r', encoding='utf-8') as f:
|
||
result["run_info"] = json.load(f)
|
||
except Exception:
|
||
result["run_info"] = {"error": "读取失败"}
|
||
|
||
# 读取 final_TB.v
|
||
final_tb_path = os.path.join(full_path, "final_TB.v")
|
||
if os.path.exists(final_tb_path):
|
||
try:
|
||
with open(final_tb_path, 'r', encoding='utf-8') as f:
|
||
result["final_tb"] = f.read()
|
||
except Exception:
|
||
result["final_tb"] = "// 读取失败"
|
||
|
||
# 读取 DUT.v
|
||
dut_path = os.path.join(full_path, "DUT.v")
|
||
if os.path.exists(dut_path):
|
||
try:
|
||
with open(dut_path, 'r', encoding='utf-8') as f:
|
||
result["dut"] = f.read()
|
||
except Exception:
|
||
result["dut"] = "// 读取失败"
|
||
|
||
# 读取 final_TB.py
|
||
final_tb_py_path = os.path.join(full_path, "final_TB.py")
|
||
if os.path.exists(final_tb_py_path):
|
||
try:
|
||
with open(final_tb_py_path, 'r', encoding='utf-8') as f:
|
||
result["final_tb_py"] = f.read()
|
||
except Exception:
|
||
result["final_tb_py"] = "# 读取失败"
|
||
|
||
# 列出子目录(阶段目录)
|
||
subdirs = []
|
||
for item in sorted(os.listdir(full_path)):
|
||
item_path = os.path.join(full_path, item)
|
||
if os.path.isdir(item_path):
|
||
subdirs.append(item)
|
||
result["subdirs"] = subdirs
|
||
|
||
return jsonify(result)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description='CorrectBench Web Server')
|
||
parser.add_argument('--port', type=int, default=5000, help='服务端口 (默认: 5000)')
|
||
parser.add_argument('--host', type=str, default='0.0.0.0', help='监听地址 (默认: 0.0.0.0)')
|
||
parser.add_argument('--debug', action='store_true', help='启用调试模式 (注意: debug模式下文件变更会重启服务器)')
|
||
args = parser.parse_args()
|
||
|
||
print(f"🛠️ CorrectBench Web Server")
|
||
print(f" 地址: http://{args.host}:{args.port}")
|
||
print(f" 调试: {'开启' if args.debug else '关闭'}")
|
||
print(f" 提示: 长时间运行任务请勿开启 debug 模式")
|
||
|
||
app.run(host=args.host, port=args.port, debug=args.debug, threaded=True)
|