432 lines
14 KiB
Python
432 lines
14 KiB
Python
|
|
"""
|
|||
|
|
CorrectBench Flask Web Server
|
|||
|
|
提供 Web 界面来管理和监控 CorrectBench 任务
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import subprocess
|
|||
|
|
import threading
|
|||
|
|
import queue
|
|||
|
|
import time
|
|||
|
|
import yaml
|
|||
|
|
|
|||
|
|
from flask import Flask, render_template, jsonify, request, Response, send_from_directory
|
|||
|
|
|
|||
|
|
# 项目根目录
|
|||
|
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
|
|||
|
|
app = Flask(__name__,
|
|||
|
|
template_folder=os.path.join(PROJECT_ROOT, "web", "templates"),
|
|||
|
|
static_folder=os.path.join(PROJECT_ROOT, "web", "static"))
|
|||
|
|
|
|||
|
|
# ========== 全局状态 ==========
|
|||
|
|
class TaskState:
|
|||
|
|
"""全局任务状态管理"""
|
|||
|
|
MAX_LOG_LINES = 5000 # 内存中保留的最大日志行数
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.running = False
|
|||
|
|
self.process = None
|
|||
|
|
self.log_queue = queue.Queue()
|
|||
|
|
self.stop_event = threading.Event()
|
|||
|
|
self.log_lines = []
|
|||
|
|
self.config_path = ""
|
|||
|
|
self.start_time = None
|
|||
|
|
self._lock = threading.Lock()
|
|||
|
|
|
|||
|
|
def add_log_line(self, line):
|
|||
|
|
"""添加日志行,超过上限时裁剪旧日志"""
|
|||
|
|
self.log_lines.append(line)
|
|||
|
|
if len(self.log_lines) > self.MAX_LOG_LINES:
|
|||
|
|
self.log_lines = self.log_lines[-self.MAX_LOG_LINES:]
|
|||
|
|
|
|||
|
|
def reset(self):
|
|||
|
|
self.running = False
|
|||
|
|
self.process = None
|
|||
|
|
self.log_queue = queue.Queue()
|
|||
|
|
self.stop_event = threading.Event()
|
|||
|
|
self.log_lines = []
|
|||
|
|
self.config_path = ""
|
|||
|
|
self.start_time = None
|
|||
|
|
|
|||
|
|
state = TaskState()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 配置文件管理 ==========
|
|||
|
|
def scan_config_files():
|
|||
|
|
"""扫描所有配置文件,返回 {name: path} 字典"""
|
|||
|
|
configs = {}
|
|||
|
|
|
|||
|
|
# 扫描 config/configs/ 目录
|
|||
|
|
configs_dir = os.path.join(PROJECT_ROOT, "config", "configs")
|
|||
|
|
if os.path.exists(configs_dir):
|
|||
|
|
for f in sorted(os.listdir(configs_dir)):
|
|||
|
|
if f.endswith(('.yaml', '.yml')):
|
|||
|
|
name = f"configs/{f}"
|
|||
|
|
configs[name] = os.path.join(configs_dir, f)
|
|||
|
|
|
|||
|
|
# 扫描 config/ 根目录下的 yaml 文件
|
|||
|
|
config_dir = os.path.join(PROJECT_ROOT, "config")
|
|||
|
|
if os.path.exists(config_dir):
|
|||
|
|
for f in sorted(os.listdir(config_dir)):
|
|||
|
|
if f.endswith(('.yaml', '.yml')) and f != 'default.yaml':
|
|||
|
|
name = f
|
|||
|
|
configs[name] = os.path.join(config_dir, f)
|
|||
|
|
|
|||
|
|
return configs
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_config_info(config_path):
|
|||
|
|
"""加载配置文件的基本信息"""
|
|||
|
|
if not os.path.exists(config_path):
|
|||
|
|
return {"error": "配置文件不存在"}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|||
|
|
cfg = yaml.safe_load(f)
|
|||
|
|
|
|||
|
|
info = {
|
|||
|
|
"model": cfg.get("gpt", {}).get("model", "未指定"),
|
|||
|
|
"rtlgen_model": cfg.get("gpt", {}).get("rtlgen_model", "未指定"),
|
|||
|
|
"mode": cfg.get("run", {}).get("mode", "未指定"),
|
|||
|
|
"tasks": cfg.get("autoline", {}).get("probset", {}).get("only", []),
|
|||
|
|
"itermax": cfg.get("autoline", {}).get("itermax", "未指定"),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 如果没有 only 字段,显示数据集路径
|
|||
|
|
if not info["tasks"]:
|
|||
|
|
data_path = cfg.get("autoline", {}).get("probset", {}).get("path", "")
|
|||
|
|
info["dataset"] = data_path if data_path else "未指定"
|
|||
|
|
|
|||
|
|
return info
|
|||
|
|
except Exception as e:
|
|||
|
|
return {"error": str(e)}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 任务进程管理 ==========
|
|||
|
|
def run_process(config_path):
|
|||
|
|
"""后台线程运行生成进程"""
|
|||
|
|
global state
|
|||
|
|
|
|||
|
|
state.running = True
|
|||
|
|
state.start_time = time.time()
|
|||
|
|
state.config_path = config_path
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
state.process = subprocess.Popen(
|
|||
|
|
[sys.executable, "main.py", "-c", config_path],
|
|||
|
|
stdout=subprocess.PIPE,
|
|||
|
|
stderr=subprocess.STDOUT,
|
|||
|
|
text=True,
|
|||
|
|
bufsize=1,
|
|||
|
|
cwd=PROJECT_ROOT
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 启动日志读取线程
|
|||
|
|
def _read_output():
|
|||
|
|
for line in iter(state.process.stdout.readline, ''):
|
|||
|
|
if state.stop_event.is_set():
|
|||
|
|
break
|
|||
|
|
if line:
|
|||
|
|
state.log_queue.put(line)
|
|||
|
|
state.add_log_line(line.rstrip('\n'))
|
|||
|
|
|
|||
|
|
reader_thread = threading.Thread(target=_read_output, daemon=True)
|
|||
|
|
reader_thread.start()
|
|||
|
|
|
|||
|
|
# 等待进程结束,同时检测 stop_event
|
|||
|
|
while state.process.poll() is None:
|
|||
|
|
if state.stop_event.is_set():
|
|||
|
|
state.process.terminate()
|
|||
|
|
try:
|
|||
|
|
state.process.wait(timeout=3)
|
|||
|
|
except subprocess.TimeoutExpired:
|
|||
|
|
state.process.kill()
|
|||
|
|
state.process.wait()
|
|||
|
|
break
|
|||
|
|
time.sleep(0.1)
|
|||
|
|
|
|||
|
|
# 等待读取线程结束,确保所有输出都已入队
|
|||
|
|
reader_thread.join(timeout=5)
|
|||
|
|
|
|||
|
|
# 读取可能残留的输出
|
|||
|
|
for line in iter(state.process.stdout.readline, ''):
|
|||
|
|
if line:
|
|||
|
|
state.log_queue.put(line)
|
|||
|
|
state.add_log_line(line.rstrip('\n'))
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
state.log_queue.put(f"[错误] {str(e)}\n")
|
|||
|
|
state.add_log_line(f"[错误] {str(e)}")
|
|||
|
|
finally:
|
|||
|
|
if state.stop_event.is_set():
|
|||
|
|
state.log_queue.put("[已停止]\n")
|
|||
|
|
state.add_log_line("[已停止]")
|
|||
|
|
else:
|
|||
|
|
state.log_queue.put("[完成]\n")
|
|||
|
|
state.add_log_line("[完成]")
|
|||
|
|
state.running = False
|
|||
|
|
state.process = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== Flask 路由 ==========
|
|||
|
|
@app.route('/')
|
|||
|
|
def index():
|
|||
|
|
"""主页"""
|
|||
|
|
return render_template('index.html')
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/configs')
|
|||
|
|
def get_configs():
|
|||
|
|
"""获取所有配置文件列表"""
|
|||
|
|
configs = scan_config_files()
|
|||
|
|
result = []
|
|||
|
|
for name, path in configs.items():
|
|||
|
|
info = load_config_info(path)
|
|||
|
|
result.append({
|
|||
|
|
"name": name,
|
|||
|
|
"path": path,
|
|||
|
|
"info": info
|
|||
|
|
})
|
|||
|
|
return jsonify(result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/config/<path:config_name>')
|
|||
|
|
def get_config_detail(config_name):
|
|||
|
|
"""获取配置文件详情"""
|
|||
|
|
configs = scan_config_files()
|
|||
|
|
# 查找匹配的配置
|
|||
|
|
for name, path in configs.items():
|
|||
|
|
if name == config_name:
|
|||
|
|
info = load_config_info(path)
|
|||
|
|
try:
|
|||
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|||
|
|
content = f.read()
|
|||
|
|
return jsonify({"name": name, "path": path, "info": info, "content": content})
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
return jsonify({"error": "配置文件不存在"}), 404
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/start', methods=['POST'])
|
|||
|
|
def start_task():
|
|||
|
|
"""启动任务"""
|
|||
|
|
global state
|
|||
|
|
|
|||
|
|
if state.running:
|
|||
|
|
return jsonify({"error": "任务正在运行中"}), 400
|
|||
|
|
|
|||
|
|
data = request.get_json() or {}
|
|||
|
|
config_name = data.get('config', '')
|
|||
|
|
|
|||
|
|
configs = scan_config_files()
|
|||
|
|
config_path = configs.get(config_name)
|
|||
|
|
|
|||
|
|
if not config_path:
|
|||
|
|
return jsonify({"error": f"配置文件不存在: {config_name}"}), 400
|
|||
|
|
|
|||
|
|
# 重置状态
|
|||
|
|
state.reset()
|
|||
|
|
state.stop_event.clear()
|
|||
|
|
|
|||
|
|
# 启动后台线程
|
|||
|
|
thread = threading.Thread(target=run_process, args=(config_path,), daemon=True)
|
|||
|
|
thread.start()
|
|||
|
|
|
|||
|
|
return jsonify({"status": "started", "config": config_name, "path": config_path})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/stop', methods=['POST'])
|
|||
|
|
def stop_task():
|
|||
|
|
"""停止任务"""
|
|||
|
|
global state
|
|||
|
|
|
|||
|
|
if not state.running:
|
|||
|
|
return jsonify({"error": "没有正在运行的任务"}), 400
|
|||
|
|
|
|||
|
|
state.stop_event.set()
|
|||
|
|
return jsonify({"status": "stopping"})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/status')
|
|||
|
|
def get_status():
|
|||
|
|
"""获取当前任务状态"""
|
|||
|
|
global state
|
|||
|
|
|
|||
|
|
elapsed = 0
|
|||
|
|
if state.running and state.start_time:
|
|||
|
|
elapsed = round(time.time() - state.start_time, 1)
|
|||
|
|
|
|||
|
|
return jsonify({
|
|||
|
|
"running": state.running,
|
|||
|
|
"config": state.config_path,
|
|||
|
|
"elapsed": elapsed,
|
|||
|
|
"log_count": len(state.log_lines),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/logs')
|
|||
|
|
def get_logs():
|
|||
|
|
"""获取日志(SSE 流式推送),支持长时间运行"""
|
|||
|
|
def generate():
|
|||
|
|
was_running = state.running
|
|||
|
|
idle_count = 0 # 空闲计数,用于检测进程结束
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
line = state.log_queue.get(timeout=2.0)
|
|||
|
|
idle_count = 0
|
|||
|
|
# SSE 格式
|
|||
|
|
yield f"data: {json.dumps({'line': line.rstrip(chr(10))})}\n\n"
|
|||
|
|
# 检测结束标记
|
|||
|
|
if line.rstrip(chr(10)) in ('[完成]', '[已停止]'):
|
|||
|
|
break
|
|||
|
|
except queue.Empty:
|
|||
|
|
idle_count += 1
|
|||
|
|
# 进程已结束,关闭连接
|
|||
|
|
if was_running and not state.running:
|
|||
|
|
# 确保结束标记已发送
|
|||
|
|
yield f"data: {json.dumps({'line': '[完成]'})}\n\n"
|
|||
|
|
break
|
|||
|
|
# 发送心跳,防止浏览器/代理超时断开
|
|||
|
|
yield f": heartbeat\n\n"
|
|||
|
|
|
|||
|
|
return Response(generate(), mimetype='text/event-stream',
|
|||
|
|
headers={
|
|||
|
|
'Cache-Control': 'no-cache',
|
|||
|
|
'X-Accel-Buffering': 'no',
|
|||
|
|
'Connection': 'keep-alive',
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/logs/snapshot')
|
|||
|
|
def get_logs_snapshot():
|
|||
|
|
"""获取日志快照(非流式)"""
|
|||
|
|
lines = state.log_lines[-500:]
|
|||
|
|
return jsonify({"lines": lines, "total": len(state.log_lines)})
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 结果浏览 API ==========
|
|||
|
|
SAVES_DIR = os.path.join(PROJECT_ROOT, "saves")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/results/runs')
|
|||
|
|
def list_runs():
|
|||
|
|
"""列出所有运行记录(扫描 saves 目录)"""
|
|||
|
|
if not os.path.exists(SAVES_DIR):
|
|||
|
|
return jsonify([])
|
|||
|
|
|
|||
|
|
runs = []
|
|||
|
|
# 遍历 saves 下的周目录
|
|||
|
|
for week_dir in sorted(os.listdir(SAVES_DIR), reverse=True):
|
|||
|
|
week_path = os.path.join(SAVES_DIR, week_dir)
|
|||
|
|
if not os.path.isdir(week_path):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 遍历子目录(可能是 Main_Results/CorrectBench/ 或直接是运行目录)
|
|||
|
|
for root, dirs, files in os.walk(week_path):
|
|||
|
|
# 检查是否包含 run_info.json 或 final_TB.v 的父目录
|
|||
|
|
run_info_path = os.path.join(root, "run_info.json")
|
|||
|
|
if os.path.exists(run_info_path):
|
|||
|
|
# 这是一个 task 目录
|
|||
|
|
rel_path = os.path.relpath(root, SAVES_DIR)
|
|||
|
|
try:
|
|||
|
|
with open(run_info_path, 'r', encoding='utf-8') as f:
|
|||
|
|
info = json.load(f)
|
|||
|
|
runs.append({
|
|||
|
|
"path": rel_path,
|
|||
|
|
"task_id": info.get("task_id", os.path.basename(root)),
|
|||
|
|
"coverage": info.get("coverage"),
|
|||
|
|
"full_pass": info.get("full_pass"),
|
|||
|
|
"time": info.get("time"),
|
|||
|
|
"token_cost": info.get("token_cost"),
|
|||
|
|
"circuit_type": info.get("circuit_type"),
|
|||
|
|
"op_record": info.get("op_record", []),
|
|||
|
|
})
|
|||
|
|
except Exception:
|
|||
|
|
runs.append({
|
|||
|
|
"path": rel_path,
|
|||
|
|
"task_id": os.path.basename(root),
|
|||
|
|
"coverage": None,
|
|||
|
|
"full_pass": None,
|
|||
|
|
"time": None,
|
|||
|
|
"token_cost": None,
|
|||
|
|
"circuit_type": None,
|
|||
|
|
"op_record": [],
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return jsonify(runs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/results/task/<path:task_path>')
|
|||
|
|
def get_task_detail(task_path):
|
|||
|
|
"""获取某个 task 的详细信息,包括 final_TB.v、DUT.v、run_info.json"""
|
|||
|
|
full_path = os.path.join(SAVES_DIR, task_path)
|
|||
|
|
if not os.path.exists(full_path):
|
|||
|
|
return jsonify({"error": "路径不存在"}), 404
|
|||
|
|
|
|||
|
|
result = {"path": task_path, "task_id": os.path.basename(task_path)}
|
|||
|
|
|
|||
|
|
# 读取 run_info.json
|
|||
|
|
run_info_path = os.path.join(full_path, "run_info.json")
|
|||
|
|
if os.path.exists(run_info_path):
|
|||
|
|
try:
|
|||
|
|
with open(run_info_path, 'r', encoding='utf-8') as f:
|
|||
|
|
result["run_info"] = json.load(f)
|
|||
|
|
except Exception:
|
|||
|
|
result["run_info"] = {"error": "读取失败"}
|
|||
|
|
|
|||
|
|
# 读取 final_TB.v
|
|||
|
|
final_tb_path = os.path.join(full_path, "final_TB.v")
|
|||
|
|
if os.path.exists(final_tb_path):
|
|||
|
|
try:
|
|||
|
|
with open(final_tb_path, 'r', encoding='utf-8') as f:
|
|||
|
|
result["final_tb"] = f.read()
|
|||
|
|
except Exception:
|
|||
|
|
result["final_tb"] = "// 读取失败"
|
|||
|
|
|
|||
|
|
# 读取 DUT.v
|
|||
|
|
dut_path = os.path.join(full_path, "DUT.v")
|
|||
|
|
if os.path.exists(dut_path):
|
|||
|
|
try:
|
|||
|
|
with open(dut_path, 'r', encoding='utf-8') as f:
|
|||
|
|
result["dut"] = f.read()
|
|||
|
|
except Exception:
|
|||
|
|
result["dut"] = "// 读取失败"
|
|||
|
|
|
|||
|
|
# 读取 final_TB.py
|
|||
|
|
final_tb_py_path = os.path.join(full_path, "final_TB.py")
|
|||
|
|
if os.path.exists(final_tb_py_path):
|
|||
|
|
try:
|
|||
|
|
with open(final_tb_py_path, 'r', encoding='utf-8') as f:
|
|||
|
|
result["final_tb_py"] = f.read()
|
|||
|
|
except Exception:
|
|||
|
|
result["final_tb_py"] = "# 读取失败"
|
|||
|
|
|
|||
|
|
# 列出子目录(阶段目录)
|
|||
|
|
subdirs = []
|
|||
|
|
for item in sorted(os.listdir(full_path)):
|
|||
|
|
item_path = os.path.join(full_path, item)
|
|||
|
|
if os.path.isdir(item_path):
|
|||
|
|
subdirs.append(item)
|
|||
|
|
result["subdirs"] = subdirs
|
|||
|
|
|
|||
|
|
return jsonify(result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
import argparse
|
|||
|
|
parser = argparse.ArgumentParser(description='CorrectBench Web Server')
|
|||
|
|
parser.add_argument('--port', type=int, default=5000, help='服务端口 (默认: 5000)')
|
|||
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='监听地址 (默认: 0.0.0.0)')
|
|||
|
|
parser.add_argument('--debug', action='store_true', help='启用调试模式 (注意: debug模式下文件变更会重启服务器)')
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
print(f"🛠️ CorrectBench Web Server")
|
|||
|
|
print(f" 地址: http://{args.host}:{args.port}")
|
|||
|
|
print(f" 调试: {'开启' if args.debug else '关闭'}")
|
|||
|
|
print(f" 提示: 长时间运行任务请勿开启 debug 模式")
|
|||
|
|
|
|||
|
|
app.run(host=args.host, port=args.port, debug=args.debug, threaded=True)
|