Files
CGA-bench/frontend/config_loader.py
2026-05-22 10:02:42 +08:00

93 lines
2.8 KiB
Python

import os
import yaml
from typing import List, Dict, Any
CONFIG_DIR = "config"
CONFIGS_DIR = os.path.join(CONFIG_DIR, "configs")
def get_config_files() -> List[str]:
"""返回 config/configs/ 目录下的所有 yaml 配置文件"""
if not os.path.exists(CONFIGS_DIR):
return []
files = [f for f in os.listdir(CONFIGS_DIR) if f.endswith(('.yaml', '.yml'))]
return sorted(files)
def get_config_path(config_name: str) -> str:
"""根据配置文件名返回完整路径"""
return os.path.join(CONFIGS_DIR, config_name)
def load_tasks_from_config(config_path: str) -> List[str]:
"""从配置文件中加载指定的题目列表"""
if not os.path.exists(config_path):
return []
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
probset = config.get('autoline', {}).get('probset', {})
tasks = probset.get('only', [])
# 如果配置中没有 only 字段,尝试从数据集获取所有题目
if not tasks:
data_path = probset.get('path', '')
if data_path and os.path.exists(data_path):
tasks = get_all_tasks_from_dataset(data_path)
return tasks
def get_all_tasks_from_dataset(jsonl_path: str) -> List[str]:
"""从 JSONL 数据集文件中获取所有题目 ID"""
tasks = []
try:
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
import json
obj = json.loads(line)
task_id = obj.get('task_id', '')
if task_id:
tasks.append(task_id)
except Exception as e:
print(f"Error loading tasks from {jsonl_path}: {e}")
return tasks
def create_temp_config(
original_config_path: str,
selected_tasks: List[str],
cga_enabled: bool = True,
max_iter: int = 10,
output_path: str = "/tmp/correctbench_temp_config.yaml"
) -> str:
"""创建临时配置文件,应用用户选择的参数"""
with open(original_config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 更新题目列表
if 'autoline' not in config:
config['autoline'] = {}
if 'probset' not in config['autoline']:
config['autoline']['probset'] = {}
config['autoline']['probset']['only'] = selected_tasks
# 更新 CGA 设置
if 'cga' not in config['autoline']:
config['autoline']['cga'] = {}
config['autoline']['cga']['enabled'] = cga_enabled
config['autoline']['cga']['max_iter'] = max_iter
# 保存临时配置
with open(output_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
return output_path
def get_default_config() -> str:
"""获取默认配置文件名"""
return "correctbench.yaml"