ai-research-survey

Systematic scan of agentic development research. What's signal, what's noise.
git clone https://git.shiptheloop.com/ai-research-survey.git
Log | Files | Refs

fit-weights.py (7185B)


      1 #!/usr/bin/env python3
      2 """
      3 Fit per-category weights for the rubric so that a small set of labeled
      4 anchor papers score within their target bands.
      5 
      6 Inputs:
      7   scripts/calibration/anchors.yaml     - Hand-labeled anchor set
      8   papers/<id>/scan.json                - Scan data for each anchor paper
      9 
     10 Output:
     11   scripts/calibration/weights.json     - Learned weights + metadata
     12 
     13 Usage:
     14   python3 scripts/calibration/fit-weights.py
     15 
     16 Apply via build-explorer-data.py:
     17   If scripts/calibration/weights.json exists, compute_overall_score and
     18   compute_category_score use its per-category weights. Otherwise they fall
     19   back to uniform weights (current behavior).
     20 """
     21 
     22 import json
     23 import sys
     24 from pathlib import Path
     25 
     26 try:
     27     import yaml
     28 except ImportError:
     29     sys.stderr.write(
     30         "pyyaml not installed. Run: pip install pyyaml\n"
     31         "Or port anchors.yaml to JSON and adjust this loader.\n"
     32     )
     33     sys.exit(1)
     34 
     35 try:
     36     from scipy.optimize import minimize
     37     import numpy as np
     38 except ImportError:
     39     sys.stderr.write(
     40         "scipy and numpy required. Run: pip install scipy numpy\n"
     41     )
     42     sys.exit(1)
     43 
     44 
     45 ROOT = Path(__file__).resolve().parent.parent.parent
     46 PAPERS_DIR = ROOT / "papers"
     47 ANCHORS_PATH = Path(__file__).resolve().parent / "anchors.yaml"
     48 OUT_PATH = Path(__file__).resolve().parent / "weights.json"
     49 
     50 
     51 # Same 14 categories as build-explorer-data.py ALL_CATEGORIES order.
     52 CATEGORIES = [
     53     "artifacts",
     54     "statistical_methodology",
     55     "evaluation_design",
     56     "claims_and_evidence",
     57     "setup_transparency",
     58     "limitations_and_scope",
     59     "data_integrity",
     60     "conflicts_of_interest",
     61     "contamination",
     62     "human_studies",
     63     "cost_and_practicality",
     64     "experimental_rigor",
     65     "data_leakage",
     66     "survey_methodology",
     67 ]
     68 
     69 
     70 def load_anchors():
     71     with open(ANCHORS_PATH) as f:
     72         data = yaml.safe_load(f)
     73     return data
     74 
     75 
     76 def load_scan(paper_id):
     77     path = PAPERS_DIR / paper_id / "scan.json"
     78     if not path.exists():
     79         return None
     80     with open(path) as f:
     81         return json.load(f)
     82 
     83 
     84 def category_counts(checklist):
     85     """Per-category (applicable, passed) counts."""
     86     result = {}
     87     for cat in CATEGORIES:
     88         data = checklist.get(cat, {})
     89         app = 0
     90         pas = 0
     91         if isinstance(data, dict):
     92             for q in data.values():
     93                 if isinstance(q, dict) and q.get("applies"):
     94                     app += 1
     95                     if q.get("answer"):
     96                         pas += 1
     97         result[cat] = (app, pas)
     98     return result
     99 
    100 
    101 def score_with_weights(counts, weights):
    102     """Weighted-mean category score. Categories with zero applicable questions
    103     drop out cleanly (no zero-fill bias)."""
    104     num = 0.0
    105     den = 0.0
    106     for cat, w in zip(CATEGORIES, weights):
    107         app, pas = counts[cat]
    108         if app == 0:
    109             continue
    110         cat_rate = pas / app
    111         num += w * cat_rate
    112         den += w
    113     if den == 0:
    114         return 0.0
    115     return (num / den) * 100.0
    116 
    117 
    118 def loss(weights, anchors_data, pairs_data, settings):
    119     total = 0.0
    120     pair_margin = settings.get("pair_margin", 20.0)
    121     pair_penalty = settings.get("pair_penalty", 2.0)
    122     l2 = settings.get("l2_reg", 0.1)
    123 
    124     # Band fit loss
    125     scores_by_id = {}
    126     for anchor in anchors_data:
    127         pid = anchor["id"]
    128         band = anchor["band"]
    129         counts = anchor["_counts"]
    130         s = score_with_weights(counts, weights)
    131         scores_by_id[pid] = s
    132         lo, hi = band
    133         target = (lo + hi) / 2
    134         # Quadratic toward target
    135         total += (s - target) ** 2
    136         # Hinge at band boundaries for soft enforcement
    137         if s < lo:
    138             total += (lo - s) ** 2 * 0.5
    139         if s > hi:
    140             total += (s - hi) ** 2 * 0.5
    141 
    142     # Pair ordering loss: first id should score < second id by >= pair_margin
    143     for lo_id, hi_id in pairs_data:
    144         if lo_id not in scores_by_id or hi_id not in scores_by_id:
    145             continue
    146         gap = scores_by_id[hi_id] - scores_by_id[lo_id]
    147         if gap < pair_margin:
    148             total += pair_penalty * (pair_margin - gap) ** 2
    149 
    150     # L2 regularization toward uniform weights (reference is 1.0 per category)
    151     for w in weights:
    152         total += l2 * (w - 1.0) ** 2
    153 
    154     return total
    155 
    156 
    157 def main():
    158     data = load_anchors()
    159     settings = data.get("settings", {})
    160     anchors = data.get("anchors", [])
    161     pairs = data.get("pairs", [])
    162 
    163     if len(anchors) < 5:
    164         sys.stderr.write(
    165             f"WARNING: only {len(anchors)} anchors labeled. Add more to "
    166             f"anchors.yaml before trusting the fit (aim for 15+).\n"
    167         )
    168 
    169     # Attach scan data to each anchor
    170     for anchor in anchors:
    171         pid = anchor["id"]
    172         scan = load_scan(pid)
    173         if scan is None:
    174             sys.stderr.write(f"SKIP: no scan.json for {pid}\n")
    175             anchor["_skip"] = True
    176             continue
    177         anchor["_counts"] = category_counts(scan.get("checklist", {}))
    178 
    179     anchors = [a for a in anchors if not a.get("_skip")]
    180     if not anchors:
    181         sys.stderr.write("No usable anchors.\n")
    182         sys.exit(1)
    183 
    184     np.random.seed(settings.get("seed", 42))
    185     x0 = np.ones(len(CATEGORIES))  # Start at uniform weights
    186 
    187     result = minimize(
    188         loss,
    189         x0,
    190         args=(anchors, pairs, settings),
    191         method="L-BFGS-B",
    192         bounds=[(settings.get("min_weight", 0.0), settings.get("max_weight", 5.0))] * len(CATEGORIES),
    193         options={"maxiter": 500},
    194     )
    195 
    196     weights = result.x.tolist()
    197     weight_map = dict(zip(CATEGORIES, [round(w, 4) for w in weights]))
    198 
    199     # Report
    200     print("=" * 70)
    201     print("LEARNED WEIGHTS")
    202     print("=" * 70)
    203     for cat, w in weight_map.items():
    204         bar = "#" * int(w * 10)
    205         print(f"  {cat:<28} {w:>6.3f}  {bar}")
    206     print()
    207 
    208     print("=" * 70)
    209     print("ANCHOR SCORES (predicted after fit)")
    210     print("=" * 70)
    211     for anchor in anchors:
    212         pid = anchor["id"]
    213         band = anchor["band"]
    214         pred = score_with_weights(anchor["_counts"], weights)
    215         in_band = "OK" if band[0] <= pred <= band[1] else "**"
    216         print(f"  {in_band} {pred:>6.1f}  target {band[0]:>3}-{band[1]:<3}  {pid}")
    217     print()
    218 
    219     print("=" * 70)
    220     print("PAIR ORDERING CHECK")
    221     print("=" * 70)
    222     scores_by_id = {
    223         a["id"]: score_with_weights(a["_counts"], weights) for a in anchors
    224     }
    225     for lo_id, hi_id in pairs:
    226         if lo_id not in scores_by_id or hi_id not in scores_by_id:
    227             print(f"  SKIP  {lo_id} < {hi_id}  (missing scan)")
    228             continue
    229         gap = scores_by_id[hi_id] - scores_by_id[lo_id]
    230         ok = "OK" if gap >= settings.get("pair_margin", 20.0) else "**"
    231         print(
    232             f"  {ok} {scores_by_id[lo_id]:>5.1f} < {scores_by_id[hi_id]:<5.1f} "
    233             f"(gap {gap:+.1f}) {lo_id}  <  {hi_id}"
    234         )
    235     print()
    236 
    237     out = {
    238         "weights": weight_map,
    239         "n_anchors": len(anchors),
    240         "n_pairs": len(pairs),
    241         "loss": float(result.fun),
    242         "converged": bool(result.success),
    243         "settings": settings,
    244     }
    245     with open(OUT_PATH, "w") as f:
    246         json.dump(out, f, indent=2)
    247     print(f"Wrote {OUT_PATH}")
    248 
    249 
    250 if __name__ == "__main__":
    251     main()

Impressum · Datenschutz