classify-paper-type.py (6692B)
1 #!/usr/bin/env python3 2 """ 3 Preliminary paper type classification using Haiku. 4 5 Reads existing scan.json (title + key_findings + methodology_tags) and 6 classifies into: empirical, benchmark-creation, survey, position, theoretical. 7 8 Writes result to papers/{slug}/paper_type.json (separate file, non-destructive). 9 10 Usage: 11 python3 scripts/classify-paper-type.py # All unclassified 12 python3 scripts/classify-paper-type.py --limit 50 # First N 13 python3 scripts/classify-paper-type.py --parallel 8 # Concurrent (Haiku is fast+cheap) 14 python3 scripts/classify-paper-type.py --id metr-rct-2025 # Specific paper 15 python3 scripts/classify-paper-type.py --force # Re-classify all 16 """ 17 18 import json 19 import subprocess 20 import sys 21 from concurrent.futures import ThreadPoolExecutor, as_completed 22 from pathlib import Path 23 24 ROOT = Path(__file__).resolve().parent.parent 25 PAPERS_DIR = ROOT / "papers" 26 27 PROMPT = """Classify this research paper into exactly ONE category. 28 29 Categories: 30 1. **empirical** — runs experiments, reports quantitative results on benchmarks or datasets. The primary contribution is experimental findings. 31 2. **benchmark-creation** — introduces a new benchmark, dataset, or evaluation framework. May run baselines, but the primary contribution is the benchmark itself. 32 3. **survey** — reviews, surveys, or meta-analyzes existing work. Primary contribution is synthesis of the field. 33 4. **position** — argues a viewpoint, proposes a conceptual framework, or makes prescriptive claims without experimental validation. Includes vision papers and opinion pieces. 34 5. **theoretical** — proves something mathematically or analyzes properties formally. Primary contribution is theorems, proofs, or formal analysis. 35 36 Paper information: 37 Title: {title} 38 Methodology tags: {tags} 39 Key findings: {key_findings} 40 41 Respond with ONLY a JSON object: 42 {{"paper_type": "<one of: empirical, benchmark-creation, survey, position, theoretical>", "reason": "<one sentence>"}}""" 43 44 45 def classify_one(paper_id, force=False): 46 """Classify one paper. Returns (paper_id, type, reason) or (paper_id, None, error).""" 47 scan_path = PAPERS_DIR / paper_id / "scan.json" 48 type_path = PAPERS_DIR / paper_id / "paper_type.json" 49 50 if type_path.exists() and not force: 51 with open(type_path) as f: 52 existing = json.load(f) 53 return paper_id, existing.get("paper_type"), "already classified" 54 55 if not scan_path.exists(): 56 return paper_id, None, "no scan.json" 57 58 with open(scan_path) as f: 59 scan = json.load(f) 60 61 if scan.get("scan_version", 1) < 2: 62 return paper_id, None, "v1 scan" 63 64 paper = scan.get("paper", {}) 65 prompt = PROMPT.format( 66 title=paper.get("title", ""), 67 tags=", ".join(scan.get("methodology_tags", [])), 68 key_findings=scan.get("key_findings", "")[:500], 69 ) 70 71 try: 72 result = subprocess.run( 73 ["claude", "-p", "-", "--model", "haiku", "--max-turns", "1"], 74 input=prompt, 75 capture_output=True, text=True, timeout=30, 76 cwd=str(ROOT), 77 ) 78 79 if result.returncode != 0: 80 return paper_id, None, f"claude exit {result.returncode}" 81 82 output = result.stdout.strip() 83 json_start = output.find("{") 84 json_end = output.rfind("}") + 1 85 if json_start == -1 or json_end == 0: 86 return paper_id, None, "no JSON in output" 87 88 parsed = json.loads(output[json_start:json_end]) 89 paper_type = parsed.get("paper_type", "") 90 reason = parsed.get("reason", "") 91 92 valid_types = ["empirical", "benchmark-creation", "survey", "position", "theoretical"] 93 if paper_type not in valid_types: 94 return paper_id, None, f"invalid type: {paper_type}" 95 96 # Write separate file (non-destructive) 97 with open(type_path, "w") as f: 98 json.dump({"paper_type": paper_type, "reason": reason}, f, ensure_ascii=False, indent=2) 99 100 return paper_id, paper_type, reason 101 102 except json.JSONDecodeError as e: 103 return paper_id, None, f"JSON parse error: {e}" 104 except subprocess.TimeoutExpired: 105 return paper_id, None, "timeout" 106 except Exception as e: 107 return paper_id, None, f"error: {e}" 108 109 110 def main(): 111 args = sys.argv[1:] 112 force = "--force" in args 113 limit = None 114 specific_id = None 115 parallel = 1 116 117 for i, arg in enumerate(args): 118 if arg == "--limit" and i + 1 < len(args): 119 limit = int(args[i + 1]) 120 if arg == "--id" and i + 1 < len(args): 121 specific_id = args[i + 1] 122 if arg == "--parallel" and i + 1 < len(args): 123 parallel = int(args[i + 1]) 124 125 # Collect candidates 126 candidates = [] 127 for scan_path in sorted(PAPERS_DIR.glob("*/scan.json")): 128 pid = scan_path.parent.name 129 if specific_id and pid != specific_id: 130 continue 131 with open(scan_path) as f: 132 s = json.load(f) 133 if s.get("scan_version", 1) < 2: 134 continue 135 type_path = scan_path.parent / "paper_type.json" 136 if type_path.exists() and not force and not specific_id: 137 continue 138 candidates.append(pid) 139 140 if limit: 141 candidates = candidates[:limit] 142 143 if not candidates: 144 print("No papers to classify.") 145 return 146 147 print(f"Classifying {len(candidates)} papers" 148 f"{f' (parallel={parallel})' if parallel > 1 else ''}:\n") 149 150 from collections import Counter 151 type_counts = Counter() 152 failures = 0 153 154 if parallel > 1: 155 with ThreadPoolExecutor(max_workers=parallel) as executor: 156 futures = {executor.submit(classify_one, pid, force): pid for pid in candidates} 157 for future in as_completed(futures): 158 pid, ptype, reason = future.result() 159 if ptype: 160 type_counts[ptype] += 1 161 else: 162 failures += 1 163 print(f" FAIL: {pid} — {reason}") 164 else: 165 for i, pid in enumerate(candidates): 166 _, ptype, reason = classify_one(pid, force) 167 if ptype: 168 type_counts[ptype] += 1 169 else: 170 failures += 1 171 print(f" FAIL: {pid} — {reason}") 172 if (i + 1) % 50 == 0: 173 print(f" ... {i+1}/{len(candidates)} done") 174 175 total = sum(type_counts.values()) 176 print(f"\nDone. Classified: {total}, Failed: {failures}") 177 print(f"Distribution:") 178 for ptype, count in type_counts.most_common(): 179 print(f" {ptype:20s} {count:>4d} ({count/total*100:.0f}%)") 180 181 182 if __name__ == "__main__": 183 main()