ai-research-survey

Systematic scan of agentic development research. What's signal, what's noise.
git clone https://git.shiptheloop.com/ai-research-survey.git
Log | Files | Refs

classify-paper-type.py (6692B)


      1 #!/usr/bin/env python3
      2 """
      3 Preliminary paper type classification using Haiku.
      4 
      5 Reads existing scan.json (title + key_findings + methodology_tags) and
      6 classifies into: empirical, benchmark-creation, survey, position, theoretical.
      7 
      8 Writes result to papers/{slug}/paper_type.json (separate file, non-destructive).
      9 
     10 Usage:
     11     python3 scripts/classify-paper-type.py                    # All unclassified
     12     python3 scripts/classify-paper-type.py --limit 50         # First N
     13     python3 scripts/classify-paper-type.py --parallel 8       # Concurrent (Haiku is fast+cheap)
     14     python3 scripts/classify-paper-type.py --id metr-rct-2025 # Specific paper
     15     python3 scripts/classify-paper-type.py --force            # Re-classify all
     16 """
     17 
     18 import json
     19 import subprocess
     20 import sys
     21 from concurrent.futures import ThreadPoolExecutor, as_completed
     22 from pathlib import Path
     23 
     24 ROOT = Path(__file__).resolve().parent.parent
     25 PAPERS_DIR = ROOT / "papers"
     26 
     27 PROMPT = """Classify this research paper into exactly ONE category.
     28 
     29 Categories:
     30 1. **empirical** — runs experiments, reports quantitative results on benchmarks or datasets. The primary contribution is experimental findings.
     31 2. **benchmark-creation** — introduces a new benchmark, dataset, or evaluation framework. May run baselines, but the primary contribution is the benchmark itself.
     32 3. **survey** — reviews, surveys, or meta-analyzes existing work. Primary contribution is synthesis of the field.
     33 4. **position** — argues a viewpoint, proposes a conceptual framework, or makes prescriptive claims without experimental validation. Includes vision papers and opinion pieces.
     34 5. **theoretical** — proves something mathematically or analyzes properties formally. Primary contribution is theorems, proofs, or formal analysis.
     35 
     36 Paper information:
     37 Title: {title}
     38 Methodology tags: {tags}
     39 Key findings: {key_findings}
     40 
     41 Respond with ONLY a JSON object:
     42 {{"paper_type": "<one of: empirical, benchmark-creation, survey, position, theoretical>", "reason": "<one sentence>"}}"""
     43 
     44 
     45 def classify_one(paper_id, force=False):
     46     """Classify one paper. Returns (paper_id, type, reason) or (paper_id, None, error)."""
     47     scan_path = PAPERS_DIR / paper_id / "scan.json"
     48     type_path = PAPERS_DIR / paper_id / "paper_type.json"
     49 
     50     if type_path.exists() and not force:
     51         with open(type_path) as f:
     52             existing = json.load(f)
     53         return paper_id, existing.get("paper_type"), "already classified"
     54 
     55     if not scan_path.exists():
     56         return paper_id, None, "no scan.json"
     57 
     58     with open(scan_path) as f:
     59         scan = json.load(f)
     60 
     61     if scan.get("scan_version", 1) < 2:
     62         return paper_id, None, "v1 scan"
     63 
     64     paper = scan.get("paper", {})
     65     prompt = PROMPT.format(
     66         title=paper.get("title", ""),
     67         tags=", ".join(scan.get("methodology_tags", [])),
     68         key_findings=scan.get("key_findings", "")[:500],
     69     )
     70 
     71     try:
     72         result = subprocess.run(
     73             ["claude", "-p", "-", "--model", "haiku", "--max-turns", "1"],
     74             input=prompt,
     75             capture_output=True, text=True, timeout=30,
     76             cwd=str(ROOT),
     77         )
     78 
     79         if result.returncode != 0:
     80             return paper_id, None, f"claude exit {result.returncode}"
     81 
     82         output = result.stdout.strip()
     83         json_start = output.find("{")
     84         json_end = output.rfind("}") + 1
     85         if json_start == -1 or json_end == 0:
     86             return paper_id, None, "no JSON in output"
     87 
     88         parsed = json.loads(output[json_start:json_end])
     89         paper_type = parsed.get("paper_type", "")
     90         reason = parsed.get("reason", "")
     91 
     92         valid_types = ["empirical", "benchmark-creation", "survey", "position", "theoretical"]
     93         if paper_type not in valid_types:
     94             return paper_id, None, f"invalid type: {paper_type}"
     95 
     96         # Write separate file (non-destructive)
     97         with open(type_path, "w") as f:
     98             json.dump({"paper_type": paper_type, "reason": reason}, f, ensure_ascii=False, indent=2)
     99 
    100         return paper_id, paper_type, reason
    101 
    102     except json.JSONDecodeError as e:
    103         return paper_id, None, f"JSON parse error: {e}"
    104     except subprocess.TimeoutExpired:
    105         return paper_id, None, "timeout"
    106     except Exception as e:
    107         return paper_id, None, f"error: {e}"
    108 
    109 
    110 def main():
    111     args = sys.argv[1:]
    112     force = "--force" in args
    113     limit = None
    114     specific_id = None
    115     parallel = 1
    116 
    117     for i, arg in enumerate(args):
    118         if arg == "--limit" and i + 1 < len(args):
    119             limit = int(args[i + 1])
    120         if arg == "--id" and i + 1 < len(args):
    121             specific_id = args[i + 1]
    122         if arg == "--parallel" and i + 1 < len(args):
    123             parallel = int(args[i + 1])
    124 
    125     # Collect candidates
    126     candidates = []
    127     for scan_path in sorted(PAPERS_DIR.glob("*/scan.json")):
    128         pid = scan_path.parent.name
    129         if specific_id and pid != specific_id:
    130             continue
    131         with open(scan_path) as f:
    132             s = json.load(f)
    133         if s.get("scan_version", 1) < 2:
    134             continue
    135         type_path = scan_path.parent / "paper_type.json"
    136         if type_path.exists() and not force and not specific_id:
    137             continue
    138         candidates.append(pid)
    139 
    140     if limit:
    141         candidates = candidates[:limit]
    142 
    143     if not candidates:
    144         print("No papers to classify.")
    145         return
    146 
    147     print(f"Classifying {len(candidates)} papers"
    148           f"{f' (parallel={parallel})' if parallel > 1 else ''}:\n")
    149 
    150     from collections import Counter
    151     type_counts = Counter()
    152     failures = 0
    153 
    154     if parallel > 1:
    155         with ThreadPoolExecutor(max_workers=parallel) as executor:
    156             futures = {executor.submit(classify_one, pid, force): pid for pid in candidates}
    157             for future in as_completed(futures):
    158                 pid, ptype, reason = future.result()
    159                 if ptype:
    160                     type_counts[ptype] += 1
    161                 else:
    162                     failures += 1
    163                     print(f"  FAIL: {pid} — {reason}")
    164     else:
    165         for i, pid in enumerate(candidates):
    166             _, ptype, reason = classify_one(pid, force)
    167             if ptype:
    168                 type_counts[ptype] += 1
    169             else:
    170                 failures += 1
    171                 print(f"  FAIL: {pid} — {reason}")
    172             if (i + 1) % 50 == 0:
    173                 print(f"  ... {i+1}/{len(candidates)} done")
    174 
    175     total = sum(type_counts.values())
    176     print(f"\nDone. Classified: {total}, Failed: {failures}")
    177     print(f"Distribution:")
    178     for ptype, count in type_counts.most_common():
    179         print(f"  {ptype:20s} {count:>4d} ({count/total*100:.0f}%)")
    180 
    181 
    182 if __name__ == "__main__":
    183     main()

Impressum · Datenschutz