compute_grid.py (7001B)
1 #!/usr/bin/env python3 2 """Compute the experiment grid from grid.yaml. 3 4 Reads the grid definition, selects a profile, computes the cartesian product 5 of all axes, applies exclusions and task overrides, and outputs one JSON 6 object per cell (JSONL format). 7 8 Usage: 9 python3 compute_grid.py <grid_file> [profile] 10 11 If no profile is given, uses "smoke". 12 """ 13 14 import json 15 import sys 16 from itertools import product 17 18 import yaml 19 20 # Short axis names for cell_id to avoid filesystem path length limits (ext4: 255 chars) 21 AXIS_ABBREV = { 22 "context_file": "ctx", 23 "effort": "eff", 24 "human_language": "hlang", 25 "language": "lang", 26 "linter": "lint", 27 "max_budget": "budget", 28 "model": "model", 29 "playwright": "pw", 30 "prompt_style": "prompt", 31 "tool_edit": "tedit", 32 "tool_glob": "tglob", 33 "tool_grep": "tgrep", 34 "tool_read": "tread", 35 "tool_write": "twrite", 36 "web_search": "web", 37 # New axes 38 "tests_provided": "tst", 39 "strategy": "strat", 40 "design_guidance": "dsgn", 41 "architecture": "arch", 42 "error_checking": "echk", 43 "context_noise": "noise", 44 "renderer": "rndr", 45 "provider": "prov", 46 } 47 48 # Short value names for cell_id to keep paths under 255 chars 49 VALUE_ABBREV = { 50 "creative_validate": "cv", 51 "use_subagents": "usub", 52 "plan_first": "plan", 53 "split_work": "split", 54 "best_practices": "bp", 55 "separation": "sep", 56 "self_verify": "sv", 57 "instructed": "inst", 58 "available": "avail", 59 "a_few": "few", 60 "unspecified": "uns", 61 "typescript": "ts", 62 "javascript": "js", 63 "wikipedia_1k": "wiki1k", 64 "wikipedia_10k": "wiki10k", 65 "wikipedia_50k": "wiki50k", 66 "wikipedia_100k": "wiki100k", 67 "wikipedia_25": "wiki25", 68 "wikipedia_50": "wiki50", 69 "wikipedia_75": "wiki75", 70 "lorem_1k": "lor1k", 71 "lorem_10k": "lor10k", 72 "lorem_50k": "lor50k", 73 "lorem_100k": "lor100k", 74 "lorem_25": "lor25", 75 "lorem_50": "lor50", 76 "lorem_75": "lor75", 77 "glm-4.5-air": "glm45air", 78 "glm-4.7": "glm47", 79 "glm-5.1": "glm51", 80 "haiku-4.5": "haiku45", 81 "sonnet-4.6": "sonnet46", 82 "opus-4.6": "opus46", 83 "qwen-3.6-plus": "qwen36p", 84 "gemma-4-26b": "gemma426b", 85 "minimax-m2.7": "mmx27", 86 "kimi-k2.5": "kimi25", 87 "anthropic": "anth", 88 "openrouter": "or", 89 } 90 91 92 def load_grid(path): 93 with open(path) as f: 94 return yaml.safe_load(f) 95 96 97 def get_axes(grid, profile_name): 98 """Get axis values for a given profile. Raises if the profile is unknown.""" 99 top_axes = {name: spec["values"] for name, spec in grid["axes"].items()} 100 profiles = grid.get("profiles", {}) 101 102 if profile_name not in profiles: 103 raise ValueError( 104 f"unknown profile '{profile_name}'. Known profiles: {sorted(profiles.keys())}" 105 ) 106 107 profile = profiles[profile_name] 108 if "axes" not in profile: 109 # Profile intentionally omits axes (e.g. 'full') to use the full top-level grid. 110 return top_axes 111 112 # Profile axes override top-level axes 113 axes = dict(top_axes) 114 for name, values in profile["axes"].items(): 115 axes[name] = values 116 return axes 117 118 119 def get_runs_per_cell(grid, profile_name): 120 if profile_name in grid.get("profiles", {}): 121 profile = grid["profiles"][profile_name] 122 if "runs_per_cell" in profile: 123 return profile["runs_per_cell"] 124 return grid["defaults"]["runs_per_cell"] 125 126 127 def matches_exclusion(cell, exclusion): 128 """Check if a cell matches an exclusion rule.""" 129 for key, value in exclusion["when"].items(): 130 if cell.get(key) != value: 131 return False 132 return True 133 134 135 def apply_task_overrides(axes, task, grid): 136 """Apply task-specific axis overrides.""" 137 overrides = grid.get("task_overrides", {}).get(task, {}) 138 if not overrides or "axes" not in overrides: 139 return axes 140 141 result = dict(axes) 142 for axis_name, axis_spec in overrides["axes"].items(): 143 result[axis_name] = axis_spec["values"] 144 return result 145 146 147 def compute_cells(grid, profile_name): 148 """Yield one cell dict at a time. 149 150 Streams the cartesian product so peak memory stays at O(1 cell) regardless 151 of profile size. Callers that need a list should wrap with list(...). 152 """ 153 base_axes = get_axes(grid, profile_name) 154 runs_per_cell = get_runs_per_cell(grid, profile_name) 155 exclusions = grid.get("exclusions", []) 156 tasks = grid["tasks"] 157 defaults = grid["defaults"] 158 159 for task in tasks: 160 axes = apply_task_overrides(base_axes, task, grid) 161 axis_names = sorted(axes.keys()) 162 axis_values = [axes[name] for name in axis_names] 163 164 for combo in product(*axis_values): 165 cell = dict(zip(axis_names, combo)) 166 167 if any(matches_exclusion(cell, e) for e in exclusions): 168 continue 169 170 # actual_model = model (no mapping needed, models are their real names) 171 cell["actual_model"] = cell.get("model", "") 172 173 # Build cell ID from task + abbreviated axis values (deterministic, filename-safe) 174 cell_id_parts = [task] + [f"{AXIS_ABBREV.get(k, k)}={VALUE_ABBREV.get(str(cell[k]), cell[k])}" for k in axis_names] 175 cell["cell_id"] = "_".join(cell_id_parts) 176 177 cell["task"] = task 178 cell["runs_per_cell"] = runs_per_cell 179 cell["max_budget_usd"] = defaults["budget"].get(cell.get("max_budget", "low"), 0.50) 180 cell["timeout_seconds"] = defaults["timeout_seconds"] 181 182 yield cell 183 184 185 DESIGNS = ("main_effects", "plackett_burman", "interaction_hunt") 186 187 188 def main(): 189 if len(sys.argv) < 2: 190 print("Usage: compute_grid.py <grid_file> [profile|design] [design_args]", file=sys.stderr) 191 print(" interaction_hunt takes a 3rd arg: comma-separated axis names", file=sys.stderr) 192 sys.exit(1) 193 194 grid_file = sys.argv[1] 195 name = sys.argv[2] if len(sys.argv) > 2 else "smoke" 196 197 grid = load_grid(grid_file) 198 profiles = grid.get("profiles", {}) 199 200 if name in profiles: 201 for cell in compute_cells(grid, name): 202 print(json.dumps(cell)) 203 return 204 205 if name in DESIGNS: 206 from experiment_design import ( 207 main_effects_plan, 208 plackett_burman_plan, 209 interaction_hunt_plan, 210 ) 211 if name == "main_effects": 212 cells = main_effects_plan(grid) 213 elif name == "plackett_burman": 214 cells = plackett_burman_plan(grid) 215 else: # interaction_hunt 216 if len(sys.argv) < 4: 217 print("ERROR: interaction_hunt requires comma-separated axis names as 3rd arg", file=sys.stderr) 218 sys.exit(1) 219 cells = interaction_hunt_plan(grid, sys.argv[3].split(",")) 220 for cell in cells: 221 print(json.dumps(cell)) 222 return 223 224 known = sorted(profiles.keys()) + list(DESIGNS) 225 print(f"ERROR: unknown profile or design '{name}'. Known: {known}", file=sys.stderr) 226 sys.exit(1) 227 228 229 if __name__ == "__main__": 230 main()