Files
CGA-bench/web_server.py
2026-05-22 10:02:42 +08:00

432 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CorrectBench Flask Web Server
提供 Web 界面来管理和监控 CorrectBench 任务
"""
import os
import sys
import json
import subprocess
import threading
import queue
import time
import yaml
from flask import Flask, render_template, jsonify, request, Response, send_from_directory
# 项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
app = Flask(__name__,
template_folder=os.path.join(PROJECT_ROOT, "web", "templates"),
static_folder=os.path.join(PROJECT_ROOT, "web", "static"))
# ========== 全局状态 ==========
class TaskState:
"""全局任务状态管理"""
MAX_LOG_LINES = 5000 # 内存中保留的最大日志行数
def __init__(self):
self.running = False
self.process = None
self.log_queue = queue.Queue()
self.stop_event = threading.Event()
self.log_lines = []
self.config_path = ""
self.start_time = None
self._lock = threading.Lock()
def add_log_line(self, line):
"""添加日志行,超过上限时裁剪旧日志"""
self.log_lines.append(line)
if len(self.log_lines) > self.MAX_LOG_LINES:
self.log_lines = self.log_lines[-self.MAX_LOG_LINES:]
def reset(self):
self.running = False
self.process = None
self.log_queue = queue.Queue()
self.stop_event = threading.Event()
self.log_lines = []
self.config_path = ""
self.start_time = None
state = TaskState()
# ========== 配置文件管理 ==========
def scan_config_files():
"""扫描所有配置文件,返回 {name: path} 字典"""
configs = {}
# 扫描 config/configs/ 目录
configs_dir = os.path.join(PROJECT_ROOT, "config", "configs")
if os.path.exists(configs_dir):
for f in sorted(os.listdir(configs_dir)):
if f.endswith(('.yaml', '.yml')):
name = f"configs/{f}"
configs[name] = os.path.join(configs_dir, f)
# 扫描 config/ 根目录下的 yaml 文件
config_dir = os.path.join(PROJECT_ROOT, "config")
if os.path.exists(config_dir):
for f in sorted(os.listdir(config_dir)):
if f.endswith(('.yaml', '.yml')) and f != 'default.yaml':
name = f
configs[name] = os.path.join(config_dir, f)
return configs
def load_config_info(config_path):
"""加载配置文件的基本信息"""
if not os.path.exists(config_path):
return {"error": "配置文件不存在"}
try:
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
info = {
"model": cfg.get("gpt", {}).get("model", "未指定"),
"rtlgen_model": cfg.get("gpt", {}).get("rtlgen_model", "未指定"),
"mode": cfg.get("run", {}).get("mode", "未指定"),
"tasks": cfg.get("autoline", {}).get("probset", {}).get("only", []),
"itermax": cfg.get("autoline", {}).get("itermax", "未指定"),
}
# 如果没有 only 字段,显示数据集路径
if not info["tasks"]:
data_path = cfg.get("autoline", {}).get("probset", {}).get("path", "")
info["dataset"] = data_path if data_path else "未指定"
return info
except Exception as e:
return {"error": str(e)}
# ========== 任务进程管理 ==========
def run_process(config_path):
"""后台线程运行生成进程"""
global state
state.running = True
state.start_time = time.time()
state.config_path = config_path
try:
state.process = subprocess.Popen(
[sys.executable, "main.py", "-c", config_path],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
cwd=PROJECT_ROOT
)
# 启动日志读取线程
def _read_output():
for line in iter(state.process.stdout.readline, ''):
if state.stop_event.is_set():
break
if line:
state.log_queue.put(line)
state.add_log_line(line.rstrip('\n'))
reader_thread = threading.Thread(target=_read_output, daemon=True)
reader_thread.start()
# 等待进程结束,同时检测 stop_event
while state.process.poll() is None:
if state.stop_event.is_set():
state.process.terminate()
try:
state.process.wait(timeout=3)
except subprocess.TimeoutExpired:
state.process.kill()
state.process.wait()
break
time.sleep(0.1)
# 等待读取线程结束,确保所有输出都已入队
reader_thread.join(timeout=5)
# 读取可能残留的输出
for line in iter(state.process.stdout.readline, ''):
if line:
state.log_queue.put(line)
state.add_log_line(line.rstrip('\n'))
except Exception as e:
state.log_queue.put(f"[错误] {str(e)}\n")
state.add_log_line(f"[错误] {str(e)}")
finally:
if state.stop_event.is_set():
state.log_queue.put("[已停止]\n")
state.add_log_line("[已停止]")
else:
state.log_queue.put("[完成]\n")
state.add_log_line("[完成]")
state.running = False
state.process = None
# ========== Flask 路由 ==========
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/configs')
def get_configs():
"""获取所有配置文件列表"""
configs = scan_config_files()
result = []
for name, path in configs.items():
info = load_config_info(path)
result.append({
"name": name,
"path": path,
"info": info
})
return jsonify(result)
@app.route('/api/config/<path:config_name>')
def get_config_detail(config_name):
"""获取配置文件详情"""
configs = scan_config_files()
# 查找匹配的配置
for name, path in configs.items():
if name == config_name:
info = load_config_info(path)
try:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return jsonify({"name": name, "path": path, "info": info, "content": content})
except Exception as e:
return jsonify({"error": str(e)}), 500
return jsonify({"error": "配置文件不存在"}), 404
@app.route('/api/start', methods=['POST'])
def start_task():
"""启动任务"""
global state
if state.running:
return jsonify({"error": "任务正在运行中"}), 400
data = request.get_json() or {}
config_name = data.get('config', '')
configs = scan_config_files()
config_path = configs.get(config_name)
if not config_path:
return jsonify({"error": f"配置文件不存在: {config_name}"}), 400
# 重置状态
state.reset()
state.stop_event.clear()
# 启动后台线程
thread = threading.Thread(target=run_process, args=(config_path,), daemon=True)
thread.start()
return jsonify({"status": "started", "config": config_name, "path": config_path})
@app.route('/api/stop', methods=['POST'])
def stop_task():
"""停止任务"""
global state
if not state.running:
return jsonify({"error": "没有正在运行的任务"}), 400
state.stop_event.set()
return jsonify({"status": "stopping"})
@app.route('/api/status')
def get_status():
"""获取当前任务状态"""
global state
elapsed = 0
if state.running and state.start_time:
elapsed = round(time.time() - state.start_time, 1)
return jsonify({
"running": state.running,
"config": state.config_path,
"elapsed": elapsed,
"log_count": len(state.log_lines),
})
@app.route('/api/logs')
def get_logs():
"""获取日志SSE 流式推送),支持长时间运行"""
def generate():
was_running = state.running
idle_count = 0 # 空闲计数,用于检测进程结束
while True:
try:
line = state.log_queue.get(timeout=2.0)
idle_count = 0
# SSE 格式
yield f"data: {json.dumps({'line': line.rstrip(chr(10))})}\n\n"
# 检测结束标记
if line.rstrip(chr(10)) in ('[完成]', '[已停止]'):
break
except queue.Empty:
idle_count += 1
# 进程已结束,关闭连接
if was_running and not state.running:
# 确保结束标记已发送
yield f"data: {json.dumps({'line': '[完成]'})}\n\n"
break
# 发送心跳,防止浏览器/代理超时断开
yield f": heartbeat\n\n"
return Response(generate(), mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no',
'Connection': 'keep-alive',
})
@app.route('/api/logs/snapshot')
def get_logs_snapshot():
"""获取日志快照(非流式)"""
lines = state.log_lines[-500:]
return jsonify({"lines": lines, "total": len(state.log_lines)})
# ========== 结果浏览 API ==========
SAVES_DIR = os.path.join(PROJECT_ROOT, "saves")
@app.route('/api/results/runs')
def list_runs():
"""列出所有运行记录(扫描 saves 目录)"""
if not os.path.exists(SAVES_DIR):
return jsonify([])
runs = []
# 遍历 saves 下的周目录
for week_dir in sorted(os.listdir(SAVES_DIR), reverse=True):
week_path = os.path.join(SAVES_DIR, week_dir)
if not os.path.isdir(week_path):
continue
# 遍历子目录(可能是 Main_Results/CorrectBench/ 或直接是运行目录)
for root, dirs, files in os.walk(week_path):
# 检查是否包含 run_info.json 或 final_TB.v 的父目录
run_info_path = os.path.join(root, "run_info.json")
if os.path.exists(run_info_path):
# 这是一个 task 目录
rel_path = os.path.relpath(root, SAVES_DIR)
try:
with open(run_info_path, 'r', encoding='utf-8') as f:
info = json.load(f)
runs.append({
"path": rel_path,
"task_id": info.get("task_id", os.path.basename(root)),
"coverage": info.get("coverage"),
"full_pass": info.get("full_pass"),
"time": info.get("time"),
"token_cost": info.get("token_cost"),
"circuit_type": info.get("circuit_type"),
"op_record": info.get("op_record", []),
})
except Exception:
runs.append({
"path": rel_path,
"task_id": os.path.basename(root),
"coverage": None,
"full_pass": None,
"time": None,
"token_cost": None,
"circuit_type": None,
"op_record": [],
})
return jsonify(runs)
@app.route('/api/results/task/<path:task_path>')
def get_task_detail(task_path):
"""获取某个 task 的详细信息,包括 final_TB.v、DUT.v、run_info.json"""
full_path = os.path.join(SAVES_DIR, task_path)
if not os.path.exists(full_path):
return jsonify({"error": "路径不存在"}), 404
result = {"path": task_path, "task_id": os.path.basename(task_path)}
# 读取 run_info.json
run_info_path = os.path.join(full_path, "run_info.json")
if os.path.exists(run_info_path):
try:
with open(run_info_path, 'r', encoding='utf-8') as f:
result["run_info"] = json.load(f)
except Exception:
result["run_info"] = {"error": "读取失败"}
# 读取 final_TB.v
final_tb_path = os.path.join(full_path, "final_TB.v")
if os.path.exists(final_tb_path):
try:
with open(final_tb_path, 'r', encoding='utf-8') as f:
result["final_tb"] = f.read()
except Exception:
result["final_tb"] = "// 读取失败"
# 读取 DUT.v
dut_path = os.path.join(full_path, "DUT.v")
if os.path.exists(dut_path):
try:
with open(dut_path, 'r', encoding='utf-8') as f:
result["dut"] = f.read()
except Exception:
result["dut"] = "// 读取失败"
# 读取 final_TB.py
final_tb_py_path = os.path.join(full_path, "final_TB.py")
if os.path.exists(final_tb_py_path):
try:
with open(final_tb_py_path, 'r', encoding='utf-8') as f:
result["final_tb_py"] = f.read()
except Exception:
result["final_tb_py"] = "# 读取失败"
# 列出子目录(阶段目录)
subdirs = []
for item in sorted(os.listdir(full_path)):
item_path = os.path.join(full_path, item)
if os.path.isdir(item_path):
subdirs.append(item)
result["subdirs"] = subdirs
return jsonify(result)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='CorrectBench Web Server')
parser.add_argument('--port', type=int, default=5000, help='服务端口 (默认: 5000)')
parser.add_argument('--host', type=str, default='0.0.0.0', help='监听地址 (默认: 0.0.0.0)')
parser.add_argument('--debug', action='store_true', help='启用调试模式 (注意: debug模式下文件变更会重启服务器)')
args = parser.parse_args()
print(f"🛠️ CorrectBench Web Server")
print(f" 地址: http://{args.host}:{args.port}")
print(f" 调试: {'开启' if args.debug else '关闭'}")
print(f" 提示: 长时间运行任务请勿开启 debug 模式")
app.run(host=args.host, port=args.port, debug=args.debug, threaded=True)