""" CorrectBench Flask Web Server 提供 Web 界面来管理和监控 CorrectBench 任务 """ import os import sys import json import subprocess import threading import queue import time import yaml from flask import Flask, render_template, jsonify, request, Response, send_from_directory # 项目根目录 PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) app = Flask(__name__, template_folder=os.path.join(PROJECT_ROOT, "web", "templates"), static_folder=os.path.join(PROJECT_ROOT, "web", "static")) # ========== 全局状态 ========== class TaskState: """全局任务状态管理""" MAX_LOG_LINES = 5000 # 内存中保留的最大日志行数 def __init__(self): self.running = False self.process = None self.log_queue = queue.Queue() self.stop_event = threading.Event() self.log_lines = [] self.config_path = "" self.start_time = None self._lock = threading.Lock() def add_log_line(self, line): """添加日志行,超过上限时裁剪旧日志""" self.log_lines.append(line) if len(self.log_lines) > self.MAX_LOG_LINES: self.log_lines = self.log_lines[-self.MAX_LOG_LINES:] def reset(self): self.running = False self.process = None self.log_queue = queue.Queue() self.stop_event = threading.Event() self.log_lines = [] self.config_path = "" self.start_time = None state = TaskState() # ========== 配置文件管理 ========== def scan_config_files(): """扫描所有配置文件,返回 {name: path} 字典""" configs = {} # 扫描 config/configs/ 目录 configs_dir = os.path.join(PROJECT_ROOT, "config", "configs") if os.path.exists(configs_dir): for f in sorted(os.listdir(configs_dir)): if f.endswith(('.yaml', '.yml')): name = f"configs/{f}" configs[name] = os.path.join(configs_dir, f) # 扫描 config/ 根目录下的 yaml 文件 config_dir = os.path.join(PROJECT_ROOT, "config") if os.path.exists(config_dir): for f in sorted(os.listdir(config_dir)): if f.endswith(('.yaml', '.yml')) and f != 'default.yaml': name = f configs[name] = os.path.join(config_dir, f) return configs def load_config_info(config_path): """加载配置文件的基本信息""" if not os.path.exists(config_path): return {"error": "配置文件不存在"} try: with open(config_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) info = { "model": cfg.get("gpt", {}).get("model", "未指定"), "rtlgen_model": cfg.get("gpt", {}).get("rtlgen_model", "未指定"), "mode": cfg.get("run", {}).get("mode", "未指定"), "tasks": cfg.get("autoline", {}).get("probset", {}).get("only", []), "itermax": cfg.get("autoline", {}).get("itermax", "未指定"), } # 如果没有 only 字段,显示数据集路径 if not info["tasks"]: data_path = cfg.get("autoline", {}).get("probset", {}).get("path", "") info["dataset"] = data_path if data_path else "未指定" return info except Exception as e: return {"error": str(e)} # ========== 任务进程管理 ========== def run_process(config_path): """后台线程运行生成进程""" global state state.running = True state.start_time = time.time() state.config_path = config_path try: state.process = subprocess.Popen( [sys.executable, "main.py", "-c", config_path], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, cwd=PROJECT_ROOT ) # 启动日志读取线程 def _read_output(): for line in iter(state.process.stdout.readline, ''): if state.stop_event.is_set(): break if line: state.log_queue.put(line) state.add_log_line(line.rstrip('\n')) reader_thread = threading.Thread(target=_read_output, daemon=True) reader_thread.start() # 等待进程结束,同时检测 stop_event while state.process.poll() is None: if state.stop_event.is_set(): state.process.terminate() try: state.process.wait(timeout=3) except subprocess.TimeoutExpired: state.process.kill() state.process.wait() break time.sleep(0.1) # 等待读取线程结束,确保所有输出都已入队 reader_thread.join(timeout=5) # 读取可能残留的输出 for line in iter(state.process.stdout.readline, ''): if line: state.log_queue.put(line) state.add_log_line(line.rstrip('\n')) except Exception as e: state.log_queue.put(f"[错误] {str(e)}\n") state.add_log_line(f"[错误] {str(e)}") finally: if state.stop_event.is_set(): state.log_queue.put("[已停止]\n") state.add_log_line("[已停止]") else: state.log_queue.put("[完成]\n") state.add_log_line("[完成]") state.running = False state.process = None # ========== Flask 路由 ========== @app.route('/') def index(): """主页""" return render_template('index.html') @app.route('/api/configs') def get_configs(): """获取所有配置文件列表""" configs = scan_config_files() result = [] for name, path in configs.items(): info = load_config_info(path) result.append({ "name": name, "path": path, "info": info }) return jsonify(result) @app.route('/api/config/') def get_config_detail(config_name): """获取配置文件详情""" configs = scan_config_files() # 查找匹配的配置 for name, path in configs.items(): if name == config_name: info = load_config_info(path) try: with open(path, 'r', encoding='utf-8') as f: content = f.read() return jsonify({"name": name, "path": path, "info": info, "content": content}) except Exception as e: return jsonify({"error": str(e)}), 500 return jsonify({"error": "配置文件不存在"}), 404 @app.route('/api/start', methods=['POST']) def start_task(): """启动任务""" global state if state.running: return jsonify({"error": "任务正在运行中"}), 400 data = request.get_json() or {} config_name = data.get('config', '') configs = scan_config_files() config_path = configs.get(config_name) if not config_path: return jsonify({"error": f"配置文件不存在: {config_name}"}), 400 # 重置状态 state.reset() state.stop_event.clear() # 启动后台线程 thread = threading.Thread(target=run_process, args=(config_path,), daemon=True) thread.start() return jsonify({"status": "started", "config": config_name, "path": config_path}) @app.route('/api/stop', methods=['POST']) def stop_task(): """停止任务""" global state if not state.running: return jsonify({"error": "没有正在运行的任务"}), 400 state.stop_event.set() return jsonify({"status": "stopping"}) @app.route('/api/status') def get_status(): """获取当前任务状态""" global state elapsed = 0 if state.running and state.start_time: elapsed = round(time.time() - state.start_time, 1) return jsonify({ "running": state.running, "config": state.config_path, "elapsed": elapsed, "log_count": len(state.log_lines), }) @app.route('/api/logs') def get_logs(): """获取日志(SSE 流式推送),支持长时间运行""" def generate(): was_running = state.running idle_count = 0 # 空闲计数,用于检测进程结束 while True: try: line = state.log_queue.get(timeout=2.0) idle_count = 0 # SSE 格式 yield f"data: {json.dumps({'line': line.rstrip(chr(10))})}\n\n" # 检测结束标记 if line.rstrip(chr(10)) in ('[完成]', '[已停止]'): break except queue.Empty: idle_count += 1 # 进程已结束,关闭连接 if was_running and not state.running: # 确保结束标记已发送 yield f"data: {json.dumps({'line': '[完成]'})}\n\n" break # 发送心跳,防止浏览器/代理超时断开 yield f": heartbeat\n\n" return Response(generate(), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive', }) @app.route('/api/logs/snapshot') def get_logs_snapshot(): """获取日志快照(非流式)""" lines = state.log_lines[-500:] return jsonify({"lines": lines, "total": len(state.log_lines)}) # ========== 结果浏览 API ========== SAVES_DIR = os.path.join(PROJECT_ROOT, "saves") @app.route('/api/results/runs') def list_runs(): """列出所有运行记录(扫描 saves 目录)""" if not os.path.exists(SAVES_DIR): return jsonify([]) runs = [] # 遍历 saves 下的周目录 for week_dir in sorted(os.listdir(SAVES_DIR), reverse=True): week_path = os.path.join(SAVES_DIR, week_dir) if not os.path.isdir(week_path): continue # 遍历子目录(可能是 Main_Results/CorrectBench/ 或直接是运行目录) for root, dirs, files in os.walk(week_path): # 检查是否包含 run_info.json 或 final_TB.v 的父目录 run_info_path = os.path.join(root, "run_info.json") if os.path.exists(run_info_path): # 这是一个 task 目录 rel_path = os.path.relpath(root, SAVES_DIR) try: with open(run_info_path, 'r', encoding='utf-8') as f: info = json.load(f) runs.append({ "path": rel_path, "task_id": info.get("task_id", os.path.basename(root)), "coverage": info.get("coverage"), "full_pass": info.get("full_pass"), "time": info.get("time"), "token_cost": info.get("token_cost"), "circuit_type": info.get("circuit_type"), "op_record": info.get("op_record", []), }) except Exception: runs.append({ "path": rel_path, "task_id": os.path.basename(root), "coverage": None, "full_pass": None, "time": None, "token_cost": None, "circuit_type": None, "op_record": [], }) return jsonify(runs) @app.route('/api/results/task/') def get_task_detail(task_path): """获取某个 task 的详细信息,包括 final_TB.v、DUT.v、run_info.json""" full_path = os.path.join(SAVES_DIR, task_path) if not os.path.exists(full_path): return jsonify({"error": "路径不存在"}), 404 result = {"path": task_path, "task_id": os.path.basename(task_path)} # 读取 run_info.json run_info_path = os.path.join(full_path, "run_info.json") if os.path.exists(run_info_path): try: with open(run_info_path, 'r', encoding='utf-8') as f: result["run_info"] = json.load(f) except Exception: result["run_info"] = {"error": "读取失败"} # 读取 final_TB.v final_tb_path = os.path.join(full_path, "final_TB.v") if os.path.exists(final_tb_path): try: with open(final_tb_path, 'r', encoding='utf-8') as f: result["final_tb"] = f.read() except Exception: result["final_tb"] = "// 读取失败" # 读取 DUT.v dut_path = os.path.join(full_path, "DUT.v") if os.path.exists(dut_path): try: with open(dut_path, 'r', encoding='utf-8') as f: result["dut"] = f.read() except Exception: result["dut"] = "// 读取失败" # 读取 final_TB.py final_tb_py_path = os.path.join(full_path, "final_TB.py") if os.path.exists(final_tb_py_path): try: with open(final_tb_py_path, 'r', encoding='utf-8') as f: result["final_tb_py"] = f.read() except Exception: result["final_tb_py"] = "# 读取失败" # 列出子目录(阶段目录) subdirs = [] for item in sorted(os.listdir(full_path)): item_path = os.path.join(full_path, item) if os.path.isdir(item_path): subdirs.append(item) result["subdirs"] = subdirs return jsonify(result) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='CorrectBench Web Server') parser.add_argument('--port', type=int, default=5000, help='服务端口 (默认: 5000)') parser.add_argument('--host', type=str, default='0.0.0.0', help='监听地址 (默认: 0.0.0.0)') parser.add_argument('--debug', action='store_true', help='启用调试模式 (注意: debug模式下文件变更会重启服务器)') args = parser.parse_args() print(f"🛠️ CorrectBench Web Server") print(f" 地址: http://{args.host}:{args.port}") print(f" 调试: {'开启' if args.debug else '关闭'}") print(f" 提示: 长时间运行任务请勿开启 debug 模式") app.run(host=args.host, port=args.port, debug=args.debug, threaded=True)