commit ed1dbd497fa3615edee1daf58e8889ffcd321d91
parent d206a13dbbf7342428cefad151df39d4847cbf3a
Author: Brian Graham <brian@buildingbetterteams.de>
Date: Tue, 7 Apr 2026 17:44:49 +0200
3D PCA scatter plot with react-three-fiber
Interactive 3D visualization: orbit controls, spheres colored by model,
size by score, axis labels with variance explained. Hover tooltips.
SMUI-themed dark background with Nord colors.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat:
2 files changed, 310 insertions(+), 233 deletions(-)
diff --git a/dashboard/src/components/PCAPlot.tsx b/dashboard/src/components/PCAPlot.tsx
@@ -1,15 +1,7 @@
-import React, { useState, useMemo } from "react";
-import {
- ScatterChart,
- Scatter,
- XAxis,
- YAxis,
- ZAxis,
- CartesianGrid,
- Tooltip,
- ResponsiveContainer,
- Cell,
-} from "recharts";
+import React, { useState, useMemo, useRef, useCallback } from "react";
+import { Canvas, type ThreeEvent } from "@react-three/fiber";
+import { OrbitControls, Html, Line } from "@react-three/drei";
+import * as THREE from "three";
import { getModelColor, modelSortOrder } from "../lib/colors";
interface PCAPoint {
@@ -55,23 +47,6 @@ interface PCAPlotProps {
type PCKey = "pc1" | "pc2" | "pc3";
-const PC_OPTIONS: { value: PCKey; label: string }[] = [
- { value: "pc1", label: "PC1" },
- { value: "pc2", label: "PC2" },
- { value: "pc3", label: "PC3" },
-];
-
-const selectStyle: React.CSSProperties = {
- background: "hsl(217 16% 15.5%)",
- color: "hsl(213 14% 80%)",
- border: "1px solid hsl(217 17% 28%)",
- borderRadius: "0",
- fontFamily: "'JetBrains Mono', monospace",
- fontSize: "11px",
- padding: "4px 6px",
- cursor: "pointer",
-};
-
const AXIS_LABELS: Record<string, string> = {
model: "Model",
effort: "Effort",
@@ -98,12 +73,153 @@ const AXIS_LABELS: Record<string, string> = {
provider: "Provider",
};
-function CustomTooltip({ active, payload }: any) {
- if (!active || !payload || payload.length === 0) return null;
- const d = payload[0]?.payload;
- if (!d) return null;
+// Convert HSL string to hex for Three.js
+function hslToHex(hsl: string): string {
+ // Handle both "hsl(193 44% 67%)" and "hsl(193, 44%, 67%)" formats
+ const match = hsl.match(/hsl\((\d+)[,\s]+(\d+)%[,\s]+(\d+)%\)/);
+ if (!match) return "#888888";
+ const h = parseInt(match[1]) / 360;
+ const s = parseInt(match[2]) / 100;
+ const l = parseInt(match[3]) / 100;
+
+ const a2 = s * Math.min(l, 1 - l);
+ const f = (n: number) => {
+ const k = (n + h * 12) % 12;
+ const color = l - a2 * Math.max(Math.min(k - 3, 9 - k, 1), -1);
+ return Math.round(255 * color)
+ .toString(16)
+ .padStart(2, "0");
+ };
+ return `#${f(0)}${f(8)}${f(4)}`;
+}
+
+interface NormalizedPoint extends PCAPoint {
+ nx: number;
+ ny: number;
+ nz: number;
+ radius: number;
+ color: string;
+ hexColor: string;
+}
+
+interface SceneProps {
+ points: NormalizedPoint[];
+ varianceExplained: number[];
+ onHover: (point: NormalizedPoint | null) => void;
+ hoveredPoint: NormalizedPoint | null;
+}
+
+function AxisLine({
+ start,
+ end,
+ label,
+}: {
+ start: [number, number, number];
+ end: [number, number, number];
+ label: string;
+}) {
+ return (
+ <>
+ <Line
+ points={[start, end]}
+ color="#5a6078"
+ lineWidth={1.5}
+ />
+ <Html position={end} style={{ pointerEvents: "none" }}>
+ <div
+ style={{
+ fontFamily: "'JetBrains Mono', monospace",
+ fontSize: "10px",
+ color: "hsl(213 14% 55%)",
+ whiteSpace: "nowrap",
+ userSelect: "none",
+ }}
+ >
+ {label}
+ </div>
+ </Html>
+ </>
+ );
+}
+
+function GridLines() {
+ const gridColor = "#2d3045";
+ const lines: Array<{ start: [number, number, number]; end: [number, number, number] }> = [
+ // Lines at y=0, z=0 plane (along x)
+ { start: [-1.3, 0, 0], end: [1.3, 0, 0] },
+ // Lines at x=0, z=0 plane (along y)
+ { start: [0, -1.3, 0], end: [0, 1.3, 0] },
+ // Lines at x=0, y=0 plane (along z)
+ { start: [0, 0, -1.3], end: [0, 0, 1.3] },
+ ];
+
+ return (
+ <>
+ {lines.map((line, i) => (
+ <Line
+ key={i}
+ points={[line.start, line.end]}
+ color={gridColor}
+ lineWidth={0.5}
+ dashed
+ dashSize={0.05}
+ gapSize={0.05}
+ />
+ ))}
+ </>
+ );
+}
+
+function DataPoint({
+ point,
+ onHover,
+ isHovered,
+}: {
+ point: NormalizedPoint;
+ onHover: (point: NormalizedPoint | null) => void;
+ isHovered: boolean;
+}) {
+ const meshRef = useRef<THREE.Mesh>(null);
+
+ const handlePointerOver = useCallback(
+ (e: ThreeEvent<PointerEvent>) => {
+ e.stopPropagation();
+ onHover(point);
+ document.body.style.cursor = "pointer";
+ },
+ [point, onHover]
+ );
+
+ const handlePointerOut = useCallback(
+ (e: ThreeEvent<PointerEvent>) => {
+ e.stopPropagation();
+ onHover(null);
+ document.body.style.cursor = "auto";
+ },
+ [onHover]
+ );
- const scorePct = Math.round(d.score * 100);
+ return (
+ <mesh
+ ref={meshRef}
+ position={[point.nx, point.ny, point.nz]}
+ onPointerOver={handlePointerOver}
+ onPointerOut={handlePointerOut}
+ >
+ <sphereGeometry args={[isHovered ? point.radius * 1.4 : point.radius, 16, 16]} />
+ <meshStandardMaterial
+ color={point.hexColor}
+ transparent
+ opacity={isHovered ? 1.0 : 0.7}
+ emissive={point.hexColor}
+ emissiveIntensity={isHovered ? 0.4 : 0.15}
+ />
+ </mesh>
+ );
+}
+
+function HoverTooltip({ point }: { point: NormalizedPoint }) {
+ const scorePct = Math.round(point.score * 100);
const scoreColor =
scorePct >= 70
? "hsl(92 28% 65%)"
@@ -112,43 +228,101 @@ function CustomTooltip({ active, payload }: any) {
: "hsl(355 52% 64%)";
return (
- <div
- style={{
- background: "hsl(217 16% 15.5%)",
- border: "1px solid hsl(217 17% 28%)",
- borderRadius: "0",
- fontFamily: "'JetBrains Mono', monospace",
- fontSize: "11px",
- padding: "8px 10px",
- lineHeight: "1.6",
- color: "hsl(213 14% 80%)",
- maxWidth: 300,
- }}
+ <Html
+ position={[point.nx, point.ny, point.nz]}
+ style={{ pointerEvents: "none" }}
+ center
>
- <div style={{ display: "flex", justifyContent: "space-between", gap: 12 }}>
- <span style={{ fontWeight: 600, color: getModelColor(d.model) }}>
- {d.model}
- </span>
- <span style={{ fontFamily: "'JetBrains Mono', monospace", fontWeight: 600, color: scoreColor }}>
- {scorePct}%
- </span>
- </div>
- <div style={{ marginTop: 4, color: "hsl(213 14% 55%)", fontSize: "10px" }}>
- {d.short_id}
- </div>
- {d.config_summary && (
- <div
- style={{
- marginTop: 4,
- fontSize: "10px",
- color: "hsl(213 14% 65%)",
- wordBreak: "break-word",
- }}
- >
- {d.config_summary}
+ <div
+ style={{
+ background: "hsl(217 16% 15.5%)",
+ border: "1px solid hsl(217 17% 28%)",
+ fontFamily: "'JetBrains Mono', monospace",
+ fontSize: "11px",
+ padding: "8px 10px",
+ lineHeight: "1.6",
+ color: "hsl(213 14% 80%)",
+ maxWidth: 260,
+ whiteSpace: "nowrap",
+ transform: "translate(16px, -50%)",
+ pointerEvents: "none",
+ }}
+ >
+ <div style={{ display: "flex", justifyContent: "space-between", gap: 12 }}>
+ <span style={{ fontWeight: 600, color: point.color }}>
+ {point.model}
+ </span>
+ <span style={{ fontWeight: 600, color: scoreColor }}>
+ {scorePct}%
+ </span>
</div>
- )}
- </div>
+ <div style={{ marginTop: 4, color: "hsl(213 14% 55%)", fontSize: "10px" }}>
+ {point.short_id}
+ </div>
+ {point.config_summary && (
+ <div
+ style={{
+ marginTop: 4,
+ fontSize: "10px",
+ color: "hsl(213 14% 65%)",
+ whiteSpace: "normal",
+ wordBreak: "break-word",
+ }}
+ >
+ {point.config_summary}
+ </div>
+ )}
+ </div>
+ </Html>
+ );
+}
+
+function Scene({ points, varianceExplained, onHover, hoveredPoint }: SceneProps) {
+ return (
+ <>
+ <ambientLight intensity={0.6} />
+ <pointLight position={[5, 5, 5]} intensity={0.8} />
+ <pointLight position={[-5, -5, -5]} intensity={0.3} />
+
+ <OrbitControls
+ enableDamping
+ dampingFactor={0.1}
+ minDistance={1}
+ maxDistance={8}
+ />
+
+ <GridLines />
+
+ {/* Axis lines */}
+ <AxisLine
+ start={[0, 0, 0]}
+ end={[1.5, 0, 0]}
+ label={`PC1 (${varianceExplained[0]}%)`}
+ />
+ <AxisLine
+ start={[0, 0, 0]}
+ end={[0, 1.5, 0]}
+ label={`PC2 (${varianceExplained[1]}%)`}
+ />
+ <AxisLine
+ start={[0, 0, 0]}
+ end={[0, 0, 1.5]}
+ label={`PC3 (${varianceExplained[2]}%)`}
+ />
+
+ {/* Data points */}
+ {points.map((pt) => (
+ <DataPoint
+ key={pt.run_id}
+ point={pt}
+ onHover={onHover}
+ isHovered={hoveredPoint?.run_id === pt.run_id}
+ />
+ ))}
+
+ {/* Tooltip on hover */}
+ {hoveredPoint && <HoverTooltip point={hoveredPoint} />}
+ </>
);
}
@@ -273,94 +447,99 @@ function LoadingsTable({
}
export default function PCAPlot({ data }: PCAPlotProps) {
- const [xPC, setXPC] = useState<PCKey>("pc1");
- const [yPC, setYPC] = useState<PCKey>("pc2");
+ const [hoveredPoint, setHoveredPoint] = useState<NormalizedPoint | null>(null);
+
+ // Normalize points to -1..1 range per axis and compute radii
+ const { normalizedPoints, modelGroups } = useMemo(() => {
+ const pc1Vals = data.points.map((p) => Math.abs(p.pc1));
+ const pc2Vals = data.points.map((p) => Math.abs(p.pc2));
+ const pc3Vals = data.points.map((p) => Math.abs(p.pc3));
+ const maxPc1 = Math.max(...pc1Vals) || 1;
+ const maxPc2 = Math.max(...pc2Vals) || 1;
+ const maxPc3 = Math.max(...pc3Vals) || 1;
+
+ const scores = data.points.map((p) => p.score);
+ const minScore = Math.min(...scores);
+ const maxScore = Math.max(...scores);
+ const scoreRange = maxScore - minScore || 1;
- // Group points by model
- const modelGroups = useMemo(() => {
- const groups: Record<string, PCAPoint[]> = {};
- for (const pt of data.points) {
+ const pts: NormalizedPoint[] = data.points.map((p) => {
+ const color = getModelColor(p.model);
+ const t = (p.score - minScore) / scoreRange;
+ return {
+ ...p,
+ nx: p.pc1 / maxPc1,
+ ny: p.pc2 / maxPc2,
+ nz: p.pc3 / maxPc3,
+ radius: 0.05 + t * 0.1,
+ color,
+ hexColor: hslToHex(color),
+ };
+ });
+
+ // Group by model
+ const groups: Record<string, NormalizedPoint[]> = {};
+ for (const pt of pts) {
(groups[pt.model] ??= []).push(pt);
}
- return Object.entries(groups).sort(
+ const sorted = Object.entries(groups).sort(
([a], [b]) => modelSortOrder(a) - modelSortOrder(b)
);
- }, [data.points]);
- // Compute axis domains
- const allX = data.points.map((p) => p[xPC]);
- const allY = data.points.map((p) => p[yPC]);
- const xMin = Math.min(...allX);
- const xMax = Math.max(...allX);
- const yMin = Math.min(...allY);
- const yMax = Math.max(...allY);
- const xPad = (xMax - xMin) * 0.08 || 1;
- const yPad = (yMax - yMin) * 0.08 || 1;
-
- // Score range for sizing
- const scores = data.points.map((p) => p.score);
- const minScore = Math.min(...scores);
- const maxScore = Math.max(...scores);
+ return { normalizedPoints: pts, modelGroups: sorted };
+ }, [data.points]);
- const xVarIdx = parseInt(xPC.replace("pc", "")) - 1;
- const yVarIdx = parseInt(yPC.replace("pc", "")) - 1;
+ const handleHover = useCallback((point: NormalizedPoint | null) => {
+ setHoveredPoint(point);
+ }, []);
return (
<div style={{ display: "flex", flexDirection: "column", gap: 24 }}>
- {/* Chart card */}
+ {/* 3D scatter plot card */}
<div className="card" style={{ position: "relative" }}>
<div
style={{
display: "flex",
alignItems: "center",
gap: 8,
- marginBottom: 16,
+ marginBottom: 12,
flexWrap: "wrap",
}}
>
<span
- style={{ fontSize: "11px", color: "hsl(213 14% 55%)", textTransform: "uppercase", letterSpacing: "0.5px" }}
- >
- X axis
- </span>
- <select
- value={xPC}
- onChange={(e) => setXPC(e.target.value as PCKey)}
- style={selectStyle}
- >
- {PC_OPTIONS.map((opt) => (
- <option key={opt.value} value={opt.value}>
- {opt.label} ({data.variance_explained[parseInt(opt.value.replace("pc", "")) - 1]}%)
- </option>
- ))}
- </select>
- <span style={{ fontSize: "12px", color: "hsl(213 14% 55%)" }}>vs</span>
- <span
- style={{ fontSize: "11px", color: "hsl(213 14% 55%)", textTransform: "uppercase", letterSpacing: "0.5px" }}
- >
- Y axis
- </span>
- <select
- value={yPC}
- onChange={(e) => setYPC(e.target.value as PCKey)}
- style={selectStyle}
- >
- {PC_OPTIONS.map((opt) => (
- <option key={opt.value} value={opt.value}>
- {opt.label} ({data.variance_explained[parseInt(opt.value.replace("pc", "")) - 1]}%)
- </option>
- ))}
- </select>
- <span
style={{
fontSize: "12px",
fontWeight: 400,
color: "hsl(213 14% 55%)",
- marginLeft: 8,
}}
>
{data.n_runs} runs, {data.n_features} features
</span>
+ <span
+ style={{
+ fontSize: "10px",
+ color: "hsl(213 14% 45%)",
+ marginLeft: "auto",
+ }}
+ >
+ Drag to rotate. Scroll to zoom. Point size proportional to score.
+ </span>
+ </div>
+
+ {/* 3D Canvas */}
+ <div style={{ width: "100%", height: 500, background: "#1a1d27" }}>
+ <Canvas
+ camera={{ position: [2.5, 2, 2.5], near: 0.1, far: 100 }}
+ style={{ background: "#1a1d27" }}
+ gl={{ antialias: true }}
+ >
+ <Scene
+ points={normalizedPoints}
+ varianceExplained={data.variance_explained}
+ onHover={handleHover}
+ hoveredPoint={hoveredPoint}
+ />
+ </Canvas>
</div>
{/* Legend */}
@@ -369,7 +548,7 @@ export default function PCAPlot({ data }: PCAPlotProps) {
display: "flex",
gap: 12,
justifyContent: "center",
- marginBottom: 12,
+ marginTop: 12,
flexWrap: "wrap",
}}
>
@@ -400,108 +579,6 @@ export default function PCAPlot({ data }: PCAPlotProps) {
</div>
))}
</div>
-
- <ResponsiveContainer width="100%" height={420}>
- <ScatterChart margin={{ top: 10, right: 20, bottom: 10, left: 10 }}>
- <CartesianGrid
- strokeDasharray="3 3"
- stroke="hsl(217 17% 28%)"
- />
- <XAxis
- dataKey="x"
- name={`PC${xVarIdx + 1}`}
- type="number"
- domain={[xMin - xPad, xMax + xPad]}
- stroke="hsl(213 14% 65%)"
- fontSize={11}
- tickFormatter={(v: number) => v.toFixed(1)}
- label={{
- value: `${xPC.toUpperCase()} (${data.variance_explained[xVarIdx]}%)`,
- position: "insideBottom",
- offset: -5,
- style: {
- fontSize: 11,
- fill: "hsl(213 14% 55%)",
- fontFamily: "'JetBrains Mono', monospace",
- },
- }}
- />
- <YAxis
- dataKey="y"
- name={`PC${yVarIdx + 1}`}
- type="number"
- domain={[yMin - yPad, yMax + yPad]}
- stroke="hsl(213 14% 65%)"
- fontSize={11}
- tickFormatter={(v: number) => v.toFixed(1)}
- label={{
- value: `${yPC.toUpperCase()} (${data.variance_explained[yVarIdx]}%)`,
- angle: -90,
- position: "insideLeft",
- style: {
- fontSize: 11,
- fill: "hsl(213 14% 55%)",
- fontFamily: "'JetBrains Mono', monospace",
- },
- }}
- />
- <ZAxis
- dataKey="z"
- range={[40, 200]}
- name="Score"
- />
- <Tooltip
- content={<CustomTooltip />}
- cursor={{ strokeDasharray: "3 3", stroke: "hsl(213 14% 35%)" }}
- />
- {modelGroups.map(([model, pts]) => {
- const chartData = pts.map((p) => ({
- x: p[xPC],
- y: p[yPC],
- z: maxScore > minScore
- ? ((p.score - minScore) / (maxScore - minScore)) * 100
- : 50,
- score: p.score,
- model: p.model,
- short_id: p.short_id,
- config_summary: p.config_summary,
- run_id: p.run_id,
- }));
-
- return (
- <Scatter
- key={model}
- name={model}
- data={chartData}
- fill={getModelColor(model)}
- opacity={0.7}
- isAnimationActive={false}
- >
- {chartData.map((_, idx) => (
- <Cell
- key={idx}
- fill={getModelColor(model)}
- stroke={getModelColor(model)}
- strokeWidth={1}
- opacity={0.7}
- />
- ))}
- </Scatter>
- );
- })}
- </ScatterChart>
- </ResponsiveContainer>
-
- <div
- style={{
- textAlign: "center",
- fontSize: "10px",
- color: "hsl(213 14% 45%)",
- marginTop: 4,
- }}
- >
- Point size proportional to score. Hover for details.
- </div>
</div>
{/* Loadings interpretation card */}
diff --git a/dashboard/src/pages/pca.astro b/dashboard/src/pages/pca.astro
@@ -18,7 +18,7 @@ if (fs.existsSync(pcaPath)) {
</p>
{pcaData ? (
- <PCAPlot client:load data={pcaData} />
+ <PCAPlot client:only="react" data={pcaData} />
) : (
<div class="card" style="text-align: center; padding: 40px; color: var(--text-muted);">
No PCA data yet. Run <code>python3 harness/pca-analysis.py</code> to generate.