fit-weights.py (7185B)
1 #!/usr/bin/env python3 2 """ 3 Fit per-category weights for the rubric so that a small set of labeled 4 anchor papers score within their target bands. 5 6 Inputs: 7 scripts/calibration/anchors.yaml - Hand-labeled anchor set 8 papers/<id>/scan.json - Scan data for each anchor paper 9 10 Output: 11 scripts/calibration/weights.json - Learned weights + metadata 12 13 Usage: 14 python3 scripts/calibration/fit-weights.py 15 16 Apply via build-explorer-data.py: 17 If scripts/calibration/weights.json exists, compute_overall_score and 18 compute_category_score use its per-category weights. Otherwise they fall 19 back to uniform weights (current behavior). 20 """ 21 22 import json 23 import sys 24 from pathlib import Path 25 26 try: 27 import yaml 28 except ImportError: 29 sys.stderr.write( 30 "pyyaml not installed. Run: pip install pyyaml\n" 31 "Or port anchors.yaml to JSON and adjust this loader.\n" 32 ) 33 sys.exit(1) 34 35 try: 36 from scipy.optimize import minimize 37 import numpy as np 38 except ImportError: 39 sys.stderr.write( 40 "scipy and numpy required. Run: pip install scipy numpy\n" 41 ) 42 sys.exit(1) 43 44 45 ROOT = Path(__file__).resolve().parent.parent.parent 46 PAPERS_DIR = ROOT / "papers" 47 ANCHORS_PATH = Path(__file__).resolve().parent / "anchors.yaml" 48 OUT_PATH = Path(__file__).resolve().parent / "weights.json" 49 50 51 # Same 14 categories as build-explorer-data.py ALL_CATEGORIES order. 52 CATEGORIES = [ 53 "artifacts", 54 "statistical_methodology", 55 "evaluation_design", 56 "claims_and_evidence", 57 "setup_transparency", 58 "limitations_and_scope", 59 "data_integrity", 60 "conflicts_of_interest", 61 "contamination", 62 "human_studies", 63 "cost_and_practicality", 64 "experimental_rigor", 65 "data_leakage", 66 "survey_methodology", 67 ] 68 69 70 def load_anchors(): 71 with open(ANCHORS_PATH) as f: 72 data = yaml.safe_load(f) 73 return data 74 75 76 def load_scan(paper_id): 77 path = PAPERS_DIR / paper_id / "scan.json" 78 if not path.exists(): 79 return None 80 with open(path) as f: 81 return json.load(f) 82 83 84 def category_counts(checklist): 85 """Per-category (applicable, passed) counts.""" 86 result = {} 87 for cat in CATEGORIES: 88 data = checklist.get(cat, {}) 89 app = 0 90 pas = 0 91 if isinstance(data, dict): 92 for q in data.values(): 93 if isinstance(q, dict) and q.get("applies"): 94 app += 1 95 if q.get("answer"): 96 pas += 1 97 result[cat] = (app, pas) 98 return result 99 100 101 def score_with_weights(counts, weights): 102 """Weighted-mean category score. Categories with zero applicable questions 103 drop out cleanly (no zero-fill bias).""" 104 num = 0.0 105 den = 0.0 106 for cat, w in zip(CATEGORIES, weights): 107 app, pas = counts[cat] 108 if app == 0: 109 continue 110 cat_rate = pas / app 111 num += w * cat_rate 112 den += w 113 if den == 0: 114 return 0.0 115 return (num / den) * 100.0 116 117 118 def loss(weights, anchors_data, pairs_data, settings): 119 total = 0.0 120 pair_margin = settings.get("pair_margin", 20.0) 121 pair_penalty = settings.get("pair_penalty", 2.0) 122 l2 = settings.get("l2_reg", 0.1) 123 124 # Band fit loss 125 scores_by_id = {} 126 for anchor in anchors_data: 127 pid = anchor["id"] 128 band = anchor["band"] 129 counts = anchor["_counts"] 130 s = score_with_weights(counts, weights) 131 scores_by_id[pid] = s 132 lo, hi = band 133 target = (lo + hi) / 2 134 # Quadratic toward target 135 total += (s - target) ** 2 136 # Hinge at band boundaries for soft enforcement 137 if s < lo: 138 total += (lo - s) ** 2 * 0.5 139 if s > hi: 140 total += (s - hi) ** 2 * 0.5 141 142 # Pair ordering loss: first id should score < second id by >= pair_margin 143 for lo_id, hi_id in pairs_data: 144 if lo_id not in scores_by_id or hi_id not in scores_by_id: 145 continue 146 gap = scores_by_id[hi_id] - scores_by_id[lo_id] 147 if gap < pair_margin: 148 total += pair_penalty * (pair_margin - gap) ** 2 149 150 # L2 regularization toward uniform weights (reference is 1.0 per category) 151 for w in weights: 152 total += l2 * (w - 1.0) ** 2 153 154 return total 155 156 157 def main(): 158 data = load_anchors() 159 settings = data.get("settings", {}) 160 anchors = data.get("anchors", []) 161 pairs = data.get("pairs", []) 162 163 if len(anchors) < 5: 164 sys.stderr.write( 165 f"WARNING: only {len(anchors)} anchors labeled. Add more to " 166 f"anchors.yaml before trusting the fit (aim for 15+).\n" 167 ) 168 169 # Attach scan data to each anchor 170 for anchor in anchors: 171 pid = anchor["id"] 172 scan = load_scan(pid) 173 if scan is None: 174 sys.stderr.write(f"SKIP: no scan.json for {pid}\n") 175 anchor["_skip"] = True 176 continue 177 anchor["_counts"] = category_counts(scan.get("checklist", {})) 178 179 anchors = [a for a in anchors if not a.get("_skip")] 180 if not anchors: 181 sys.stderr.write("No usable anchors.\n") 182 sys.exit(1) 183 184 np.random.seed(settings.get("seed", 42)) 185 x0 = np.ones(len(CATEGORIES)) # Start at uniform weights 186 187 result = minimize( 188 loss, 189 x0, 190 args=(anchors, pairs, settings), 191 method="L-BFGS-B", 192 bounds=[(settings.get("min_weight", 0.0), settings.get("max_weight", 5.0))] * len(CATEGORIES), 193 options={"maxiter": 500}, 194 ) 195 196 weights = result.x.tolist() 197 weight_map = dict(zip(CATEGORIES, [round(w, 4) for w in weights])) 198 199 # Report 200 print("=" * 70) 201 print("LEARNED WEIGHTS") 202 print("=" * 70) 203 for cat, w in weight_map.items(): 204 bar = "#" * int(w * 10) 205 print(f" {cat:<28} {w:>6.3f} {bar}") 206 print() 207 208 print("=" * 70) 209 print("ANCHOR SCORES (predicted after fit)") 210 print("=" * 70) 211 for anchor in anchors: 212 pid = anchor["id"] 213 band = anchor["band"] 214 pred = score_with_weights(anchor["_counts"], weights) 215 in_band = "OK" if band[0] <= pred <= band[1] else "**" 216 print(f" {in_band} {pred:>6.1f} target {band[0]:>3}-{band[1]:<3} {pid}") 217 print() 218 219 print("=" * 70) 220 print("PAIR ORDERING CHECK") 221 print("=" * 70) 222 scores_by_id = { 223 a["id"]: score_with_weights(a["_counts"], weights) for a in anchors 224 } 225 for lo_id, hi_id in pairs: 226 if lo_id not in scores_by_id or hi_id not in scores_by_id: 227 print(f" SKIP {lo_id} < {hi_id} (missing scan)") 228 continue 229 gap = scores_by_id[hi_id] - scores_by_id[lo_id] 230 ok = "OK" if gap >= settings.get("pair_margin", 20.0) else "**" 231 print( 232 f" {ok} {scores_by_id[lo_id]:>5.1f} < {scores_by_id[hi_id]:<5.1f} " 233 f"(gap {gap:+.1f}) {lo_id} < {hi_id}" 234 ) 235 print() 236 237 out = { 238 "weights": weight_map, 239 "n_anchors": len(anchors), 240 "n_pairs": len(pairs), 241 "loss": float(result.fun), 242 "converged": bool(result.success), 243 "settings": settings, 244 } 245 with open(OUT_PATH, "w") as f: 246 json.dump(out, f, indent=2) 247 print(f"Wrote {OUT_PATH}") 248 249 250 if __name__ == "__main__": 251 main()