import os import yaml from typing import List, Dict, Any CONFIG_DIR = "config" CONFIGS_DIR = os.path.join(CONFIG_DIR, "configs") def get_config_files() -> List[str]: """返回 config/configs/ 目录下的所有 yaml 配置文件""" if not os.path.exists(CONFIGS_DIR): return [] files = [f for f in os.listdir(CONFIGS_DIR) if f.endswith(('.yaml', '.yml'))] return sorted(files) def get_config_path(config_name: str) -> str: """根据配置文件名返回完整路径""" return os.path.join(CONFIGS_DIR, config_name) def load_tasks_from_config(config_path: str) -> List[str]: """从配置文件中加载指定的题目列表""" if not os.path.exists(config_path): return [] with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) probset = config.get('autoline', {}).get('probset', {}) tasks = probset.get('only', []) # 如果配置中没有 only 字段,尝试从数据集获取所有题目 if not tasks: data_path = probset.get('path', '') if data_path and os.path.exists(data_path): tasks = get_all_tasks_from_dataset(data_path) return tasks def get_all_tasks_from_dataset(jsonl_path: str) -> List[str]: """从 JSONL 数据集文件中获取所有题目 ID""" tasks = [] try: with open(jsonl_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: import json obj = json.loads(line) task_id = obj.get('task_id', '') if task_id: tasks.append(task_id) except Exception as e: print(f"Error loading tasks from {jsonl_path}: {e}") return tasks def create_temp_config( original_config_path: str, selected_tasks: List[str], cga_enabled: bool = True, max_iter: int = 10, output_path: str = "/tmp/correctbench_temp_config.yaml" ) -> str: """创建临时配置文件,应用用户选择的参数""" with open(original_config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # 更新题目列表 if 'autoline' not in config: config['autoline'] = {} if 'probset' not in config['autoline']: config['autoline']['probset'] = {} config['autoline']['probset']['only'] = selected_tasks # 更新 CGA 设置 if 'cga' not in config['autoline']: config['autoline']['cga'] = {} config['autoline']['cga']['enabled'] = cga_enabled config['autoline']['cga']['max_iter'] = max_iter # 保存临时配置 with open(output_path, 'w', encoding='utf-8') as f: yaml.dump(config, f, allow_unicode=True, default_flow_style=False) return output_path def get_default_config() -> str: """获取默认配置文件名""" return "correctbench.yaml"