276 lines
9.0 KiB
Python
276 lines
9.0 KiB
Python
"""
|
|
Batch runner for the paper-style CorrectBench experiments.
|
|
|
|
This script runs paired Baseline/CGA conditions across the FSM/protocol task set
|
|
by generating per-run config files and invoking `main.py` in isolated subprocesses.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from experiments.paper_tasks import FSM_PROTOCOL_TASKS
|
|
|
|
|
|
DEFAULT_BASE_CONFIG = PROJECT_ROOT / "config" / "configs" / "paper_fsm_qwen.yaml"
|
|
DEFAULT_MANIFEST_DIR = PROJECT_ROOT / "analysis" / "paper_runs"
|
|
|
|
|
|
def sanitize_token(value: str) -> str:
|
|
return "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in value)
|
|
|
|
|
|
def default_python_bin() -> str:
|
|
venv_python = PROJECT_ROOT / "venv" / "bin" / "python"
|
|
if venv_python.exists():
|
|
return str(venv_python)
|
|
return sys.executable
|
|
|
|
|
|
def load_yaml(path: Path) -> dict:
|
|
with path.open("r", encoding="utf-8") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def save_yaml(path: Path, data: dict) -> None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with path.open("w", encoding="utf-8") as f:
|
|
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=False)
|
|
|
|
|
|
def discover_run_dirs(prefix: str) -> list[Path]:
|
|
saves_root = PROJECT_ROOT / "saves"
|
|
if not saves_root.exists():
|
|
return []
|
|
pattern = f"**/{prefix}_*"
|
|
return sorted([p for p in saves_root.glob(pattern) if p.is_dir()], key=lambda p: p.stat().st_mtime)
|
|
|
|
|
|
def build_run_config(
|
|
base_config: dict,
|
|
model: str,
|
|
condition: str,
|
|
repeat_idx: int,
|
|
experiment_name: str,
|
|
tasks: list[str],
|
|
) -> dict:
|
|
cfg = deepcopy(base_config)
|
|
model_slug = sanitize_token(model)
|
|
prefix = f"{experiment_name}_{model_slug}_{condition}_r{repeat_idx:02d}"
|
|
|
|
cfg["gpt"]["model"] = model
|
|
cfg["gpt"]["rtlgen_model"] = model
|
|
cfg["autoline"]["probset"]["only"] = list(tasks)
|
|
cfg["autoline"]["cga"]["enabled"] = (condition == "cga")
|
|
cfg["autoline"]["result_path"] = f"results/paper/{experiment_name}/{model_slug}/{condition}/repeat_{repeat_idx:02d}"
|
|
cfg["save"]["pub"]["prefix"] = prefix
|
|
cfg["save"]["pub"]["subdir"] = f"Paper_Experiments/{experiment_name}/{model_slug}/{condition}"
|
|
return cfg
|
|
|
|
|
|
def save_manifest(manifest_path: Path, rows: list[dict]) -> None:
|
|
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with manifest_path.open("w", encoding="utf-8") as f:
|
|
json.dump(rows, f, indent=2, ensure_ascii=False)
|
|
|
|
csv_path = manifest_path.with_suffix(".csv")
|
|
fieldnames = [
|
|
"experiment_name",
|
|
"model",
|
|
"condition",
|
|
"repeat",
|
|
"task_count",
|
|
"config_path",
|
|
"run_dir",
|
|
"returncode",
|
|
"duration_sec",
|
|
"timestamp",
|
|
]
|
|
with csv_path.open("w", encoding="utf-8", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for row in rows:
|
|
writer.writerow({key: row.get(key, "") for key in fieldnames})
|
|
|
|
|
|
def run_single_experiment(
|
|
python_bin: str,
|
|
config_path: Path,
|
|
prefix: str,
|
|
) -> tuple[int, float, str]:
|
|
before_dirs = {str(path) for path in discover_run_dirs(prefix)}
|
|
started_at = time.time()
|
|
cmd = [python_bin, "-u", str(PROJECT_ROOT / "main.py"), "-c", str(config_path)]
|
|
completed = subprocess.run(cmd, cwd=str(PROJECT_ROOT))
|
|
duration_sec = time.time() - started_at
|
|
|
|
after_dirs = discover_run_dirs(prefix)
|
|
new_dirs = [path for path in after_dirs if str(path) not in before_dirs]
|
|
run_dir = str(new_dirs[-1] if new_dirs else (after_dirs[-1] if after_dirs else ""))
|
|
return completed.returncode, duration_sec, run_dir
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Run paired Baseline/CGA CorrectBench paper experiments.")
|
|
parser.add_argument(
|
|
"--base-config",
|
|
default=str(DEFAULT_BASE_CONFIG),
|
|
help="Base YAML config used as the template for all generated runs.",
|
|
)
|
|
parser.add_argument(
|
|
"--python-bin",
|
|
default=default_python_bin(),
|
|
help="Python interpreter used to launch main.py for each run.",
|
|
)
|
|
parser.add_argument(
|
|
"--experiment-name",
|
|
default="paper_fsm",
|
|
help="Experiment tag used in save subdirectories and manifests.",
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
nargs="+",
|
|
default=["qwen-max"],
|
|
help="One or more models to evaluate.",
|
|
)
|
|
parser.add_argument(
|
|
"--conditions",
|
|
nargs="+",
|
|
choices=["baseline", "cga"],
|
|
default=["baseline", "cga"],
|
|
help="Conditions to run for each repeat.",
|
|
)
|
|
parser.add_argument(
|
|
"--repeats",
|
|
type=int,
|
|
default=5,
|
|
help="Number of repeats per condition.",
|
|
)
|
|
parser.add_argument(
|
|
"--start-repeat",
|
|
type=int,
|
|
default=1,
|
|
help="1-based repeat index to start from.",
|
|
)
|
|
parser.add_argument(
|
|
"--limit-tasks",
|
|
type=int,
|
|
default=0,
|
|
help="Optional prefix limit for quick smoke runs.",
|
|
)
|
|
parser.add_argument(
|
|
"--sleep-seconds",
|
|
type=float,
|
|
default=0.0,
|
|
help="Optional cooldown between runs.",
|
|
)
|
|
parser.add_argument(
|
|
"--manifest-path",
|
|
default="",
|
|
help="Optional manifest path. Defaults to analysis/paper_runs/<experiment_name>/run_manifest.json.",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Generate configs and manifest entries without launching runs.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
base_config_path = Path(args.base_config)
|
|
if not base_config_path.is_absolute():
|
|
base_config_path = (PROJECT_ROOT / base_config_path).resolve()
|
|
base_config = load_yaml(base_config_path)
|
|
|
|
tasks = list(FSM_PROTOCOL_TASKS)
|
|
if args.limit_tasks > 0:
|
|
tasks = tasks[: args.limit_tasks]
|
|
|
|
experiment_name = sanitize_token(args.experiment_name)
|
|
manifest_path = Path(args.manifest_path) if args.manifest_path else (
|
|
DEFAULT_MANIFEST_DIR / experiment_name / "run_manifest.json"
|
|
)
|
|
if not manifest_path.is_absolute():
|
|
manifest_path = (PROJECT_ROOT / manifest_path).resolve()
|
|
generated_config_dir = manifest_path.parent / "generated_configs"
|
|
|
|
manifest_rows = []
|
|
total_runs = len(args.models) * len(args.conditions) * max(0, args.repeats - args.start_repeat + 1)
|
|
launched = 0
|
|
|
|
for model in args.models:
|
|
for condition in args.conditions:
|
|
for repeat_idx in range(args.start_repeat, args.repeats + 1):
|
|
cfg = build_run_config(
|
|
base_config=base_config,
|
|
model=model,
|
|
condition=condition,
|
|
repeat_idx=repeat_idx,
|
|
experiment_name=experiment_name,
|
|
tasks=tasks,
|
|
)
|
|
prefix = cfg["save"]["pub"]["prefix"]
|
|
config_path = generated_config_dir / f"{prefix}.yaml"
|
|
save_yaml(config_path, cfg)
|
|
|
|
launched += 1
|
|
print(
|
|
f"[{launched}/{total_runs}] condition={condition} repeat={repeat_idx:02d} "
|
|
f"model={model} tasks={len(tasks)}"
|
|
)
|
|
|
|
if args.dry_run:
|
|
row = {
|
|
"experiment_name": experiment_name,
|
|
"model": model,
|
|
"condition": condition,
|
|
"repeat": repeat_idx,
|
|
"task_count": len(tasks),
|
|
"config_path": str(config_path),
|
|
"run_dir": "",
|
|
"returncode": 0,
|
|
"duration_sec": 0.0,
|
|
"timestamp": int(time.time()),
|
|
}
|
|
else:
|
|
returncode, duration_sec, run_dir = run_single_experiment(args.python_bin, config_path, prefix)
|
|
row = {
|
|
"experiment_name": experiment_name,
|
|
"model": model,
|
|
"condition": condition,
|
|
"repeat": repeat_idx,
|
|
"task_count": len(tasks),
|
|
"config_path": str(config_path),
|
|
"run_dir": run_dir,
|
|
"returncode": returncode,
|
|
"duration_sec": round(duration_sec, 2),
|
|
"timestamp": int(time.time()),
|
|
}
|
|
if args.sleep_seconds > 0:
|
|
time.sleep(args.sleep_seconds)
|
|
|
|
manifest_rows.append(row)
|
|
save_manifest(manifest_path, manifest_rows)
|
|
|
|
print(f"Manifest saved to {manifest_path}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|