ScatterPlot.tsx (14834B)
1 import React from "react"; 2 import { getModelColor, getModelFillColor, modelSortOrder } from "../lib/colors"; 3 import { 4 ScatterChart, 5 Scatter, 6 XAxis, 7 YAxis, 8 ZAxis, 9 CartesianGrid, 10 ResponsiveContainer, 11 } from "recharts"; 12 import { useXAxisScale, useYAxisScale } from "recharts"; 13 import type { Run } from "../lib/types"; 14 import { groupIntoCells } from "../lib/analysis"; 15 import ModelSelector from "./ModelSelector"; 16 17 interface ScatterPlotProps { 18 runs: Run[]; 19 defaultX?: string; 20 defaultY?: string; 21 } 22 23 type CellMetricKey = 24 | "cost" 25 | "score" 26 | "turns" 27 | "wall_time" 28 | "gameplay" 29 | "quality" 30 | "code_quality" 31 | "structural" 32 | "sonarqube" 33 | "transcript"; 34 35 interface MetricDef { 36 label: string; 37 cellKey: CellMetricKey; 38 scale: number; 39 format: (v: number) => string; 40 } 41 42 const METRIC_CONFIG: Record<string, MetricDef> = { 43 cost: { 44 label: "Cost ($)", 45 cellKey: "cost", 46 scale: 1, 47 format: (v) => `$${v.toFixed(2)}`, 48 }, 49 outcome: { 50 label: "Outcome Score (%)", 51 cellKey: "score", 52 scale: 100, 53 format: (v) => `${v.toFixed(0)}%`, 54 }, 55 gameplay: { 56 label: "Gameplay (%)", 57 cellKey: "gameplay", 58 scale: 100, 59 format: (v) => `${v.toFixed(0)}%`, 60 }, 61 quality: { 62 label: "Quality (%)", 63 cellKey: "quality", 64 scale: 100, 65 format: (v) => `${v.toFixed(0)}%`, 66 }, 67 code_quality: { 68 label: "Code Quality (%)", 69 cellKey: "code_quality", 70 scale: 100, 71 format: (v) => `${v.toFixed(0)}%`, 72 }, 73 structural: { 74 label: "Structural (%)", 75 cellKey: "structural", 76 scale: 100, 77 format: (v) => `${v.toFixed(0)}%`, 78 }, 79 sonarqube: { 80 label: "SonarQube (%)", 81 cellKey: "sonarqube", 82 scale: 100, 83 format: (v) => `${v.toFixed(0)}%`, 84 }, 85 turns: { 86 label: "Turns", 87 cellKey: "turns", 88 scale: 1, 89 format: (v) => `${Math.round(v)}`, 90 }, 91 wall_time: { 92 label: "Time (s)", 93 cellKey: "wall_time", 94 scale: 1, 95 format: (v) => `${Math.round(v)}s`, 96 }, 97 transcript: { 98 label: "Transcript (%)", 99 cellKey: "transcript", 100 scale: 100, 101 format: (v) => `${v.toFixed(0)}%`, 102 }, 103 }; 104 105 const METRIC_OPTIONS = Object.entries(METRIC_CONFIG).map(([key, conf]) => ({ 106 value: key, 107 label: conf.label, 108 })); 109 110 function fallbackColor(model: string): string { 111 return getModelColor(model); 112 } 113 114 function fallbackFillPrefix(model: string): string { 115 return getModelFillColor(model); 116 } 117 118 // --- Convex hull (Andrew's monotone chain) --- 119 120 type Point = [number, number]; 121 122 function cross(o: Point, a: Point, b: Point): number { 123 return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0]); 124 } 125 126 function convexHull(points: Point[]): Point[] { 127 const sorted = [...points].sort((a, b) => a[0] - b[0] || a[1] - b[1]); 128 if (sorted.length <= 2) return sorted; 129 130 const lower: Point[] = []; 131 for (const p of sorted) { 132 while (lower.length >= 2 && cross(lower[lower.length - 2], lower[lower.length - 1], p) <= 0) 133 lower.pop(); 134 lower.push(p); 135 } 136 137 const upper: Point[] = []; 138 for (let i = sorted.length - 1; i >= 0; i--) { 139 const p = sorted[i]; 140 while (upper.length >= 2 && cross(upper[upper.length - 2], upper[upper.length - 1], p) <= 0) 141 upper.pop(); 142 upper.push(p); 143 } 144 145 return [...lower.slice(0, -1), ...upper.slice(0, -1)]; 146 } 147 148 interface DensityHull { 149 pct: number; // e.g. 100, 75, 50, 25 150 hull: Point[]; 151 opacity: number; 152 strokeOpacity: number; 153 } 154 155 interface ModelRegion { 156 model: string; 157 points: Point[]; 158 centroid: Point; 159 hulls: DensityHull[]; 160 n: number; 161 } 162 163 const DENSITY_LEVELS = [ 164 { pct: 100, opacity: 0.05, strokeOpacity: 0.20 }, 165 { pct: 75, opacity: 0.10, strokeOpacity: 0.25 }, 166 { pct: 50, opacity: 0.18, strokeOpacity: 0.35 }, 167 { pct: 25, opacity: 0.30, strokeOpacity: 0.50 }, 168 ]; 169 170 function computeRegions( 171 byModel: Record<string, { x: number; y: number }[]> 172 ): ModelRegion[] { 173 const regions: ModelRegion[] = []; 174 175 for (const [model, data] of Object.entries(byModel)) { 176 const points: Point[] = data.map((d) => [d.x, d.y]); 177 if (points.length === 0) continue; 178 179 const cx = points.reduce((s, p) => s + p[0], 0) / points.length; 180 const cy = points.reduce((s, p) => s + p[1], 0) / points.length; 181 const centroid: Point = [cx, cy]; 182 183 const byDist = [...points].sort((a, b) => { 184 const da = (a[0] - cx) ** 2 + (a[1] - cy) ** 2; 185 const db = (b[0] - cx) ** 2 + (b[1] - cy) ** 2; 186 return da - db; 187 }); 188 189 const hulls: DensityHull[] = []; 190 for (const level of DENSITY_LEVELS) { 191 const count = Math.max(1, Math.ceil(points.length * level.pct / 100)); 192 const subset = byDist.slice(0, count); 193 const hull = subset.length >= 3 ? convexHull(subset) : [...subset]; 194 hulls.push({ 195 pct: level.pct, 196 hull, 197 opacity: level.opacity, 198 strokeOpacity: level.strokeOpacity, 199 }); 200 } 201 202 regions.push({ model, points, centroid, hulls, n: points.length }); 203 } 204 205 // Model sort order from shared colors 206 return regions.sort((a, b) => modelSortOrder(a.model) - modelSortOrder(b.model)); 207 } 208 209 interface CentroidDatum { 210 model: string; 211 cx: number; 212 cy: number; 213 n: number; 214 xLabel: string; 215 yLabel: string; 216 xFormat: (v: number) => string; 217 yFormat: (v: number) => string; 218 } 219 220 function HullLayer({ 221 regions, 222 centroids, 223 setHover, 224 }: { 225 regions: ModelRegion[]; 226 centroids: CentroidDatum[]; 227 setHover: (c: CentroidDatum | null) => void; 228 }) { 229 const xScale = useXAxisScale(); 230 const yScale = useYAxisScale(); 231 232 if (!xScale || !yScale) return null; 233 234 const toSvg = (x: number, y: number): { sx: number; sy: number } => ({ 235 sx: xScale(x) ?? 0, 236 sy: yScale(y) ?? 0, 237 }); 238 239 const hullToPolygonPoints = (hull: Point[]): string => 240 hull 241 .map((p) => { 242 const { sx, sy } = toSvg(p[0], p[1]); 243 return `${sx},${sy}`; 244 }) 245 .join(" "); 246 247 return ( 248 <g> 249 {regions.map((region) => { 250 const color = fallbackColor(region.model); 251 const fillPrefix = fallbackFillPrefix(region.model); 252 253 return ( 254 <g key={region.model}> 255 {/* Density hulls - outer to inner */} 256 {region.hulls.map((dh) => 257 dh.hull.length >= 3 ? ( 258 <polygon 259 key={dh.pct} 260 points={hullToPolygonPoints(dh.hull)} 261 fill={`${fillPrefix} ${dh.opacity})`} 262 stroke={`${fillPrefix} ${dh.strokeOpacity})`} 263 strokeWidth={1} 264 /> 265 ) : null 266 )} 267 {/* For 1-2 point models, show individual dots */} 268 {region.points.length <= 2 && 269 region.points.map((p, i) => { 270 const { sx, sy } = toSvg(p[0], p[1]); 271 return ( 272 <circle 273 key={i} 274 cx={sx} 275 cy={sy} 276 r={5} 277 fill={color} 278 opacity={0.7} 279 /> 280 ); 281 })} 282 </g> 283 ); 284 })} 285 {/* Centroid dots rendered on top of all hulls */} 286 {regions.map((region) => { 287 const { sx, sy } = toSvg(region.centroid[0], region.centroid[1]); 288 const color = fallbackColor(region.model); 289 const centroidData = centroids.find((c) => c.model === region.model); 290 291 return ( 292 <circle 293 key={`centroid-${region.model}`} 294 cx={sx} 295 cy={sy} 296 r={8} 297 fill={color} 298 stroke="hsl(217 16% 15.5%)" 299 strokeWidth={2} 300 style={{ cursor: "pointer" }} 301 onMouseEnter={() => centroidData && setHover(centroidData)} 302 onMouseLeave={() => setHover(null)} 303 /> 304 ); 305 })} 306 </g> 307 ); 308 } 309 310 function CentroidTooltip({ data }: { data: CentroidDatum }) { 311 return ( 312 <div 313 style={{ 314 background: "hsl(217 16% 15.5%)", 315 border: "1px solid hsl(217 17% 28%)", 316 borderRadius: "2px", 317 fontFamily: "'JetBrains Mono', monospace", 318 fontSize: "11px", 319 padding: "8px 10px", 320 lineHeight: "1.6", 321 color: "hsl(213 14% 80%)", 322 position: "absolute", 323 pointerEvents: "none", 324 zIndex: 10, 325 top: 8, 326 right: 8, 327 }} 328 > 329 <div style={{ fontWeight: 600, marginBottom: 4 }}>{data.model}</div> 330 <div> 331 {data.xLabel}: {data.xFormat(data.cx)} 332 </div> 333 <div> 334 {data.yLabel}: {data.yFormat(data.cy)} 335 </div> 336 <div style={{ marginTop: 2, color: "hsl(213 14% 55%)" }}> 337 centroid of n={data.n} cell{data.n !== 1 ? "s" : ""}{data.n < 5 ? " (low n)" : ""} 338 </div> 339 </div> 340 ); 341 } 342 343 const selectStyle: React.CSSProperties = { 344 background: "hsl(217 16% 15.5%)", 345 color: "hsl(213 14% 80%)", 346 border: "1px solid hsl(217 17% 28%)", 347 borderRadius: "2px", 348 fontFamily: "'JetBrains Mono', monospace", 349 fontSize: "11px", 350 padding: "4px 6px", 351 cursor: "pointer", 352 }; 353 354 export default function ScatterPlot({ 355 runs, 356 defaultX = "cost", 357 defaultY = "outcome", 358 }: ScatterPlotProps) { 359 const [xMetric, setXMetric] = React.useState(defaultX); 360 const [yMetric, setYMetric] = React.useState(defaultY); 361 const [visibleModels, setVisibleModels] = React.useState<Set<string> | null>(null); 362 363 const xConf = METRIC_CONFIG[xMetric]; 364 const yConf = METRIC_CONFIG[yMetric]; 365 if (!xConf || !yConf) return null; 366 367 const cells = groupIntoCells(runs); 368 369 const byModel: Record<string, { x: number; y: number }[]> = {}; 370 let totalCells = 0; 371 372 for (const cell of cells) { 373 const xAgg = cell[xConf.cellKey]; 374 const yAgg = cell[yConf.cellKey]; 375 if (xAgg.avg === 0 && xAgg.min === 0 && xAgg.max === 0) continue; 376 if (yAgg.avg === 0 && yAgg.min === 0 && yAgg.max === 0) continue; 377 378 const xAvg = xAgg.avg * xConf.scale; 379 const yAvg = yAgg.avg * yConf.scale; 380 381 const model = cell.meta.model; 382 if (!byModel[model]) byModel[model] = []; 383 byModel[model].push({ x: xAvg, y: yAvg }); 384 totalCells++; 385 } 386 387 // All models present in data (stable order) 388 // Model sort order from shared colors 389 const allModels = Object.keys(byModel).sort( 390 (a, b) => modelSortOrder(a) - modelSortOrder(b) 391 ); 392 393 // Initialize visibleModels on first render or when models change 394 const effectiveVisible = visibleModels ?? new Set(allModels); 395 396 const handleModelChange = (models: Set<string>) => { 397 setVisibleModels(models); 398 }; 399 400 // Compute regions from ALL data (for stable axis domains) 401 const allRegions = computeRegions(byModel); 402 403 // Filter to visible models for rendering 404 const regions = allRegions.filter((r) => effectiveVisible.has(r.model)); 405 406 const centroids: CentroidDatum[] = regions.map((r) => ({ 407 model: r.model, 408 cx: r.centroid[0], 409 cy: r.centroid[1], 410 n: r.n, 411 xLabel: xConf.label, 412 yLabel: yConf.label, 413 xFormat: xConf.format, 414 yFormat: yConf.format, 415 })); 416 417 const [hovered, setHovered] = React.useState<CentroidDatum | null>(null); 418 419 // Compute axis domains with padding (from ALL data, not just visible) 420 const allX = allRegions.flatMap((r) => r.points.map((p) => p[0])); 421 const allY = allRegions.flatMap((r) => r.points.map((p) => p[1])); 422 const xMin = Math.min(...allX); 423 const xMax = Math.max(...allX); 424 const yMin = Math.min(...allY); 425 const yMax = Math.max(...allY); 426 const xPad = (xMax - xMin) * 0.08 || 1; 427 const yPad = (yMax - yMin) * 0.08 || 1; 428 429 return ( 430 <div className="card" style={{ position: "relative" }}> 431 <div style={{ display: "flex", alignItems: "center", gap: "8px", marginBottom: "16px", flexWrap: "wrap" }}> 432 <select 433 value={xMetric} 434 onChange={(e) => setXMetric(e.target.value)} 435 style={selectStyle} 436 > 437 {METRIC_OPTIONS.map((opt) => ( 438 <option key={opt.value} value={opt.value}>{opt.label}</option> 439 ))} 440 </select> 441 <span style={{ fontSize: "12px", color: "hsl(213 14% 55%)" }}>vs</span> 442 <select 443 value={yMetric} 444 onChange={(e) => setYMetric(e.target.value)} 445 style={selectStyle} 446 > 447 {METRIC_OPTIONS.map((opt) => ( 448 <option key={opt.value} value={opt.value}>{opt.label}</option> 449 ))} 450 </select> 451 <span 452 style={{ 453 fontSize: "10px", 454 fontFamily: "'JetBrains Mono', monospace", 455 fontWeight: 400, 456 color: "hsl(213 14% 55%)", 457 }} 458 > 459 (n={runs.length} runs across {totalCells} cells) 460 </span> 461 </div> 462 463 {/* Model toggles */} 464 <div style={{ marginBottom: "12px" }}> 465 <ModelSelector 466 allModels={allModels} 467 selectedModels={effectiveVisible} 468 onChange={handleModelChange} 469 /> 470 {/* Low-n warnings for models with <5 cells */} 471 {(() => { 472 const lowNModels = allModels.filter((m) => effectiveVisible.has(m) && (byModel[m]?.length ?? 0) < 5 && (byModel[m]?.length ?? 0) > 0); 473 if (lowNModels.length === 0) return null; 474 return ( 475 <div style={{ marginTop: "4px", textAlign: "center" }}> 476 {lowNModels.map((m) => ( 477 <span key={m} style={{ fontSize: "10px", fontFamily: "'JetBrains Mono', monospace", color: "var(--yellow, hsl(40 95% 64%))", marginRight: "12px" }}> 478 {m}: n={byModel[m]?.length ?? 0} cells (low n) 479 </span> 480 ))} 481 </div> 482 ); 483 })()} 484 </div> 485 486 {hovered && <CentroidTooltip data={hovered} />} 487 488 <ResponsiveContainer width="100%" height={350}> 489 <ScatterChart margin={{ top: 10, right: 20, bottom: 10, left: 10 }}> 490 <CartesianGrid 491 strokeDasharray="3 3" 492 stroke="hsl(217 17% 28%)" 493 /> 494 <XAxis 495 dataKey="x" 496 name={xConf.label} 497 type="number" 498 domain={[xMin - xPad, xMax + xPad]} 499 stroke="hsl(213 14% 65%)" 500 fontSize={11} 501 tickFormatter={(v) => xConf.format(v)} 502 /> 503 <YAxis 504 dataKey="y" 505 name={yConf.label} 506 type="number" 507 domain={[yMin - yPad, yMax + yPad]} 508 stroke="hsl(213 14% 65%)" 509 fontSize={11} 510 tickFormatter={(v) => yConf.format(v)} 511 /> 512 {/* Hidden scatter to seed axis scales with data */} 513 <ZAxis range={[0, 0]} /> 514 <Scatter 515 data={allRegions.flatMap((r) => r.points.map((p) => ({ x: p[0], y: p[1] })))} 516 fill="transparent" 517 isAnimationActive={false} 518 /> 519 <HullLayer 520 regions={regions} 521 centroids={centroids} 522 setHover={setHovered} 523 /> 524 </ScatterChart> 525 </ResponsiveContainer> 526 </div> 527 ); 528 }