PCAPlot.tsx (19306B)
1 import React, { useState, useMemo, useRef, useCallback } from "react"; 2 import { Canvas, type ThreeEvent } from "@react-three/fiber"; 3 import { OrbitControls, Html, Line } from "@react-three/drei"; 4 import * as THREE from "three"; 5 import { getModelColor, modelSortOrder } from "../lib/colors"; 6 7 interface PCAPoint { 8 run_id: string; 9 short_id: string; 10 model: string; 11 score: number; 12 pc1: number; 13 pc2: number; 14 pc3: number; 15 config_summary: string; 16 } 17 18 interface PCALoading { 19 feature: string; 20 axis: string; 21 pc1: number; 22 pc2: number; 23 pc3: number; 24 } 25 26 interface PCAAxisImportance { 27 axis: string; 28 pc1: number; 29 pc2: number; 30 pc3: number; 31 total: number; 32 } 33 34 interface PCAData { 35 n_runs: number; 36 n_features: number; 37 n_components: number; 38 variance_explained: number[]; 39 scree?: number[]; 40 points: PCAPoint[]; 41 loadings: PCALoading[]; 42 axis_importance: PCAAxisImportance[]; 43 } 44 45 interface PCAPlotProps { 46 data: PCAData; 47 } 48 49 type PCKey = "pc1" | "pc2" | "pc3" | "pc4" | "pc5" | "pc6" | "pc7" | "pc8" | "pc9" | "pc10"; 50 51 const AXIS_LABELS: Record<string, string> = { 52 model: "Model", 53 effort: "Effort", 54 prompt_style: "Prompt Style", 55 language: "Language", 56 human_language: "Human Language", 57 tool_read: "Tool: Read", 58 tool_write: "Tool: Write", 59 tool_edit: "Tool: Edit", 60 tool_glob: "Tool: Glob", 61 tool_grep: "Tool: Grep", 62 linter: "Linter", 63 playwright: "Playwright", 64 context_file: "Context File", 65 web_search: "Web Search", 66 max_budget: "Budget", 67 tests_provided: "Tests Provided", 68 strategy: "Strategy", 69 design_guidance: "Design Guidance", 70 architecture: "Architecture", 71 error_checking: "Error Checking", 72 context_noise: "Context Noise", 73 renderer: "Renderer", 74 provider: "Provider", 75 }; 76 77 // Convert HSL string to hex for Three.js 78 function hslToHex(hsl: string): string { 79 // Handle both "hsl(193 44% 67%)" and "hsl(193, 44%, 67%)" formats 80 const match = hsl.match(/hsl\((\d+)[,\s]+(\d+)%[,\s]+(\d+)%\)/); 81 if (!match) return "#888888"; 82 const h = parseInt(match[1]) / 360; 83 const s = parseInt(match[2]) / 100; 84 const l = parseInt(match[3]) / 100; 85 86 const a2 = s * Math.min(l, 1 - l); 87 const f = (n: number) => { 88 const k = (n + h * 12) % 12; 89 const color = l - a2 * Math.max(Math.min(k - 3, 9 - k, 1), -1); 90 return Math.round(255 * color) 91 .toString(16) 92 .padStart(2, "0"); 93 }; 94 return `#${f(0)}${f(8)}${f(4)}`; 95 } 96 97 interface NormalizedPoint extends PCAPoint { 98 nx: number; 99 ny: number; 100 nz: number; 101 radius: number; 102 color: string; 103 hexColor: string; 104 } 105 106 interface SceneProps { 107 points: NormalizedPoint[]; 108 varianceExplained: number[]; 109 onHover: (point: NormalizedPoint | null) => void; 110 hoveredPoint: NormalizedPoint | null; 111 } 112 113 function AxisLine({ 114 start, 115 end, 116 label, 117 }: { 118 start: [number, number, number]; 119 end: [number, number, number]; 120 label: string; 121 }) { 122 return ( 123 <> 124 <Line 125 points={[start, end]} 126 color="#5a6078" 127 lineWidth={1.5} 128 /> 129 <Html position={end} style={{ pointerEvents: "none" }}> 130 <div 131 style={{ 132 fontFamily: "'JetBrains Mono', monospace", 133 fontSize: "10px", 134 color: "hsl(213 14% 55%)", 135 whiteSpace: "nowrap", 136 userSelect: "none", 137 }} 138 > 139 {label} 140 </div> 141 </Html> 142 </> 143 ); 144 } 145 146 function GridLines() { 147 const gridColor = "#2d3045"; 148 const lines: Array<{ start: [number, number, number]; end: [number, number, number] }> = [ 149 // Lines at y=0, z=0 plane (along x) 150 { start: [-1.3, 0, 0], end: [1.3, 0, 0] }, 151 // Lines at x=0, z=0 plane (along y) 152 { start: [0, -1.3, 0], end: [0, 1.3, 0] }, 153 // Lines at x=0, y=0 plane (along z) 154 { start: [0, 0, -1.3], end: [0, 0, 1.3] }, 155 ]; 156 157 return ( 158 <> 159 {lines.map((line, i) => ( 160 <Line 161 key={i} 162 points={[line.start, line.end]} 163 color={gridColor} 164 lineWidth={0.5} 165 dashed 166 dashSize={0.05} 167 gapSize={0.05} 168 /> 169 ))} 170 </> 171 ); 172 } 173 174 function DataPoint({ 175 point, 176 onHover, 177 isHovered, 178 }: { 179 point: NormalizedPoint; 180 onHover: (point: NormalizedPoint | null) => void; 181 isHovered: boolean; 182 }) { 183 const meshRef = useRef<THREE.Mesh>(null); 184 185 const handlePointerOver = useCallback( 186 (e: ThreeEvent<PointerEvent>) => { 187 e.stopPropagation(); 188 onHover(point); 189 document.body.style.cursor = "pointer"; 190 }, 191 [point, onHover] 192 ); 193 194 const handlePointerOut = useCallback( 195 (e: ThreeEvent<PointerEvent>) => { 196 e.stopPropagation(); 197 onHover(null); 198 document.body.style.cursor = "auto"; 199 }, 200 [onHover] 201 ); 202 203 return ( 204 <mesh 205 ref={meshRef} 206 position={[point.nx, point.ny, point.nz]} 207 onPointerOver={handlePointerOver} 208 onPointerOut={handlePointerOut} 209 > 210 <sphereGeometry args={[isHovered ? point.radius * 1.4 : point.radius, 16, 16]} /> 211 <meshStandardMaterial 212 color={point.hexColor} 213 transparent 214 opacity={isHovered ? 1.0 : 0.7} 215 emissive={point.hexColor} 216 emissiveIntensity={isHovered ? 0.4 : 0.15} 217 /> 218 </mesh> 219 ); 220 } 221 222 function HoverTooltip({ point }: { point: NormalizedPoint }) { 223 const scorePct = Math.round(point.score * 100); 224 const scoreColor = 225 scorePct >= 70 226 ? "hsl(92 28% 65%)" 227 : scorePct >= 40 228 ? "hsl(40 71% 73%)" 229 : "hsl(355 52% 64%)"; 230 231 return ( 232 <Html 233 position={[point.nx, point.ny, point.nz]} 234 style={{ pointerEvents: "none" }} 235 center 236 > 237 <div 238 style={{ 239 background: "hsl(217 16% 15.5%)", 240 border: "1px solid hsl(217 17% 28%)", 241 fontFamily: "'JetBrains Mono', monospace", 242 fontSize: "11px", 243 padding: "8px 10px", 244 lineHeight: "1.6", 245 color: "hsl(213 14% 80%)", 246 maxWidth: 260, 247 whiteSpace: "nowrap", 248 transform: "translate(16px, -50%)", 249 pointerEvents: "none", 250 }} 251 > 252 <div style={{ display: "flex", justifyContent: "space-between", gap: 12 }}> 253 <span style={{ fontWeight: 600, color: point.color }}> 254 {point.model} 255 </span> 256 <span style={{ fontWeight: 600, color: scoreColor }}> 257 {scorePct}% 258 </span> 259 </div> 260 <div style={{ marginTop: 4, color: "hsl(213 14% 55%)", fontSize: "10px" }}> 261 {point.short_id} 262 </div> 263 {point.config_summary && ( 264 <div 265 style={{ 266 marginTop: 4, 267 fontSize: "10px", 268 color: "hsl(213 14% 65%)", 269 whiteSpace: "normal", 270 wordBreak: "break-word", 271 }} 272 > 273 {point.config_summary} 274 </div> 275 )} 276 </div> 277 </Html> 278 ); 279 } 280 281 function Scene({ points, varianceExplained, onHover, hoveredPoint }: SceneProps) { 282 return ( 283 <> 284 <ambientLight intensity={0.6} /> 285 <pointLight position={[5, 5, 5]} intensity={0.8} /> 286 <pointLight position={[-5, -5, -5]} intensity={0.3} /> 287 288 <OrbitControls 289 enableDamping 290 dampingFactor={0.1} 291 minDistance={1} 292 maxDistance={8} 293 /> 294 295 <GridLines /> 296 297 {/* Axis lines */} 298 <AxisLine 299 start={[0, 0, 0]} 300 end={[1.5, 0, 0]} 301 label={`PC1 (${varianceExplained[0]}%)`} 302 /> 303 <AxisLine 304 start={[0, 0, 0]} 305 end={[0, 1.5, 0]} 306 label={`PC2 (${varianceExplained[1]}%)`} 307 /> 308 <AxisLine 309 start={[0, 0, 0]} 310 end={[0, 0, 1.5]} 311 label={`PC3 (${varianceExplained[2]}%)`} 312 /> 313 314 {/* Data points */} 315 {points.map((pt) => ( 316 <DataPoint 317 key={pt.run_id} 318 point={pt} 319 onHover={onHover} 320 isHovered={hoveredPoint?.run_id === pt.run_id} 321 /> 322 ))} 323 324 {/* Tooltip on hover */} 325 {hoveredPoint && <HoverTooltip point={hoveredPoint} />} 326 </> 327 ); 328 } 329 330 function LoadingsTable({ 331 data, 332 pcKey, 333 varianceExplained, 334 }: { 335 data: PCAData; 336 pcKey: PCKey; 337 varianceExplained: number; 338 }) { 339 const pcIndex = parseInt(pcKey.replace("pc", "")) - 1; 340 const pcLabel = `PC${pcIndex + 1}`; 341 342 // Top 5 axes by importance for this PC 343 const topAxes = [...data.axis_importance] 344 .sort((a, b) => (b[pcKey] as number) - (a[pcKey] as number)) 345 .slice(0, 5); 346 347 // For each top axis, find the most significant feature loadings 348 const axisDetails = topAxes.map((axEntry) => { 349 const axisLoadings = data.loadings 350 .filter((l) => l.axis === axEntry.axis) 351 .sort((a, b) => Math.abs(b[pcKey] as number) - Math.abs(a[pcKey] as number)); 352 return { 353 axis: axEntry.axis, 354 label: AXIS_LABELS[axEntry.axis] || axEntry.axis, 355 importance: axEntry[pcKey] as number, 356 topFeatures: axisLoadings.slice(0, 3), 357 }; 358 }); 359 360 return ( 361 <div> 362 <div 363 style={{ 364 fontSize: "11px", 365 fontWeight: 600, 366 marginBottom: 6, 367 color: "hsl(213 14% 80%)", 368 }} 369 > 370 {pcLabel}{" "} 371 <span style={{ fontWeight: 400, color: "hsl(213 14% 55%)" }}> 372 ({varianceExplained.toFixed(1)}% variance) 373 </span> 374 </div> 375 <table 376 style={{ 377 width: "100%", 378 borderCollapse: "collapse", 379 fontSize: "11px", 380 fontFamily: "'JetBrains Mono', monospace", 381 }} 382 > 383 <thead> 384 <tr 385 style={{ 386 borderBottom: "1px solid hsl(217 17% 28%)", 387 color: "hsl(213 14% 55%)", 388 textAlign: "left", 389 }} 390 > 391 <th style={{ padding: "4px 8px 4px 0", fontWeight: 500 }}>Axis</th> 392 <th style={{ padding: "4px 8px", fontWeight: 500, textAlign: "right" }}> 393 Weight 394 </th> 395 <th style={{ padding: "4px 0 4px 8px", fontWeight: 500 }}> 396 Top Contributors 397 </th> 398 </tr> 399 </thead> 400 <tbody> 401 {axisDetails.map((ax) => ( 402 <tr 403 key={ax.axis} 404 style={{ borderBottom: "1px solid hsl(217 17% 22%)" }} 405 > 406 <td style={{ padding: "4px 8px 4px 0", color: "hsl(213 14% 80%)" }}> 407 {ax.label} 408 </td> 409 <td 410 style={{ 411 padding: "4px 8px", 412 textAlign: "right", 413 fontWeight: 600, 414 color: "hsl(193 44% 67%)", 415 }} 416 > 417 {ax.importance.toFixed(3)} 418 </td> 419 <td 420 style={{ 421 padding: "4px 0 4px 8px", 422 color: "hsl(213 14% 55%)", 423 fontSize: "10px", 424 }} 425 > 426 {ax.topFeatures.map((f, i) => { 427 const val = f[pcKey] as number; 428 const color = 429 val > 0 ? "hsl(92 28% 65%)" : "hsl(355 52% 64%)"; 430 return ( 431 <span key={f.feature}> 432 {i > 0 && ", "} 433 <span style={{ color }}> 434 {val > 0 ? "+" : ""} 435 {val.toFixed(3)} 436 </span>{" "} 437 {f.feature.replace(`${ax.axis}_`, "")} 438 </span> 439 ); 440 })} 441 </td> 442 </tr> 443 ))} 444 </tbody> 445 </table> 446 </div> 447 ); 448 } 449 450 export default function PCAPlot({ data }: PCAPlotProps) { 451 const [hoveredPoint, setHoveredPoint] = useState<NormalizedPoint | null>(null); 452 453 // Normalize points to -1..1 range per axis and compute radii 454 const { normalizedPoints, modelGroups } = useMemo(() => { 455 const pc1Vals = data.points.map((p) => Math.abs(p.pc1)); 456 const pc2Vals = data.points.map((p) => Math.abs(p.pc2)); 457 const pc3Vals = data.points.map((p) => Math.abs(p.pc3)); 458 const maxPc1 = Math.max(...pc1Vals) || 1; 459 const maxPc2 = Math.max(...pc2Vals) || 1; 460 const maxPc3 = Math.max(...pc3Vals) || 1; 461 462 const scores = data.points.map((p) => p.score); 463 const minScore = Math.min(...scores); 464 const maxScore = Math.max(...scores); 465 const scoreRange = maxScore - minScore || 1; 466 467 const pts: NormalizedPoint[] = data.points.map((p) => { 468 const color = getModelColor(p.model); 469 const t = (p.score - minScore) / scoreRange; 470 return { 471 ...p, 472 nx: (p.pc1 / maxPc1) * 2.5, 473 ny: (p.pc2 / maxPc2) * 2.5, 474 nz: (p.pc3 / maxPc3) * 2.5, 475 radius: 0.04 + t * 0.06, 476 color, 477 hexColor: hslToHex(color), 478 }; 479 }); 480 481 // Group by model 482 const groups: Record<string, NormalizedPoint[]> = {}; 483 for (const pt of pts) { 484 (groups[pt.model] ??= []).push(pt); 485 } 486 const sorted = Object.entries(groups).sort( 487 ([a], [b]) => modelSortOrder(a) - modelSortOrder(b) 488 ); 489 490 return { normalizedPoints: pts, modelGroups: sorted }; 491 }, [data.points]); 492 493 const handleHover = useCallback((point: NormalizedPoint | null) => { 494 setHoveredPoint(point); 495 }, []); 496 497 return ( 498 <div style={{ display: "flex", flexDirection: "column", gap: 24 }}> 499 {/* 3D scatter plot card */} 500 <div className="card" style={{ position: "relative" }}> 501 <div 502 style={{ 503 display: "flex", 504 alignItems: "center", 505 gap: 8, 506 marginBottom: 12, 507 flexWrap: "wrap", 508 }} 509 > 510 <span 511 style={{ 512 fontSize: "12px", 513 fontWeight: 400, 514 color: "hsl(213 14% 55%)", 515 }} 516 > 517 {data.n_runs} runs, {data.n_features} features 518 </span> 519 <span 520 style={{ 521 fontSize: "10px", 522 color: "hsl(213 14% 45%)", 523 marginLeft: "auto", 524 }} 525 > 526 Drag to rotate. Scroll to zoom. Point size proportional to score. 527 </span> 528 </div> 529 530 {/* 3D Canvas */} 531 <div style={{ width: "100%", height: 500, background: "#1a1d27" }}> 532 <Canvas 533 camera={{ position: [2.5, 2, 2.5], near: 0.1, far: 100 }} 534 style={{ background: "#1a1d27" }} 535 gl={{ antialias: true }} 536 > 537 <Scene 538 points={normalizedPoints} 539 varianceExplained={data.variance_explained} 540 onHover={handleHover} 541 hoveredPoint={hoveredPoint} 542 /> 543 </Canvas> 544 </div> 545 546 {/* Legend */} 547 <div 548 style={{ 549 display: "flex", 550 gap: 12, 551 justifyContent: "center", 552 marginTop: 12, 553 flexWrap: "wrap", 554 }} 555 > 556 {modelGroups.map(([model, pts]) => ( 557 <div 558 key={model} 559 style={{ 560 display: "flex", 561 alignItems: "center", 562 gap: 4, 563 fontSize: "11px", 564 fontFamily: "'JetBrains Mono', monospace", 565 }} 566 > 567 <div 568 style={{ 569 width: 8, 570 height: 8, 571 background: getModelColor(model), 572 }} 573 /> 574 <span style={{ color: getModelColor(model) }}> 575 {model} 576 </span> 577 <span style={{ color: "hsl(213 14% 45%)", fontSize: "10px" }}> 578 ({pts.length}) 579 </span> 580 </div> 581 ))} 582 </div> 583 </div> 584 585 {/* Scree plot */} 586 {data.scree && data.scree.length > 3 && ( 587 <div className="card"> 588 <h3 style={{ fontSize: "13px", fontWeight: 600, marginBottom: 12, color: "hsl(213 14% 80%)", textTransform: "uppercase", letterSpacing: "0.5px" }}> 589 Variance by Component (Scree Plot) 590 </h3> 591 <div style={{ display: "flex", alignItems: "flex-end", gap: 3, height: 200 }}> 592 {data.scree.slice(0, 20).map((v, i) => { 593 const cumulative = data.scree.slice(0, i + 1).reduce((a, b) => a + b, 0); 594 const maxV = data.scree[0]; 595 const barHeight = Math.max(4, (v / maxV) * 180); 596 const isUsed = i < 3; 597 const isExplained = i < 10; 598 return ( 599 <div key={i} style={{ flex: 1, display: "flex", flexDirection: "column", alignItems: "center", gap: 2 }}> 600 <span style={{ fontSize: "9px", fontFamily: "'JetBrains Mono', monospace", fontWeight: 600, color: isUsed ? "hsl(193 44% 67%)" : "hsl(213 14% 55%)" }}> 601 {v.toFixed(1)}% 602 </span> 603 <div 604 style={{ 605 width: "100%", 606 height: `${barHeight}px`, 607 background: isUsed ? "hsl(193 44% 67%)" : isExplained ? "hsl(216 15% 40%)" : "hsl(216 15% 25%)", 608 }} 609 title={`PC${i + 1}: ${v.toFixed(1)}% (cumulative: ${cumulative.toFixed(1)}%)`} 610 /> 611 <span style={{ fontSize: "9px", fontFamily: "'JetBrains Mono', monospace", color: isUsed ? "hsl(193 44% 67%)" : "hsl(213 14% 45%)" }}> 612 PC{i + 1} 613 </span> 614 </div> 615 ); 616 })} 617 </div> 618 <div style={{ fontSize: "10px", color: "hsl(213 14% 55%)", marginTop: 8, textAlign: "center", fontFamily: "'JetBrains Mono', monospace" }}> 619 3D view (cyan): {data.variance_explained.slice(0, 3).reduce((a, b) => a + b, 0).toFixed(1)}% | First 10: {data.scree.slice(0, 10).reduce((a, b) => a + b, 0).toFixed(1)}% | All {data.scree.length} components: {data.scree.reduce((a, b) => a + b, 0).toFixed(1)}% 620 </div> 621 </div> 622 )} 623 624 {/* Loadings interpretation card */} 625 <div className="card"> 626 <h3 627 style={{ 628 fontSize: "13px", 629 fontWeight: 600, 630 marginBottom: 16, 631 color: "hsl(213 14% 80%)", 632 textTransform: "uppercase", 633 letterSpacing: "0.5px", 634 }} 635 > 636 What do these dimensions mean? 637 </h3> 638 <p 639 style={{ 640 fontSize: "11px", 641 color: "hsl(213 14% 55%)", 642 marginBottom: 16, 643 lineHeight: 1.6, 644 }} 645 > 646 Each principal component is a weighted combination of all configuration axes. 647 Higher weight means that axis contributes more to the variance in that dimension. 648 Green/red values show the direction of influence. 649 </p> 650 <div 651 style={{ 652 display: "grid", 653 gridTemplateColumns: "repeat(auto-fit, minmax(320px, 1fr))", 654 gap: 16, 655 }} 656 > 657 {Array.from({ length: Math.min(data.n_components, 10) }, (_, i) => `pc${i + 1}` as PCKey) 658 .map((pc, idx) => ( 659 <LoadingsTable 660 key={pc} 661 data={data} 662 pcKey={pc} 663 varianceExplained={data.variance_explained[idx]} 664 /> 665 ))} 666 </div> 667 </div> 668 669 {/* Variance Explained section removed -- scree plot covers this */} 670 </div> 671 ); 672 }