Files
CGA-bench/experiments/run_paper_experiments.py
2026-05-22 10:02:42 +08:00

276 lines
9.0 KiB
Python

"""
Batch runner for the paper-style CorrectBench experiments.
This script runs paired Baseline/CGA conditions across the FSM/protocol task set
by generating per-run config files and invoking `main.py` in isolated subprocesses.
"""
from __future__ import annotations
import argparse
import csv
import json
import subprocess
import sys
import time
from copy import deepcopy
from pathlib import Path
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from experiments.paper_tasks import FSM_PROTOCOL_TASKS
DEFAULT_BASE_CONFIG = PROJECT_ROOT / "config" / "configs" / "paper_fsm_qwen.yaml"
DEFAULT_MANIFEST_DIR = PROJECT_ROOT / "analysis" / "paper_runs"
def sanitize_token(value: str) -> str:
return "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in value)
def default_python_bin() -> str:
venv_python = PROJECT_ROOT / "venv" / "bin" / "python"
if venv_python.exists():
return str(venv_python)
return sys.executable
def load_yaml(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return yaml.safe_load(f)
def save_yaml(path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=False)
def discover_run_dirs(prefix: str) -> list[Path]:
saves_root = PROJECT_ROOT / "saves"
if not saves_root.exists():
return []
pattern = f"**/{prefix}_*"
return sorted([p for p in saves_root.glob(pattern) if p.is_dir()], key=lambda p: p.stat().st_mtime)
def build_run_config(
base_config: dict,
model: str,
condition: str,
repeat_idx: int,
experiment_name: str,
tasks: list[str],
) -> dict:
cfg = deepcopy(base_config)
model_slug = sanitize_token(model)
prefix = f"{experiment_name}_{model_slug}_{condition}_r{repeat_idx:02d}"
cfg["gpt"]["model"] = model
cfg["gpt"]["rtlgen_model"] = model
cfg["autoline"]["probset"]["only"] = list(tasks)
cfg["autoline"]["cga"]["enabled"] = (condition == "cga")
cfg["autoline"]["result_path"] = f"results/paper/{experiment_name}/{model_slug}/{condition}/repeat_{repeat_idx:02d}"
cfg["save"]["pub"]["prefix"] = prefix
cfg["save"]["pub"]["subdir"] = f"Paper_Experiments/{experiment_name}/{model_slug}/{condition}"
return cfg
def save_manifest(manifest_path: Path, rows: list[dict]) -> None:
manifest_path.parent.mkdir(parents=True, exist_ok=True)
with manifest_path.open("w", encoding="utf-8") as f:
json.dump(rows, f, indent=2, ensure_ascii=False)
csv_path = manifest_path.with_suffix(".csv")
fieldnames = [
"experiment_name",
"model",
"condition",
"repeat",
"task_count",
"config_path",
"run_dir",
"returncode",
"duration_sec",
"timestamp",
]
with csv_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow({key: row.get(key, "") for key in fieldnames})
def run_single_experiment(
python_bin: str,
config_path: Path,
prefix: str,
) -> tuple[int, float, str]:
before_dirs = {str(path) for path in discover_run_dirs(prefix)}
started_at = time.time()
cmd = [python_bin, "-u", str(PROJECT_ROOT / "main.py"), "-c", str(config_path)]
completed = subprocess.run(cmd, cwd=str(PROJECT_ROOT))
duration_sec = time.time() - started_at
after_dirs = discover_run_dirs(prefix)
new_dirs = [path for path in after_dirs if str(path) not in before_dirs]
run_dir = str(new_dirs[-1] if new_dirs else (after_dirs[-1] if after_dirs else ""))
return completed.returncode, duration_sec, run_dir
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run paired Baseline/CGA CorrectBench paper experiments.")
parser.add_argument(
"--base-config",
default=str(DEFAULT_BASE_CONFIG),
help="Base YAML config used as the template for all generated runs.",
)
parser.add_argument(
"--python-bin",
default=default_python_bin(),
help="Python interpreter used to launch main.py for each run.",
)
parser.add_argument(
"--experiment-name",
default="paper_fsm",
help="Experiment tag used in save subdirectories and manifests.",
)
parser.add_argument(
"--models",
nargs="+",
default=["qwen-max"],
help="One or more models to evaluate.",
)
parser.add_argument(
"--conditions",
nargs="+",
choices=["baseline", "cga"],
default=["baseline", "cga"],
help="Conditions to run for each repeat.",
)
parser.add_argument(
"--repeats",
type=int,
default=5,
help="Number of repeats per condition.",
)
parser.add_argument(
"--start-repeat",
type=int,
default=1,
help="1-based repeat index to start from.",
)
parser.add_argument(
"--limit-tasks",
type=int,
default=0,
help="Optional prefix limit for quick smoke runs.",
)
parser.add_argument(
"--sleep-seconds",
type=float,
default=0.0,
help="Optional cooldown between runs.",
)
parser.add_argument(
"--manifest-path",
default="",
help="Optional manifest path. Defaults to analysis/paper_runs/<experiment_name>/run_manifest.json.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Generate configs and manifest entries without launching runs.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
base_config_path = Path(args.base_config)
if not base_config_path.is_absolute():
base_config_path = (PROJECT_ROOT / base_config_path).resolve()
base_config = load_yaml(base_config_path)
tasks = list(FSM_PROTOCOL_TASKS)
if args.limit_tasks > 0:
tasks = tasks[: args.limit_tasks]
experiment_name = sanitize_token(args.experiment_name)
manifest_path = Path(args.manifest_path) if args.manifest_path else (
DEFAULT_MANIFEST_DIR / experiment_name / "run_manifest.json"
)
if not manifest_path.is_absolute():
manifest_path = (PROJECT_ROOT / manifest_path).resolve()
generated_config_dir = manifest_path.parent / "generated_configs"
manifest_rows = []
total_runs = len(args.models) * len(args.conditions) * max(0, args.repeats - args.start_repeat + 1)
launched = 0
for model in args.models:
for condition in args.conditions:
for repeat_idx in range(args.start_repeat, args.repeats + 1):
cfg = build_run_config(
base_config=base_config,
model=model,
condition=condition,
repeat_idx=repeat_idx,
experiment_name=experiment_name,
tasks=tasks,
)
prefix = cfg["save"]["pub"]["prefix"]
config_path = generated_config_dir / f"{prefix}.yaml"
save_yaml(config_path, cfg)
launched += 1
print(
f"[{launched}/{total_runs}] condition={condition} repeat={repeat_idx:02d} "
f"model={model} tasks={len(tasks)}"
)
if args.dry_run:
row = {
"experiment_name": experiment_name,
"model": model,
"condition": condition,
"repeat": repeat_idx,
"task_count": len(tasks),
"config_path": str(config_path),
"run_dir": "",
"returncode": 0,
"duration_sec": 0.0,
"timestamp": int(time.time()),
}
else:
returncode, duration_sec, run_dir = run_single_experiment(args.python_bin, config_path, prefix)
row = {
"experiment_name": experiment_name,
"model": model,
"condition": condition,
"repeat": repeat_idx,
"task_count": len(tasks),
"config_path": str(config_path),
"run_dir": run_dir,
"returncode": returncode,
"duration_sec": round(duration_sec, 2),
"timestamp": int(time.time()),
}
if args.sleep_seconds > 0:
time.sleep(args.sleep_seconds)
manifest_rows.append(row)
save_manifest(manifest_path, manifest_rows)
print(f"Manifest saved to {manifest_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())