93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
import os
|
|
import yaml
|
|
from typing import List, Dict, Any
|
|
|
|
CONFIG_DIR = "config"
|
|
CONFIGS_DIR = os.path.join(CONFIG_DIR, "configs")
|
|
|
|
|
|
def get_config_files() -> List[str]:
|
|
"""返回 config/configs/ 目录下的所有 yaml 配置文件"""
|
|
if not os.path.exists(CONFIGS_DIR):
|
|
return []
|
|
files = [f for f in os.listdir(CONFIGS_DIR) if f.endswith(('.yaml', '.yml'))]
|
|
return sorted(files)
|
|
|
|
|
|
def get_config_path(config_name: str) -> str:
|
|
"""根据配置文件名返回完整路径"""
|
|
return os.path.join(CONFIGS_DIR, config_name)
|
|
|
|
|
|
def load_tasks_from_config(config_path: str) -> List[str]:
|
|
"""从配置文件中加载指定的题目列表"""
|
|
if not os.path.exists(config_path):
|
|
return []
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
probset = config.get('autoline', {}).get('probset', {})
|
|
tasks = probset.get('only', [])
|
|
|
|
# 如果配置中没有 only 字段,尝试从数据集获取所有题目
|
|
if not tasks:
|
|
data_path = probset.get('path', '')
|
|
if data_path and os.path.exists(data_path):
|
|
tasks = get_all_tasks_from_dataset(data_path)
|
|
|
|
return tasks
|
|
|
|
|
|
def get_all_tasks_from_dataset(jsonl_path: str) -> List[str]:
|
|
"""从 JSONL 数据集文件中获取所有题目 ID"""
|
|
tasks = []
|
|
try:
|
|
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
import json
|
|
obj = json.loads(line)
|
|
task_id = obj.get('task_id', '')
|
|
if task_id:
|
|
tasks.append(task_id)
|
|
except Exception as e:
|
|
print(f"Error loading tasks from {jsonl_path}: {e}")
|
|
return tasks
|
|
|
|
|
|
def create_temp_config(
|
|
original_config_path: str,
|
|
selected_tasks: List[str],
|
|
cga_enabled: bool = True,
|
|
max_iter: int = 10,
|
|
output_path: str = "/tmp/correctbench_temp_config.yaml"
|
|
) -> str:
|
|
"""创建临时配置文件,应用用户选择的参数"""
|
|
with open(original_config_path, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# 更新题目列表
|
|
if 'autoline' not in config:
|
|
config['autoline'] = {}
|
|
if 'probset' not in config['autoline']:
|
|
config['autoline']['probset'] = {}
|
|
config['autoline']['probset']['only'] = selected_tasks
|
|
|
|
# 更新 CGA 设置
|
|
if 'cga' not in config['autoline']:
|
|
config['autoline']['cga'] = {}
|
|
config['autoline']['cga']['enabled'] = cga_enabled
|
|
config['autoline']['cga']['max_iter'] = max_iter
|
|
|
|
# 保存临时配置
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
|
|
|
|
return output_path
|
|
|
|
|
|
def get_default_config() -> str:
|
|
"""获取默认配置文件名"""
|
|
return "correctbench.yaml" |