experiment_design.py (21014B)
1 #!/usr/bin/env python3 2 """Experiment design and analysis for loop benchmarking. 3 4 Generates efficient experiment plans instead of full factorial grids. 5 Analyzes results to identify which variables have the biggest impact. 6 7 Approaches: 8 1. Main effects sweep: vary one axis at a time from a baseline 9 2. Fractional factorial: Plackett-Burman screening for binary factors 10 3. Interaction hunt: full factorial on the top-k most impactful axes 11 """ 12 13 import json 14 import math 15 import sys 16 from itertools import product 17 from pathlib import Path 18 19 import yaml 20 21 22 def load_grid(path): 23 with open(path) as f: 24 return yaml.safe_load(f) 25 26 27 def get_axes(grid, profile_name=None): 28 """Get axis definitions, optionally filtered by profile.""" 29 top_axes = {name: spec["values"] for name, spec in grid["axes"].items()} 30 if profile_name and profile_name in grid.get("profiles", {}): 31 profile = grid["profiles"][profile_name] 32 if "axes" in profile: 33 axes = dict(top_axes) 34 for name, values in profile["axes"].items(): 35 axes[name] = values 36 return axes 37 return top_axes 38 39 40 # --------------------------------------------------------------------------- 41 # 1. Main effects sweep 42 # --------------------------------------------------------------------------- 43 44 def main_effects_plan(grid, baseline=None, tasks=None): 45 """Generate a one-at-a-time sweep from a baseline. 46 47 For each axis, vary it through all its values while holding everything 48 else at baseline. This identifies main effects cheaply. 49 50 Returns a list of cell dicts. 51 """ 52 axes = get_axes(grid) 53 tasks = tasks or grid["tasks"] 54 defaults = grid["defaults"] 55 56 # Pick baseline: first value of each axis unless overridden 57 if baseline is None: 58 baseline = {name: values[0] for name, values in axes.items()} 59 60 cells = [] 61 seen = set() 62 63 for task in tasks: 64 # Apply task overrides to axes 65 task_axes = dict(axes) 66 overrides = grid.get("task_overrides", {}).get(task, {}) 67 if "axes" in overrides: 68 for axis_name, spec in overrides["axes"].items(): 69 task_axes[axis_name] = spec["values"] 70 71 # Baseline cell 72 base_cell = dict(baseline) 73 # Ensure baseline values are valid for this task 74 for name, values in task_axes.items(): 75 if base_cell[name] not in values: 76 base_cell[name] = values[0] 77 78 base_key = _cell_key(task, base_cell) 79 if base_key not in seen: 80 seen.add(base_key) 81 cells.append(_build_cell(task, base_cell, defaults, grid)) 82 83 # Vary each axis 84 for axis_name, values in task_axes.items(): 85 for value in values: 86 if value == base_cell[axis_name]: 87 continue 88 varied = dict(base_cell) 89 varied[axis_name] = value 90 if _is_excluded(varied, grid): 91 continue 92 key = _cell_key(task, varied) 93 if key not in seen: 94 seen.add(key) 95 cells.append(_build_cell(task, varied, defaults, grid)) 96 97 return cells 98 99 100 # --------------------------------------------------------------------------- 101 # 2. Plackett-Burman screening 102 # --------------------------------------------------------------------------- 103 104 def _hadamard_matrix(n): 105 """Generate a Hadamard-like matrix for Plackett-Burman design. 106 107 n must be a multiple of 4. Returns an n x (n-1) matrix of +1/-1. 108 Uses the Paley construction for prime n-1. 109 """ 110 # For simplicity, use the standard PB generators for common sizes 111 # These are the first rows; subsequent rows are cyclic shifts 112 generators = { 113 4: [1, 1, -1], 114 8: [1, 1, 1, -1, 1, -1, -1], 115 12: [1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1], 116 16: [1, 1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1], 117 20: [1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1], 118 24: [1, 1, 1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, -1, -1], 119 } 120 121 if n not in generators: 122 # Fall back to nearest larger size 123 for size in sorted(generators.keys()): 124 if size >= n: 125 n = size 126 break 127 else: 128 n = max(generators.keys()) 129 130 gen = generators[n] 131 k = len(gen) 132 matrix = [] 133 134 for i in range(k): 135 row = gen[i:] + gen[:i] 136 matrix.append(row) 137 138 # Add a row of all -1 139 matrix.append([-1] * k) 140 141 return matrix 142 143 144 def plackett_burman_plan(grid, tasks=None): 145 """Generate a Plackett-Burman screening design for binary factors. 146 147 For factors with more than 2 levels (e.g., model: haiku/sonnet/opus), 148 we create dummy binary variables or sweep them separately. 149 150 Returns a list of cell dicts. 151 """ 152 axes = get_axes(grid) 153 tasks = tasks or grid["tasks"] 154 defaults = grid["defaults"] 155 156 # Separate binary and multi-level factors 157 binary_axes = {} 158 multi_axes = {} 159 for name, values in axes.items(): 160 if len(values) == 2: 161 binary_axes[name] = values 162 elif len(values) > 2: 163 multi_axes[name] = values 164 165 binary_names = sorted(binary_axes.keys()) 166 n_factors = len(binary_names) 167 168 if n_factors == 0: 169 return main_effects_plan(grid, tasks=tasks) 170 171 # Find the smallest PB design that fits 172 n_runs = n_factors + 1 173 # Round up to multiple of 4 174 n_runs = math.ceil(n_runs / 4) * 4 175 176 matrix = _hadamard_matrix(n_runs) 177 178 cells = [] 179 seen = set() 180 181 # For multi-level factors, fix at each level and run the PB design 182 if multi_axes: 183 multi_names = sorted(multi_axes.keys()) 184 multi_combos = list(product(*[multi_axes[n] for n in multi_names])) 185 else: 186 multi_names = [] 187 multi_combos = [()] 188 189 for multi_combo in multi_combos: 190 multi_fixed = dict(zip(multi_names, multi_combo)) 191 192 for row in matrix: 193 cell = dict(multi_fixed) 194 for i, name in enumerate(binary_names): 195 if i < len(row): 196 idx = 0 if row[i] == -1 else 1 197 else: 198 idx = 0 199 cell[name] = binary_axes[name][idx] 200 201 for task in tasks: 202 # Apply task overrides 203 task_axes = dict(axes) 204 overrides = grid.get("task_overrides", {}).get(task, {}) 205 if "axes" in overrides: 206 for axis_name, spec in overrides["axes"].items(): 207 task_axes[axis_name] = spec["values"] 208 209 # Ensure values are valid for this task 210 valid = True 211 for name, values in task_axes.items(): 212 if cell.get(name) not in values: 213 if len(values) == 1: 214 cell[name] = values[0] 215 else: 216 valid = False 217 break 218 219 # Check exclusions 220 if valid and not _is_excluded(cell, grid): 221 key = _cell_key(task, cell) 222 if key not in seen: 223 seen.add(key) 224 cells.append(_build_cell(task, cell, defaults, grid)) 225 226 return cells 227 228 229 # --------------------------------------------------------------------------- 230 # 3. Interaction hunt 231 # --------------------------------------------------------------------------- 232 233 def interaction_hunt_plan(grid, top_axes, tasks=None): 234 """Full factorial on a subset of axes, baseline for the rest. 235 236 Args: 237 top_axes: list of axis names to fully explore (e.g., ["model", "effort", "linter"]) 238 tasks: which tasks to include 239 """ 240 axes = get_axes(grid) 241 tasks = tasks or grid["tasks"] 242 defaults = grid["defaults"] 243 244 # Baseline for non-explored axes 245 baseline = {name: values[0] for name, values in axes.items()} 246 247 # Full factorial on top_axes 248 explore_names = sorted(top_axes) 249 explore_values = [axes[n] for n in explore_names] 250 251 cells = [] 252 seen = set() 253 254 for combo in product(*explore_values): 255 cell = dict(baseline) 256 for name, value in zip(explore_names, combo): 257 cell[name] = value 258 259 for task in tasks: 260 task_axes = dict(axes) 261 overrides = grid.get("task_overrides", {}).get(task, {}) 262 if "axes" in overrides: 263 for axis_name, spec in overrides["axes"].items(): 264 task_axes[axis_name] = spec["values"] 265 266 # Adjust for task constraints 267 for name, values in task_axes.items(): 268 if cell.get(name) not in values: 269 cell[name] = values[0] 270 271 if not _is_excluded(cell, grid): 272 key = _cell_key(task, cell) 273 if key not in seen: 274 seen.add(key) 275 cells.append(_build_cell(task, cell, defaults, grid)) 276 277 return cells 278 279 280 # --------------------------------------------------------------------------- 281 # Analysis: compute effects from results 282 # --------------------------------------------------------------------------- 283 284 def analyze_main_effects(results_dir, metric="score"): 285 """Compute the main effect of each axis on a given metric. 286 287 Reads all completed runs, groups by axis values, computes mean metric 288 for each group, and returns the effect size (difference from grand mean). 289 290 Returns a dict: {axis_name: {value: effect_size, ...}, ...} 291 sorted by absolute effect size. 292 """ 293 runs = _load_results(results_dir) 294 if not runs: 295 return {} 296 297 # Extract metric values 298 scored_runs = [] 299 for run in runs: 300 val = _extract_metric(run, metric) 301 if val is not None: 302 scored_runs.append((run["meta"], val)) 303 304 if not scored_runs: 305 return {} 306 307 grand_mean = sum(v for _, v in scored_runs) / len(scored_runs) 308 309 # Identify axes from the first run's meta 310 meta_keys = set(scored_runs[0][0].keys()) 311 skip_keys = { 312 "task", "cell_id", "run_id", "run_number", "runs_per_cell", 313 "max_budget_usd", "timeout_seconds", "base_tools", 314 "started_at", "completed_at", "wall_time_seconds", "exit_code", 315 "short_id", "short_cell_id", "claude_version", "actual_model", 316 } 317 axis_names = sorted(meta_keys - skip_keys) 318 319 effects = {} 320 for axis in axis_names: 321 groups = {} 322 for meta, val in scored_runs: 323 key = str(meta.get(axis, "unknown")) 324 groups.setdefault(key, []).append(val) 325 326 if len(groups) < 2: 327 continue 328 329 axis_effects = {} 330 for value, vals in sorted(groups.items()): 331 group_mean = sum(vals) / len(vals) 332 effect = group_mean - grand_mean 333 axis_effects[value] = { 334 "mean": round(group_mean, 4), 335 "effect": round(effect, 4), 336 "n": len(vals), 337 } 338 339 # Effect magnitude = max spread between any two values 340 means = [v["mean"] for v in axis_effects.values()] 341 spread = max(means) - min(means) if means else 0 342 343 effects[axis] = { 344 "values": axis_effects, 345 "spread": round(spread, 4), 346 } 347 348 # Sort by spread (biggest effects first) 349 effects = dict(sorted(effects.items(), key=lambda x: -x[1]["spread"])) 350 return effects 351 352 353 def analyze_interactions(results_dir, axis_a, axis_b, metric="score"): 354 """Compute the interaction effect between two axes. 355 356 Returns a 2D table of mean metric values for each (a_value, b_value) combo, 357 plus the interaction effect size. 358 """ 359 runs = _load_results(results_dir) 360 if not runs: 361 return {} 362 363 groups = {} 364 for run in runs: 365 val = _extract_metric(run, metric) 366 if val is None: 367 continue 368 a_val = str(run["meta"].get(axis_a, "?")) 369 b_val = str(run["meta"].get(axis_b, "?")) 370 key = (a_val, b_val) 371 groups.setdefault(key, []).append(val) 372 373 if not groups: 374 return {} 375 376 table = {} 377 for (a_val, b_val), vals in sorted(groups.items()): 378 table.setdefault(a_val, {})[b_val] = { 379 "mean": round(sum(vals) / len(vals), 4), 380 "n": len(vals), 381 } 382 383 # Compute interaction: does the effect of axis_a change depending on axis_b? 384 a_values = sorted(table.keys()) 385 b_values = sorted(set(b for row in table.values() for b in row.keys())) 386 387 # Interaction = deviation from additive model 388 grand_mean = sum( 389 v for row in table.values() for cell in row.values() for v in [cell["mean"]] 390 ) / sum(1 for row in table.values() for _ in row.values()) 391 392 a_means = {} 393 for a in a_values: 394 vals = [table[a][b]["mean"] for b in b_values if b in table.get(a, {})] 395 a_means[a] = sum(vals) / len(vals) if vals else grand_mean 396 397 b_means = {} 398 for b in b_values: 399 vals = [table[a][b]["mean"] for a in a_values if b in table.get(a, {})] 400 b_means[b] = sum(vals) / len(vals) if vals else grand_mean 401 402 # Interaction effects 403 interactions = {} 404 max_interaction = 0 405 for a in a_values: 406 for b in b_values: 407 if b in table.get(a, {}): 408 expected = a_means[a] + b_means[b] - grand_mean 409 actual = table[a][b]["mean"] 410 interaction = round(actual - expected, 4) 411 interactions[(a, b)] = interaction 412 max_interaction = max(max_interaction, abs(interaction)) 413 414 return { 415 "table": table, 416 "grand_mean": round(grand_mean, 4), 417 "a_means": {k: round(v, 4) for k, v in a_means.items()}, 418 "b_means": {k: round(v, 4) for k, v in b_means.items()}, 419 "interactions": {f"{a},{b}": v for (a, b), v in interactions.items()}, 420 "max_interaction": round(max_interaction, 4), 421 } 422 423 424 # --------------------------------------------------------------------------- 425 # Helpers 426 # --------------------------------------------------------------------------- 427 428 def _cell_key(task, cell): 429 axis_names = sorted(k for k in cell.keys() if k not in ( 430 "task", "cell_id", "runs_per_cell", "max_budget_usd", 431 "timeout_seconds", "base_tools", 432 )) 433 parts = [task] + [f"{k}={cell[k]}" for k in axis_names] 434 return "_".join(parts) 435 436 437 def _is_excluded(cell, grid): 438 for exclusion in grid.get("exclusions", []): 439 match = True 440 for key, value in exclusion["when"].items(): 441 if cell.get(key) != value: 442 match = False 443 break 444 if match: 445 return True 446 return False 447 448 449 def _build_cell(task, cell, defaults, grid): 450 from compute_grid import AXIS_ABBREV, VALUE_ABBREV 451 axis_names = sorted(cell.keys()) 452 453 cell_id_parts = [task] + [f"{AXIS_ABBREV.get(k, k)}={VALUE_ABBREV.get(str(cell[k]), cell[k])}" for k in axis_names] 454 455 result = dict(cell) 456 result["task"] = task 457 result["actual_model"] = cell.get("model", "") 458 result["cell_id"] = "_".join(cell_id_parts) 459 result["runs_per_cell"] = defaults.get("runs_per_cell", 3) 460 result["timeout_seconds"] = defaults.get("timeout_seconds", 600) 461 462 budget_key = cell.get("max_budget", "low") 463 result["max_budget_usd"] = defaults.get("budget", {}).get(budget_key, 0.50) 464 465 return result 466 467 468 def _load_results(results_dir): 469 """Load all completed runs from the results directory.""" 470 results_dir = Path(results_dir) 471 runs_dir = results_dir / "runs" 472 if not runs_dir.exists(): 473 return [] 474 475 runs = [] 476 for run_dir in runs_dir.iterdir(): 477 if not run_dir.is_dir(): 478 continue 479 meta_path = run_dir / "meta.json" 480 eval_path = run_dir / "eval_results.json" 481 claude_path = run_dir / "claude_output.json" 482 483 if not meta_path.exists() or not eval_path.exists(): 484 continue 485 486 try: 487 meta = json.loads(meta_path.read_text()) 488 eval_results = json.loads(eval_path.read_text()) 489 claude_output = {} 490 if claude_path.exists(): 491 claude_output = json.loads(claude_path.read_text()) 492 493 runs.append({ 494 "meta": meta, 495 "eval": eval_results, 496 "claude": claude_output, 497 }) 498 except (json.JSONDecodeError, OSError): 499 continue 500 501 return runs 502 503 504 def _extract_metric(run, metric): 505 """Extract a numeric metric from a run.""" 506 if metric == "score": 507 val = run["eval"].get("score") 508 return val if isinstance(val, (int, float)) else None 509 elif metric == "cost": 510 return run["claude"].get("total_cost_usd") 511 elif metric == "turns": 512 return run["claude"].get("num_turns") 513 elif metric == "wall_time": 514 return run["meta"].get("wall_time_seconds") 515 elif metric == "pass_rate": 516 func = run["eval"].get("functional", {}) 517 if isinstance(func, dict) and "pass" in func: 518 return 1.0 if func["pass"] else 0.0 519 return None 520 elif metric == "gameplay": 521 gb = run["eval"].get("gameplay_bot", {}) 522 if isinstance(gb, dict): 523 val = gb.get("score") 524 return val if isinstance(val, (int, float)) else None 525 return None 526 elif metric == "code_quality": 527 ca = run["eval"].get("code_analysis", {}) 528 if isinstance(ca, dict): 529 val = ca.get("score") 530 return val if isinstance(val, (int, float)) else None 531 return None 532 elif metric == "structural": 533 s = run["eval"].get("structural", {}) 534 if isinstance(s, dict): 535 val = s.get("score") 536 return val if isinstance(val, (int, float)) else None 537 return None 538 elif metric == "transcript": 539 t = run["eval"].get("transcript_analysis", {}) 540 if isinstance(t, dict): 541 val = t.get("score") 542 return val if isinstance(val, (int, float)) else None 543 return None 544 elif metric == "sonarqube": 545 sq = run["eval"].get("sonarqube", {}) 546 if isinstance(sq, dict): 547 val = sq.get("score") 548 return val if isinstance(val, (int, float)) else None 549 return None 550 elif metric == "build_quality": 551 q = run["eval"].get("quality", {}) 552 if isinstance(q, dict): 553 val = q.get("score") 554 return val if isinstance(val, (int, float)) else None 555 return None 556 return None 557 558 559 # --------------------------------------------------------------------------- 560 # CLI 561 # --------------------------------------------------------------------------- 562 563 def main(): 564 if len(sys.argv) < 3: 565 print("Usage:") 566 print(" experiment_design.py plan <grid_file> <design> [args...]") 567 print(" designs: main_effects, plackett_burman, interaction_hunt") 568 print(" experiment_design.py analyze <results_dir> <analysis> [args...]") 569 print(" analyses: main_effects, interactions") 570 sys.exit(1) 571 572 command = sys.argv[1] 573 574 if command == "plan": 575 grid_file = sys.argv[2] 576 design = sys.argv[3] if len(sys.argv) > 3 else "main_effects" 577 grid = load_grid(grid_file) 578 579 if design == "main_effects": 580 cells = main_effects_plan(grid) 581 elif design == "plackett_burman": 582 cells = plackett_burman_plan(grid) 583 elif design == "interaction_hunt": 584 top_axes = sys.argv[4].split(",") if len(sys.argv) > 4 else [] 585 if not top_axes: 586 print("ERROR: interaction_hunt requires comma-separated axis names", file=sys.stderr) 587 sys.exit(1) 588 cells = interaction_hunt_plan(grid, top_axes) 589 else: 590 print(f"Unknown design: {design}", file=sys.stderr) 591 sys.exit(1) 592 593 print(f"# {design}: {len(cells)} cells", file=sys.stderr) 594 for cell in cells: 595 print(json.dumps(cell)) 596 597 elif command == "analyze": 598 results_dir = sys.argv[2] 599 analysis = sys.argv[3] if len(sys.argv) > 3 else "main_effects" 600 601 if analysis == "main_effects": 602 metric = sys.argv[4] if len(sys.argv) > 4 else "score" 603 effects = analyze_main_effects(results_dir, metric) 604 print(json.dumps(effects, indent=2)) 605 elif analysis == "interactions": 606 if len(sys.argv) < 6: 607 print("ERROR: interactions requires two axis names", file=sys.stderr) 608 sys.exit(1) 609 axis_a = sys.argv[4] 610 axis_b = sys.argv[5] 611 metric = sys.argv[6] if len(sys.argv) > 6 else "score" 612 result = analyze_interactions(results_dir, axis_a, axis_b, metric) 613 print(json.dumps(result, indent=2)) 614 else: 615 print(f"Unknown analysis: {analysis}", file=sys.stderr) 616 sys.exit(1) 617 618 else: 619 print(f"Unknown command: {command}", file=sys.stderr) 620 sys.exit(1) 621 622 623 if __name__ == "__main__": 624 main()