VariabilityViolin.tsx (10139B)
1 import { useMemo } from "react"; 2 import { getModelColor, modelSortOrder } from "../lib/colors"; 3 import { 4 ComposedChart, 5 XAxis, 6 YAxis, 7 CartesianGrid, 8 Tooltip, 9 ResponsiveContainer, 10 Scatter, 11 Cell, 12 ZAxis, 13 Bar, 14 } from "recharts"; 15 import type { Run } from "../lib/types"; 16 import { groupIntoCells } from "../lib/analysis"; 17 18 interface VariabilityViolinProps { 19 runs: Run[]; 20 } 21 22 const SMUI = { 23 surface0: "hsl(213 16% 12%)", 24 surface1: "hsl(217 16% 15.5%)", 25 surface2: "hsl(216 15% 19%)", 26 border: "hsl(217 17% 28%)", 27 muted: "hsl(213 14% 65%)", 28 frost1: "hsl(176 25% 65%)", 29 frost2: "hsl(193 44% 67%)", 30 frost3: "hsl(210 34% 63%)", 31 frost4: "hsl(213 32% 52%)", 32 green: "hsl(92 28% 65%)", 33 yellow: "hsl(40 71% 73%)", 34 red: "hsl(355 52% 64%)", 35 purple: "hsl(311 24% 63%)", 36 }; 37 38 // Colors and sort order from shared palette 39 40 const TOOLTIP_STYLE: React.CSSProperties = { 41 background: SMUI.surface1, 42 border: `1px solid ${SMUI.border}`, 43 borderRadius: "0", 44 fontFamily: "'JetBrains Mono', monospace", 45 fontSize: "11px", 46 padding: "8px 12px", 47 }; 48 49 interface ModelCVData { 50 label: string; 51 model: string; 52 cvValues: number[]; 53 min: number; 54 q1: number; 55 median: number; 56 q3: number; 57 max: number; 58 base: number; 59 iqr: number; 60 color: string; 61 cellCount: number; 62 } 63 64 interface ScatterPoint { 65 label: string; 66 cv: number; 67 color: string; 68 jitter: number; 69 } 70 71 function quantile(sorted: number[], q: number): number { 72 if (sorted.length === 0) return 0; 73 if (sorted.length === 1) return sorted[0]; 74 const pos = q * (sorted.length - 1); 75 const lo = Math.floor(pos); 76 const hi = Math.ceil(pos); 77 if (lo === hi) return sorted[lo]; 78 return sorted[lo] + (pos - lo) * (sorted[hi] - sorted[lo]); 79 } 80 81 function computeModelCV(runs: Run[]): ModelCVData[] { 82 const cells = groupIntoCells(runs); 83 84 const byModel: Record<string, number[]> = {}; 85 86 for (const cell of cells) { 87 if (cell.n < 2) continue; 88 const scores = cell.runs 89 .map((r) => r.eval_results?.score) 90 .filter((s): s is number => s != null); 91 if (scores.length < 2) continue; 92 const mean = scores.reduce((a, b) => a + b, 0) / scores.length; 93 if (mean === 0) continue; 94 const stdDev = Math.sqrt( 95 scores.reduce((s, v) => s + (v - mean) ** 2, 0) / scores.length 96 ); 97 const cv = (stdDev / mean) * 100; 98 const model = cell.meta.actual_model || cell.meta.model; 99 (byModel[model] ??= []).push(cv); 100 } 101 102 const sortedEntries = Object.entries(byModel).sort( 103 ([a], [b]) => 104 modelSortOrder(a) - modelSortOrder(b) || a.localeCompare(b) 105 ); 106 107 return sortedEntries.map(([model, cvs]) => { 108 const sorted = [...cvs].sort((a, b) => a - b); 109 const q1 = quantile(sorted, 0.25); 110 const q3 = quantile(sorted, 0.75); 111 return { 112 label: `${model}|(n=${cvs.length})`, 113 model, 114 cvValues: cvs, 115 min: sorted[0], 116 q1, 117 median: quantile(sorted, 0.5), 118 q3, 119 max: sorted[sorted.length - 1], 120 base: q1, 121 iqr: q3 - q1, 122 color: getModelColor(model), 123 cellCount: cvs.length, 124 }; 125 }); 126 } 127 128 function buildScatterData(data: ModelCVData[]): ScatterPoint[] { 129 const points: ScatterPoint[] = []; 130 // Deterministic jitter based on index within each model group 131 for (const d of data) { 132 const sorted = [...d.cvValues].sort((a, b) => a - b); 133 for (let i = 0; i < sorted.length; i++) { 134 // Simple deterministic jitter: alternate sides, proportional to index 135 const side = i % 2 === 0 ? 1 : -1; 136 const magnitude = ((i % 5) + 1) * 0.06; 137 points.push({ 138 label: d.label, 139 cv: sorted[i], 140 color: d.color, 141 jitter: side * magnitude, 142 }); 143 } 144 } 145 return points; 146 } 147 148 // Custom shape for the box + whiskers 149 function CVBoxPlotShape(props: any) { 150 const { x, y, width, height, payload } = props as { 151 x: number; 152 y: number; 153 width: number; 154 height: number; 155 payload: ModelCVData; 156 }; 157 if (!payload || height === undefined) return null; 158 159 const { min, median, max, q1, q3, color } = payload; 160 const boxTop = y; 161 const boxBottom = y + height; 162 const centerX = x + width / 2; 163 164 const dataToY = (val: number): number => { 165 if (q3 === q1) return boxTop; 166 return boxTop + ((q3 - val) / (q3 - q1)) * (boxBottom - boxTop); 167 }; 168 169 const minY = dataToY(min); 170 const maxY = dataToY(max); 171 const medianY = dataToY(median); 172 const whiskerHalfW = width * 0.3; 173 174 return ( 175 <g> 176 {/* Whisker line: min to max */} 177 <line 178 x1={centerX} 179 y1={minY} 180 x2={centerX} 181 y2={maxY} 182 stroke={SMUI.muted} 183 strokeWidth={1} 184 /> 185 {/* Min whisker cap */} 186 <line 187 x1={centerX - whiskerHalfW} 188 y1={minY} 189 x2={centerX + whiskerHalfW} 190 y2={minY} 191 stroke={SMUI.muted} 192 strokeWidth={1} 193 /> 194 {/* Max whisker cap */} 195 <line 196 x1={centerX - whiskerHalfW} 197 y1={maxY} 198 x2={centerX + whiskerHalfW} 199 y2={maxY} 200 stroke={SMUI.muted} 201 strokeWidth={1} 202 /> 203 {/* Box (IQR) */} 204 <rect 205 x={x} 206 y={boxTop} 207 width={width} 208 height={Math.max(height, 1)} 209 fill={color} 210 fillOpacity={0.3} 211 stroke={color} 212 strokeWidth={1} 213 /> 214 {/* Median line */} 215 <line 216 x1={x} 217 y1={medianY} 218 x2={x + width} 219 y2={medianY} 220 stroke={color} 221 strokeWidth={2} 222 /> 223 </g> 224 ); 225 } 226 227 function CVTooltipContent({ 228 active, 229 payload, 230 }: { 231 active?: boolean; 232 payload?: Array<{ payload: ModelCVData }>; 233 label?: string; 234 }) { 235 if (!active || !payload || payload.length === 0) return null; 236 const d = payload[0].payload; 237 if (!d.model) return null; 238 return ( 239 <div style={TOOLTIP_STYLE}> 240 <div style={{ marginBottom: 4, fontWeight: 600 }}>{d.model}</div> 241 <div>Cells: {d.cellCount}</div> 242 <div>Max CV: {d.max.toFixed(1)}%</div> 243 <div>Q3: {d.q3.toFixed(1)}%</div> 244 <div>Median: {d.median.toFixed(1)}%</div> 245 <div>Q1: {d.q1.toFixed(1)}%</div> 246 <div>Min CV: {d.min.toFixed(1)}%</div> 247 </div> 248 ); 249 } 250 251 export default function VariabilityViolin({ runs }: VariabilityViolinProps) { 252 const data = useMemo(() => computeModelCV(runs), [runs]); 253 const scatterData = useMemo(() => buildScatterData(data), [data]); 254 255 if (data.length === 0) { 256 return ( 257 <div 258 className="card" 259 style={{ 260 textAlign: "center", 261 padding: "40px", 262 color: SMUI.muted, 263 fontFamily: "'JetBrains Mono', monospace", 264 }} 265 > 266 Not enough multi-run cells to compute variability. 267 </div> 268 ); 269 } 270 271 const maxCV = Math.max(...data.map((d) => d.max), 10); 272 const yMax = Math.ceil(maxCV / 10) * 10; 273 274 return ( 275 <div className="card"> 276 <h3 277 style={{ 278 marginBottom: "4px", 279 fontFamily: "'JetBrains Mono', monospace", 280 }} 281 > 282 Score Variability by Model (CV%) 283 </h3> 284 <p 285 style={{ 286 color: SMUI.muted, 287 fontSize: "11px", 288 fontFamily: "'JetBrains Mono', monospace", 289 marginBottom: "16px", 290 }} 291 > 292 Lower = more consistent. Each dot is one cell's coefficient of 293 variation. 294 </p> 295 <ResponsiveContainer width="100%" height={320}> 296 <ComposedChart data={data} barCategoryGap="25%"> 297 <CartesianGrid 298 strokeDasharray="3 3" 299 stroke={SMUI.border} 300 vertical={false} 301 /> 302 <XAxis 303 dataKey="label" 304 stroke={SMUI.muted} 305 tickLine={false} 306 axisLine={{ stroke: SMUI.border }} 307 interval={0} 308 tick={({ x, y, payload }: any) => { 309 const [name, count] = (payload.value as string).split("|"); 310 return ( 311 <g> 312 <text 313 x={x} 314 y={y + 12} 315 textAnchor="middle" 316 fill={SMUI.muted} 317 fontSize={10} 318 fontFamily="'JetBrains Mono', monospace" 319 > 320 {name} 321 </text> 322 <text 323 x={x} 324 y={y + 24} 325 textAnchor="middle" 326 fill={SMUI.muted} 327 fontSize={8} 328 fontFamily="'JetBrains Mono', monospace" 329 opacity={0.6} 330 > 331 {count} 332 </text> 333 </g> 334 ); 335 }} 336 height={40} 337 /> 338 <YAxis 339 stroke={SMUI.muted} 340 fontSize={11} 341 fontFamily="'JetBrains Mono', monospace" 342 domain={[0, yMax]} 343 tickLine={false} 344 axisLine={false} 345 yAxisId="cv" 346 label={{ 347 value: "CV%", 348 angle: -90, 349 position: "insideLeft", 350 style: { 351 fill: SMUI.muted, 352 fontSize: 10, 353 fontFamily: "'JetBrains Mono', monospace", 354 }, 355 }} 356 /> 357 <Tooltip 358 content={<CVTooltipContent />} 359 cursor={{ fill: "hsl(217 17% 28% / 0.3)" }} 360 /> 361 {/* Invisible base bar to push the visible box up to q1 */} 362 <Bar 363 dataKey="base" 364 stackId="box" 365 fill="transparent" 366 barSize={40} 367 yAxisId="cv" 368 /> 369 {/* Visible IQR box with custom shape for whiskers and median */} 370 <Bar 371 dataKey="iqr" 372 stackId="box" 373 barSize={40} 374 yAxisId="cv" 375 shape={<CVBoxPlotShape />} 376 > 377 {data.map((entry) => ( 378 <Cell key={entry.label} fill={entry.color} /> 379 ))} 380 </Bar> 381 {/* Hidden scatter to keep recharts scale consistent */} 382 <Scatter data={[]} dataKey="cv" yAxisId="cv" fill="transparent" /> 383 </ComposedChart> 384 </ResponsiveContainer> 385 </div> 386 ); 387 }