From f8e381e511dc5018611879ba7342ad5484b917fc Mon Sep 17 00:00:00 2001 From: blader Date: Mon, 2 Feb 2026 01:47:06 -0800 Subject: [PATCH 01/16] Add SVR reconstruction + 3D volume viewer tooling --- .gitignore | 3 + frontend/src/components/AlignmentControls.tsx | 2 +- frontend/src/components/ComparisonMatrix.tsx | 132 +- frontend/src/components/DicomViewer.tsx | 99 +- .../components/GroundTruthPolygonOverlay.tsx | 560 +++++ frontend/src/components/Svr3DView.tsx | 1090 +++++++++ frontend/src/components/SvrModal.tsx | 793 +++++++ frontend/src/components/SvrVolume3DModal.tsx | 574 +++++ frontend/src/components/SvrVolume3DViewer.tsx | 1370 +++++++++++ .../components/TumorSegmentationOverlay.tsx | 2107 +++++++++++++++++ .../comparison/ComparisonFiltersSidebar.tsx | 46 +- .../src/components/comparison/GridCell.tsx | 235 ++ .../src/components/comparison/GridView.tsx | 145 +- .../src/components/comparison/OverlayView.tsx | 115 +- frontend/src/db/db.ts | 36 +- frontend/src/db/schema.ts | 141 ++ frontend/src/hooks/useAutoAlign.ts | 215 +- frontend/src/hooks/useComparisonFilters.ts | 32 +- frontend/src/hooks/useOverlayNavigation.ts | 18 +- frontend/src/hooks/useSvrReconstruction.ts | 116 + frontend/src/services/dicomIngestion.ts | 7 +- frontend/src/services/exportBackup.ts | 1 + frontend/src/services/svrHarness.ts | 86 + frontend/src/types/api.ts | 7 +- frontend/src/types/svr.ts | 143 ++ frontend/src/utils/alignment.ts | 607 ++++- .../src/utils/alignmentSliceScoreStore.ts | 93 + frontend/src/utils/debugSvr.ts | 33 + frontend/src/utils/elastixRegistration.ts | 127 +- frontend/src/utils/imageFeatures.ts | 83 + frontend/src/utils/localApi.ts | 145 +- frontend/src/utils/mind.ts | 334 +++ frontend/src/utils/mutualInformation.ts | 31 +- frontend/src/utils/phaseCorrelation.ts | 416 ++++ .../utils/segmentation/geodesicDistance.ts | 195 ++ .../src/utils/segmentation/gtBenchmark.ts | 614 +++++ .../src/utils/segmentation/harness/base64.ts | 62 + .../segmentation/harness/canonicalize.ts | 37 + .../src/utils/segmentation/harness/dataset.ts | 98 + .../harness/exportTumorHarnessDataset.ts | 261 ++ .../harness/loadCornerstoneGrayscale.ts | 77 + .../segmentation/harness/runTumorHarness.ts | 443 ++++ .../segmentation/harness/syntheticPaint.ts | 185 ++ .../src/utils/segmentation/marchingSquares.ts | 201 ++ .../src/utils/segmentation/maskMetrics.ts | 51 + frontend/src/utils/segmentation/morphology.ts | 97 + .../segmentation/polygonBoundaryMetrics.ts | 174 ++ .../utils/segmentation/rasterizePolygon.ts | 71 + .../src/utils/segmentation/segmentTumor.ts | 1484 ++++++++++++ frontend/src/utils/segmentation/simplify.ts | 70 + frontend/src/utils/segmentation/smooth.ts | 41 + .../src/utils/segmentation/traceBoundary.ts | 148 ++ frontend/src/utils/ssim.ts | 243 ++ frontend/src/utils/svr/dicomGeometry.ts | 171 ++ frontend/src/utils/svr/downsample.ts | 31 + frontend/src/utils/svr/reconstructVolume.ts | 1565 ++++++++++++ frontend/src/utils/svr/reconstructionCore.ts | 473 ++++ frontend/src/utils/svr/resample2d.ts | 229 ++ frontend/src/utils/svr/sliceRoiCrop.ts | 112 + frontend/src/utils/svr/trilinear.ts | 142 ++ frontend/src/utils/svr/vec3.ts | 39 + frontend/src/utils/svr/volumePreview.ts | 112 + frontend/src/utils/tumorPropagation.ts | 210 ++ frontend/src/utils/tumorPropagationCore.ts | 115 + frontend/src/utils/viewTransform.ts | 149 ++ frontend/src/utils/viewportMapping.ts | 51 + frontend/tests/alignmentSliceSearch.test.ts | 159 ++ frontend/tests/geodesicDistance.test.ts | 57 + frontend/tests/mutualInformation.test.ts | 14 + frontend/tests/segmentation.test.ts | 74 + frontend/tests/svrDicomGeometry.test.ts | 59 + frontend/tests/svrDownsample.test.ts | 50 + frontend/tests/svrGeometryInvariants.test.ts | 119 + frontend/tests/svrPhantom.test.ts | 232 ++ frontend/tests/svrResample2d.test.ts | 89 + frontend/tests/svrSliceRoiCrop.test.ts | 86 + frontend/tests/svrTrilinear.test.ts | 46 + frontend/tests/tumorHarnessRunner.test.ts | 84 + frontend/tests/viewTransform.test.ts | 98 + 79 files changed, 18779 insertions(+), 281 deletions(-) create mode 100644 frontend/src/components/GroundTruthPolygonOverlay.tsx create mode 100644 frontend/src/components/Svr3DView.tsx create mode 100644 frontend/src/components/SvrModal.tsx create mode 100644 frontend/src/components/SvrVolume3DModal.tsx create mode 100644 frontend/src/components/SvrVolume3DViewer.tsx create mode 100644 frontend/src/components/TumorSegmentationOverlay.tsx create mode 100644 frontend/src/components/comparison/GridCell.tsx create mode 100644 frontend/src/hooks/useSvrReconstruction.ts create mode 100644 frontend/src/services/svrHarness.ts create mode 100644 frontend/src/types/svr.ts create mode 100644 frontend/src/utils/alignmentSliceScoreStore.ts create mode 100644 frontend/src/utils/debugSvr.ts create mode 100644 frontend/src/utils/imageFeatures.ts create mode 100644 frontend/src/utils/mind.ts create mode 100644 frontend/src/utils/phaseCorrelation.ts create mode 100644 frontend/src/utils/segmentation/geodesicDistance.ts create mode 100644 frontend/src/utils/segmentation/gtBenchmark.ts create mode 100644 frontend/src/utils/segmentation/harness/base64.ts create mode 100644 frontend/src/utils/segmentation/harness/canonicalize.ts create mode 100644 frontend/src/utils/segmentation/harness/dataset.ts create mode 100644 frontend/src/utils/segmentation/harness/exportTumorHarnessDataset.ts create mode 100644 frontend/src/utils/segmentation/harness/loadCornerstoneGrayscale.ts create mode 100644 frontend/src/utils/segmentation/harness/runTumorHarness.ts create mode 100644 frontend/src/utils/segmentation/harness/syntheticPaint.ts create mode 100644 frontend/src/utils/segmentation/marchingSquares.ts create mode 100644 frontend/src/utils/segmentation/maskMetrics.ts create mode 100644 frontend/src/utils/segmentation/morphology.ts create mode 100644 frontend/src/utils/segmentation/polygonBoundaryMetrics.ts create mode 100644 frontend/src/utils/segmentation/rasterizePolygon.ts create mode 100644 frontend/src/utils/segmentation/segmentTumor.ts create mode 100644 frontend/src/utils/segmentation/simplify.ts create mode 100644 frontend/src/utils/segmentation/smooth.ts create mode 100644 frontend/src/utils/segmentation/traceBoundary.ts create mode 100644 frontend/src/utils/ssim.ts create mode 100644 frontend/src/utils/svr/dicomGeometry.ts create mode 100644 frontend/src/utils/svr/downsample.ts create mode 100644 frontend/src/utils/svr/reconstructVolume.ts create mode 100644 frontend/src/utils/svr/reconstructionCore.ts create mode 100644 frontend/src/utils/svr/resample2d.ts create mode 100644 frontend/src/utils/svr/sliceRoiCrop.ts create mode 100644 frontend/src/utils/svr/trilinear.ts create mode 100644 frontend/src/utils/svr/vec3.ts create mode 100644 frontend/src/utils/svr/volumePreview.ts create mode 100644 frontend/src/utils/tumorPropagation.ts create mode 100644 frontend/src/utils/tumorPropagationCore.ts create mode 100644 frontend/src/utils/viewTransform.ts create mode 100644 frontend/src/utils/viewportMapping.ts create mode 100644 frontend/tests/alignmentSliceSearch.test.ts create mode 100644 frontend/tests/geodesicDistance.test.ts create mode 100644 frontend/tests/segmentation.test.ts create mode 100644 frontend/tests/svrDicomGeometry.test.ts create mode 100644 frontend/tests/svrDownsample.test.ts create mode 100644 frontend/tests/svrGeometryInvariants.test.ts create mode 100644 frontend/tests/svrPhantom.test.ts create mode 100644 frontend/tests/svrResample2d.test.ts create mode 100644 frontend/tests/svrSliceRoiCrop.test.ts create mode 100644 frontend/tests/svrTrilinear.test.ts create mode 100644 frontend/tests/tumorHarnessRunner.test.ts create mode 100644 frontend/tests/viewTransform.test.ts diff --git a/.gitignore b/.gitignore index e2d5022..d264113 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,8 @@ Critical MRI Source Images (LLM Agent - do not delete)/ # Local test artifacts (downloaded DICOMs, zips, etc.) frontend/tmp/ +# Vercel local project metadata +.vercel/ + # Accidental root lockfile (repo does not have a root package.json) /package-lock.json diff --git a/frontend/src/components/AlignmentControls.tsx b/frontend/src/components/AlignmentControls.tsx index f7bcfd2..273ad03 100644 --- a/frontend/src/components/AlignmentControls.tsx +++ b/frontend/src/components/AlignmentControls.tsx @@ -104,7 +104,7 @@ export function AlignmentControls({ {progress.slicesChecked > 0 && ( - {progress.slicesChecked} slices · MI {progress.bestMiSoFar.toFixed(3)} + {progress.slicesChecked} slices · Score {progress.bestMiSoFar.toFixed(3)} )} diff --git a/frontend/src/components/ComparisonMatrix.tsx b/frontend/src/components/ComparisonMatrix.tsx index 699c250..5b48aa1 100644 --- a/frontend/src/components/ComparisonMatrix.tsx +++ b/frontend/src/components/ComparisonMatrix.tsx @@ -13,11 +13,13 @@ import { Trash2, MoreVertical, HelpCircle, + Box, } from 'lucide-react'; import { HelpModal } from './HelpModal'; import { UploadModal } from './UploadModal'; import { ExportModal } from './ExportModal'; import { ClearDataModal } from './ClearDataModal'; +import { Svr3DView } from './Svr3DView'; import { SliceLoopNavigator } from './comparison/SliceLoopNavigator'; import { GridView } from './comparison/GridView'; import { OverlayView } from './comparison/OverlayView'; @@ -139,6 +141,8 @@ export function ComparisonMatrix() { isAligning, progress: alignmentProgress, results: alignmentResults, + error: alignmentError, + clearState: clearAlignmentState, alignAllDates, abort: abortAlignment, } = useAutoAlign(); @@ -161,7 +165,8 @@ export function ComparisonMatrix() { const planeKey = (plane: string | null) => (plane && plane.trim() ? plane : 'Other'); return data.sequences - .filter(s => planeKey(s.plane) === selectedPlane) + .filter((s) => planeKey(s.plane) === selectedPlane) + .filter((s) => formatSequenceLabel(s) !== 'Unknown') .sort((a, b) => formatSequenceLabel(b).localeCompare(formatSequenceLabel(a))); // reverse alpha }, [data, selectedPlane]); @@ -280,6 +285,45 @@ export function ComparisonMatrix() { const overlayViewerSize = getOverlayViewerSize(gridSize); + // Seed SVR 3D ROI preview slice: + // - Prefer the currently displayed overlay slice when available. + // - Otherwise fall back to the newest enabled date in the grid. + const svr3dSeed = useMemo(() => { + if (overlayDisplayedDate && overlayDisplayedRef) { + return { + defaultDateIso: overlayDisplayedDate, + fallbackRoiSeriesUid: overlayDisplayedRef.series_uid, + fallbackRoiSliceIndex: overlayDisplayedEffectiveSliceIndex, + }; + } + + const first = columns.find((c) => c.ref); + if (!first?.ref) { + return { + defaultDateIso: null, + fallbackRoiSeriesUid: null, + fallbackRoiSliceIndex: null, + }; + } + + const settings = panelSettings.get(first.date) || DEFAULT_PANEL_SETTINGS; + const sliceIndex = getSliceIndex(first.ref.instance_count, progress, settings.offset); + const effectiveIndex = getEffectiveInstanceIndex(sliceIndex, first.ref.instance_count, settings.reverseSliceOrder); + + return { + defaultDateIso: first.date, + fallbackRoiSeriesUid: first.ref.series_uid, + fallbackRoiSliceIndex: effectiveIndex, + }; + }, [ + columns, + overlayDisplayedDate, + overlayDisplayedEffectiveSliceIndex, + overlayDisplayedRef, + panelSettings, + progress, + ]); + const startAlignAll = useCallback( async (reference: AlignmentReference, exclusionMask: ExclusionMask) => { if (isAligning) { @@ -327,6 +371,12 @@ export function ComparisonMatrix() { // preventDefault, and we skip them here via `e.defaultPrevented`. const wheelNavContextRef = useRef<{ instanceCount: number; offset: number } | null>(null); useEffect(() => { + if (viewMode === 'svr3d') { + // The SVR 3D view uses mousewheel for zoom; don't hijack wheel events for slice navigation. + wheelNavContextRef.current = null; + return; + } + let instanceCount = 1; let offset = DEFAULT_PANEL_SETTINGS.offset; @@ -438,12 +488,20 @@ export function ComparisonMatrix() { + @@ -582,20 +640,36 @@ export function ComparisonMatrix() { {/* Main area with sidebar */}
- setSidebarOpen((v) => !v)} - availablePlanes={availablePlanes} - selectedPlane={selectedPlane} - onSelectPlane={selectPlane} - sequencesForPlane={sequencesForPlane} - sequencesWithDataForDates={sequencesWithDataForDates} - selectedSeqId={selectedSeqId} - onSelectSequence={selectSequence} - /> - - {/* Main content area - Grid or Overlay */} + {viewMode !== 'svr3d' ? ( + setSidebarOpen((v) => !v)} + availablePlanes={availablePlanes} + selectedPlane={selectedPlane} + onSelectPlane={selectPlane} + sequencesForPlane={sequencesForPlane} + sequencesWithDataForDates={sequencesWithDataForDates} + selectedSeqId={selectedSeqId} + onSelectSequence={selectSequence} + /> + ) : null} + + {/* Main content area - Grid / Overlay / SVR 3D */}
+ {alignmentError && !isAligning ? ( +
+
+ Alignment failed: {alignmentError} +
+ +
+ ) : null} {!hasData ? ( /* Empty state */
@@ -628,6 +702,7 @@ export function ComparisonMatrix() {
) : viewMode === 'grid' ? ( - ) : ( + ) : viewMode === 'overlay' ? ( + ) : ( + )}
@@ -684,13 +768,15 @@ export function ComparisonMatrix() {
{/* Slice navigator with loop + speed controls */} - + {viewMode !== 'svr3d' ? ( + + ) : null} ); } diff --git a/frontend/src/components/DicomViewer.tsx b/frontend/src/components/DicomViewer.tsx index 1bddc5e..44a055f 100644 --- a/frontend/src/components/DicomViewer.tsx +++ b/frontend/src/components/DicomViewer.tsx @@ -11,6 +11,8 @@ import { getImageIdForInstance } from '../utils/localApi'; import cornerstone from 'cornerstone-core'; import { useWheelNavigation } from '../hooks/useWheelNavigation'; import { getEffectiveInstanceIndex } from '../utils/math'; +import { isDebugAlignmentEnabled } from '../utils/debugAlignment'; +import { getAlignmentSliceScore } from '../utils/alignmentSliceScoreStore'; export type DicomViewerCaptureOptions = { /** Max dimension (in CSS pixels) used for the capture output. Defaults to 512 for speed. */ @@ -25,6 +27,21 @@ export type DicomViewerHandle = { captureVisiblePng: (options?: DicomViewerCaptureOptions) => Promise; }; +function parseDicomViewerContentKey(contentKey: string): { seriesUid: string; instanceIndex: number } | null { + // Content key format: `${studyId}:${seriesUid}:${effectiveInstanceIndex}` + // + // We parse from the right so this keeps working even if study IDs ever contain ':' (unlikely). + const parts = contentKey.split(':'); + if (parts.length < 3) return null; + + const indexStr = parts[parts.length - 1]; + const seriesUid = parts[parts.length - 2]; + const instanceIndex = Number(indexStr); + if (!Number.isFinite(instanceIndex) || instanceIndex < 0) return null; + + return { seriesUid, instanceIndex }; +} + interface DicomViewerProps { studyId: string; seriesUid: string; @@ -124,6 +141,52 @@ export const DicomViewer = forwardRef(funct // Resolve imageId for Cornerstone (miradb:) const [imageId, setImageId] = useState(null); + // Track what slice is actually displayed in the viewer. + // CornerstoneImage intentionally keeps the previous image visible while the next slice loads. + const [displayedContentKey, setDisplayedContentKey] = useState(null); + + const debugSliceScores = isDebugAlignmentEnabled(); + + // Only show the (very noisy) per-slice debug scores overlay while the user is holding 'Z'. + // This keeps the UI clean while still making it easy to inspect values on demand. + const [isZHeld, setIsZHeld] = useState(false); + useEffect(() => { + if (typeof window === 'undefined') return; + + const isZKey = (e: KeyboardEvent) => (e.key || '').toLowerCase() === 'z'; + + const onKeyDown = (e: KeyboardEvent) => { + // Ignore cmd/ctrl/alt modified combos (e.g. Cmd+Z) so we don't flash the overlay + // during common shortcuts. + if (!isZKey(e)) return; + if (e.metaKey || e.ctrlKey || e.altKey) return; + setIsZHeld(true); + }; + + const onKeyUp = (e: KeyboardEvent) => { + if (!isZKey(e)) return; + setIsZHeld(false); + }; + + const onBlur = () => { + setIsZHeld(false); + }; + + window.addEventListener('keydown', onKeyDown); + window.addEventListener('keyup', onKeyUp); + window.addEventListener('blur', onBlur); + return () => { + window.removeEventListener('keydown', onKeyDown); + window.removeEventListener('keyup', onKeyUp); + window.removeEventListener('blur', onBlur); + }; + }, []); + + const displayedForScores = displayedContentKey ? parseDicomViewerContentKey(displayedContentKey) : null; + const scoreSeriesUid = displayedForScores?.seriesUid ?? seriesUid; + const scoreInstanceIndex = displayedForScores?.instanceIndex ?? effectiveInstanceIndex; + const sliceScore = debugSliceScores ? getAlignmentSliceScore(scoreSeriesUid, scoreInstanceIndex) : null; + useEffect(() => { let cancelled = false; (async () => { @@ -373,12 +436,30 @@ export const DicomViewer = forwardRef(funct imageFilter={imageFilter} imageTransform={imageTransform} alt={`Slice ${instanceIndex + 1}`} + onDisplayedContentKey={setDisplayedContentKey} /> ) : (
Loading...
)} + + {debugSliceScores && isZHeld ? ( +
+
+
SSIM: {sliceScore ? sliceScore.ssim.toFixed(6) : '—'}
+
LNCC: {sliceScore ? sliceScore.lncc.toFixed(6) : '—'}
+
ZNCC: {sliceScore ? sliceScore.zncc.toFixed(6) : '—'}
+
NGF: {sliceScore ? sliceScore.ngf.toFixed(6) : '—'}
+
Census: {sliceScore ? sliceScore.census.toFixed(6) : '—'}
+
MIND: {sliceScore && sliceScore.mind != null ? sliceScore.mind.toFixed(6) : '—'}
+
Phase: {sliceScore && sliceScore.phase != null ? sliceScore.phase.toFixed(6) : '—'}
+
MI: {sliceScore ? sliceScore.mi.toFixed(6) : '—'}
+
NMI: {sliceScore ? sliceScore.nmi.toFixed(6) : '—'}
+
Score: {sliceScore ? sliceScore.score.toFixed(6) : '—'}
+
+
+ ) : null} ); @@ -400,6 +481,9 @@ interface CornerstoneImageProps { imageFilter: string; imageTransform: string; alt: string; + + /** Called after Cornerstone actually displays the requested image. */ + onDisplayedContentKey?: (contentKey: string) => void; } function DelayedSpinnerOverlay({ delayMs = 150 }: { delayMs?: number }) { @@ -435,7 +519,14 @@ function ErrorOverlay({ message }: { message: string }) { ); } -function CornerstoneImage({ imageId, contentKey, imageFilter, imageTransform, alt }: CornerstoneImageProps) { +function CornerstoneImage({ + imageId, + contentKey, + imageFilter, + imageTransform, + alt, + onDisplayedContentKey, +}: CornerstoneImageProps) { const elementRef = useRef(null); const enabledRef = useRef(false); @@ -472,6 +563,11 @@ function CornerstoneImage({ imageId, contentKey, imageFilter, imageTransform, al contentKeyRef.current = contentKey; }, [contentKey]); + const onDisplayedContentKeyRef = useRef(onDisplayedContentKey); + useEffect(() => { + onDisplayedContentKeyRef.current = onDisplayedContentKey; + }, [onDisplayedContentKey]); + // Derive status from comparison const status: 'loading' | 'loaded' | 'error' = errorImageId === imageId ? 'error' : @@ -550,6 +646,7 @@ function CornerstoneImage({ imageId, contentKey, imageFilter, imageTransform, al setLoadedImageId(imageId); setLoadedContentKey(keyForThisLoad); setErrorImageId(null); + onDisplayedContentKeyRef.current?.(keyForThisLoad); } catch (err) { console.error('Failed to load DICOM image:', err); if (!cancelled) { diff --git a/frontend/src/components/GroundTruthPolygonOverlay.tsx b/frontend/src/components/GroundTruthPolygonOverlay.tsx new file mode 100644 index 0000000..28dad73 --- /dev/null +++ b/frontend/src/components/GroundTruthPolygonOverlay.tsx @@ -0,0 +1,560 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { Pencil, Save, Trash2, Undo2, X } from 'lucide-react'; +import type { NormalizedPoint, TumorPolygon, ViewerTransform } from '../db/schema'; +import { + deleteTumorGroundTruth, + getSopInstanceUidForInstanceIndex, + getTumorGroundTruthForInstance, + saveTumorGroundTruth, +} from '../utils/localApi'; +import { + normalizeViewerTransform, + remapPointBetweenViewerTransforms, + remapPointsBetweenViewerTransforms, + remapPolygonBetweenViewerTransforms, +} from '../utils/viewTransform'; + +function clamp01(v: number) { + return Math.max(0, Math.min(1, v)); +} + +function polygonToSvgPath(p: TumorPolygon): string { + if (!p.points.length) return ''; + + const d = [`M ${p.points[0].x.toFixed(4)} ${p.points[0].y.toFixed(4)}`]; + for (let i = 1; i < p.points.length; i++) { + d.push(`L ${p.points[i].x.toFixed(4)} ${p.points[i].y.toFixed(4)}`); + } + d.push('Z'); + return d.join(' '); +} + +export type GroundTruthPolygonOverlayProps = { + enabled: boolean; + onRequestClose: () => void; + + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + /** Instance index in effective slice ordering (i.e. after reverseSliceOrder mapping). */ + effectiveInstanceIndex: number; + + /** Current viewer transform (pan/zoom/rotation/affine). */ + viewerTransform: ViewerTransform; +}; + +export function GroundTruthPolygonOverlay({ + enabled, + onRequestClose, + comboId, + dateIso, + studyId, + seriesUid, + effectiveInstanceIndex, + viewerTransform, +}: GroundTruthPolygonOverlayProps) { + const containerRef = useRef(null); + const [containerSize, setContainerSize] = useState<{ w: number; h: number }>({ w: 0, h: 0 }); + + // Keep the latest viewer transform in a ref so we can snapshot it at specific lifecycle moments + // (e.g. when enabling or when loading a saved polygon) without re-running those effects on every + // pan/zoom/rotation change. + const viewerTransformRef = useRef(viewerTransform); + useEffect(() => { + viewerTransformRef.current = viewerTransform; + }, [viewerTransform]); + + const [draftPoints, setDraftPoints] = useState([]); + const [isClosed, setIsClosed] = useState(false); + const [draftViewTransform, setDraftViewTransform] = useState(null); + + const [savedPolygon, setSavedPolygon] = useState(null); + const [savedViewTransform, setSavedViewTransform] = useState(null); + const [busy, setBusy] = useState(false); + const [error, setError] = useState(null); + + // Load existing saved polygon when enabled or when slice changes. + useEffect(() => { + if (!enabled) return; + + let cancelled = false; + (async () => { + try { + setError(null); + const sop = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + const row = await getTumorGroundTruthForInstance(seriesUid, sop); + if (cancelled) return; + setSavedPolygon(row?.polygon ?? null); + setSavedViewTransform(row?.viewTransform ?? normalizeViewerTransform(null)); + } catch (e) { + console.error(e); + } + })(); + + return () => { + cancelled = true; + }; + }, [enabled, seriesUid, effectiveInstanceIndex]); + + // Reset draft state when turning on. + useEffect(() => { + if (!enabled) return; + setDraftPoints([]); + setIsClosed(false); + setDraftViewTransform({ ...viewerTransformRef.current }); + setError(null); + }, [enabled]); + + // Track container size (used for hit-testing / close threshold). + useEffect(() => { + if (!enabled) return; + const el = containerRef.current; + if (!el) return; + + const update = () => { + const r = el.getBoundingClientRect(); + setContainerSize({ w: r.width, h: r.height }); + }; + + update(); + const ro = new ResizeObserver(update); + ro.observe(el); + return () => ro.disconnect(); + }, [enabled]); + + const getLocalNormPoint = useCallback((e: PointerEvent | React.PointerEvent): NormalizedPoint | null => { + const el = containerRef.current; + if (!el) return null; + const r = el.getBoundingClientRect(); + if (r.width <= 0 || r.height <= 0) return null; + const x = ((e as PointerEvent).clientX - r.left) / r.width; + const y = ((e as PointerEvent).clientY - r.top) / r.height; + return { x: clamp01(x), y: clamp01(y) }; + }, []); + + const closeRadiusPx = 12; + + const isNearFirstPoint = useCallback( + (p: NormalizedPoint, first: NormalizedPoint) => { + if (containerSize.w <= 0 || containerSize.h <= 0) return false; + const dx = (p.x - first.x) * containerSize.w; + const dy = (p.y - first.y) * containerSize.h; + return Math.hypot(dx, dy) <= closeRadiusPx; + }, + [containerSize.h, containerSize.w] + ); + + const didClickRef = useRef(false); + + const onPointerDown = useCallback( + (e: React.PointerEvent) => { + if (!enabled) return; + if (!e.isPrimary) return; + if (e.button !== 0) return; + + // Avoid starting a polygon click on overlay buttons. + const target = e.target as HTMLElement | null; + if (target?.closest('[data-gt-ui="true"]')) return; + + const pCurrent = getLocalNormPoint(e); + if (!pCurrent) return; + + didClickRef.current = true; + setError(null); + + // If already closed, require the user to Clear before starting over. + if (isClosed) return; + + // Keep draft points in a stable "creation" view transform so the polygon can be re-projected + // when the user pans/zooms/rotates. + let baseView = draftViewTransform; + if (!baseView) { + baseView = { ...viewerTransform }; + setDraftViewTransform(baseView); + } + + const size = { w: containerSize.w, h: containerSize.h }; + const pDraft = + size.w > 0 && size.h > 0 + ? remapPointBetweenViewerTransforms(pCurrent, size, viewerTransform, baseView) + : pCurrent; + + setDraftPoints((prev) => { + if (prev.length >= 3) { + const firstDraft = prev[0]!; + const firstCurrent = + size.w > 0 && size.h > 0 + ? remapPointBetweenViewerTransforms(firstDraft, size, baseView!, viewerTransform) + : firstDraft; + + if (isNearFirstPoint(pCurrent, firstCurrent)) { + // Close polygon by clicking near the first point. + setIsClosed(true); + return prev; + } + } + + // Avoid adding duplicate points (in draft/view space). + const last = prev[prev.length - 1]; + if (last && Math.hypot(last.x - pDraft.x, last.y - pDraft.y) < 0.0015) { + return prev; + } + + return [...prev, pDraft]; + }); + }, + [ + containerSize.h, + containerSize.w, + didClickRef, + draftViewTransform, + enabled, + getLocalNormPoint, + isClosed, + isNearFirstPoint, + viewerTransform, + ] + ); + + const onClickCapture = useCallback((e: React.MouseEvent) => { + if (!didClickRef.current) return; + didClickRef.current = false; + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onUndo = useCallback(() => { + setError(null); + setIsClosed(false); + setDraftPoints((prev) => prev.slice(0, -1)); + }, []); + + const onClear = useCallback(() => { + setError(null); + setIsClosed(false); + setDraftPoints([]); + setDraftViewTransform({ ...viewerTransform }); + }, [viewerTransform]); + + const onSave = useCallback(async () => { + if (!enabled) return; + if (!isClosed || draftPoints.length < 3) { + setError('Close the polygon (click the first point) before saving'); + return; + } + + setBusy(true); + setError(null); + + try { + const sop = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + + const view = draftViewTransform ?? { ...viewerTransform }; + + const viewportSize = + containerSize.w > 0 && containerSize.h > 0 + ? { w: Math.round(containerSize.w), h: Math.round(containerSize.h) } + : undefined; + + await saveTumorGroundTruth({ + comboId, + dateIso, + studyId, + seriesUid, + sopInstanceUid: sop, + polygon: { points: draftPoints }, + viewTransform: view, + viewportSize, + }); + + setSavedPolygon({ points: draftPoints }); + setSavedViewTransform(view); + } catch (err) { + console.error(err); + setError(err instanceof Error ? err.message : 'Save failed'); + } finally { + setBusy(false); + } + }, [comboId, containerSize.h, containerSize.w, dateIso, draftPoints, draftViewTransform, effectiveInstanceIndex, enabled, isClosed, seriesUid, studyId, viewerTransform]); + + const onDelete = useCallback(async () => { + if (!enabled) return; + + setBusy(true); + setError(null); + + try { + const sop = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + await deleteTumorGroundTruth(seriesUid, sop); + setSavedPolygon(null); + setSavedViewTransform(null); + + // Also clear draft so there is no confusion about what's saved. + setDraftPoints([]); + setIsClosed(false); + } catch (err) { + console.error(err); + setError(err instanceof Error ? err.message : 'Delete failed'); + } finally { + setBusy(false); + } + }, [effectiveInstanceIndex, enabled, seriesUid]); + + // Keyboard shortcuts. + useEffect(() => { + if (!enabled) return; + + const onKeyDown = (e: KeyboardEvent) => { + if (e.key === 'Escape') { + // If the user is mid-draw, Esc cancels the draft. Otherwise it closes the tool. + if (draftPoints.length > 0 && !isClosed) { + onClear(); + e.preventDefault(); + e.stopPropagation(); + return; + } + + onRequestClose(); + e.preventDefault(); + e.stopPropagation(); + return; + } + + if (e.key === 'Enter') { + if (!isClosed && draftPoints.length >= 3) { + setIsClosed(true); + e.preventDefault(); + e.stopPropagation(); + } + return; + } + + if (e.key === 'Backspace' || e.key === 'Delete' || (e.key.toLowerCase() === 'z' && (e.metaKey || e.ctrlKey))) { + if (draftPoints.length > 0) { + onUndo(); + e.preventDefault(); + e.stopPropagation(); + } + } + }; + + window.addEventListener('keydown', onKeyDown); + return () => window.removeEventListener('keydown', onKeyDown); + }, [draftPoints.length, enabled, isClosed, onClear, onRequestClose, onUndo]); + + const viewSize = useMemo(() => ({ w: containerSize.w, h: containerSize.h }), [containerSize.h, containerSize.w]); + + const savedPath = useMemo(() => { + if (!savedPolygon) return ''; + + const from = savedViewTransform ?? viewerTransform; + const displayPoly = + viewSize.w > 0 && viewSize.h > 0 + ? remapPolygonBetweenViewerTransforms(savedPolygon, viewSize, from, viewerTransform) + : savedPolygon; + + return polygonToSvgPath(displayPoly); + }, [savedPolygon, savedViewTransform, viewSize, viewerTransform]); + + const draftPointsDisplay = useMemo(() => { + if (draftPoints.length === 0) return []; + + const from = draftViewTransform ?? viewerTransform; + return viewSize.w > 0 && viewSize.h > 0 + ? remapPointsBetweenViewerTransforms(draftPoints, viewSize, from, viewerTransform) + : draftPoints; + }, [draftPoints, draftViewTransform, viewSize, viewerTransform]); + + const draftPath = useMemo(() => { + if (!isClosed || draftPointsDisplay.length < 3) return ''; + return polygonToSvgPath({ points: draftPointsDisplay }); + }, [draftPointsDisplay, isClosed]); + + if (!enabled) return null; + + const canUndo = draftPoints.length > 0 && !busy; + const canClear = (draftPoints.length > 0 || isClosed) && !busy; + const canSave = isClosed && draftPoints.length >= 3 && !busy; + + return ( +
{ + // Prevent the browser context menu while drawing. + if (!enabled) return; + e.preventDefault(); + e.stopPropagation(); + }} + > + {/* UI chrome */} + {/* + Position below the viewer's top hover controls (Tumor/GT buttons + ImageControls). + Otherwise it visually overlaps the control bar in GridView/OverlayView. + */} +
+
+ + GT Polygon + {busy ? : null} +
+ + +
+ +
+ + + + + + + {savedPolygon ? ( + + ) : null} +
+ + {/* Error / status */} + {error ? ( +
+ {error} +
+ ) : !isClosed ? ( +
+ Click to add points. Click the first point (or press Enter) to close. +
+ ) : ( +
+ Polygon closed. Save to persist. +
+ )} + + {/* Saved polygon */} + {savedPath ? ( + + + + ) : null} + + {/* Draft polyline (during drawing) */} + {!isClosed && draftPointsDisplay.length > 0 ? ( + + `${p.x.toFixed(4)},${p.y.toFixed(4)}`).join(' ')} + fill="none" + stroke="rgba(245, 158, 11, 0.95)" + strokeWidth={2} + vectorEffect="non-scaling-stroke" + strokeLinecap="round" + strokeLinejoin="round" + /> + + ) : null} + + {/* Draft closed polygon */} + {draftPath ? ( + + + + ) : null} + + {/* Vertex handles */} + {draftPointsDisplay.length > 0 ? ( + + {draftPointsDisplay.map((p, idx) => ( + + ))} + + ) : null} +
+ ); +} diff --git a/frontend/src/components/Svr3DView.tsx b/frontend/src/components/Svr3DView.tsx new file mode 100644 index 0000000..ccac75e --- /dev/null +++ b/frontend/src/components/Svr3DView.tsx @@ -0,0 +1,1090 @@ +import { useEffect, useMemo, useRef, useState } from 'react'; +import { ChevronLeft, ChevronRight, Loader2 } from 'lucide-react'; +import cornerstone from 'cornerstone-core'; +import { getDB } from '../db/db'; +import type { DicomInstance } from '../db/schema'; +import type { ComparisonData } from '../types/api'; +import type { SvrParams, SvrRoi, SvrRoiPlane, SvrSelectedSeries } from '../types/svr'; +import { formatSequenceLabel } from '../utils/clinicalData'; +import { DEFAULT_SVR_PARAMS } from '../types/svr'; +import { useSvrReconstruction } from '../hooks/useSvrReconstruction'; +import { getSortedSopInstanceUidsForSeries } from '../utils/localApi'; +import type { SliceGeometry } from '../utils/svr/dicomGeometry'; +import { getSliceGeometryFromInstance } from '../utils/svr/dicomGeometry'; +import { resample2dAreaAverage } from '../utils/svr/resample2d'; +import { SvrVolume3DViewer } from './SvrVolume3DViewer'; + +function sortedDatesDesc(dates: string[]): string[] { + return [...dates].sort((a, b) => b.localeCompare(a)); +} + +function formatSeriesLabel(seq: { plane: string | null; weight: string | null; sequence: string | null }): string { + const base = formatSequenceLabel(seq); + return [seq.plane, base].filter(Boolean).join(' ') || 'Unknown'; +} + +function sequenceGroupKey(seq: { weight: string | null; sequence: string | null }): string { + return `${seq.weight ?? ''}|||${seq.sequence ?? ''}`; +} + + +type RoiRect01 = { + x0: number; + y0: number; + x1: number; + y1: number; +}; + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +function normalizeRect01(rect: RoiRect01): { left: number; right: number; top: number; bottom: number } { + return { + left: Math.min(rect.x0, rect.x1), + right: Math.max(rect.x0, rect.x1), + top: Math.min(rect.y0, rect.y1), + bottom: Math.max(rect.y0, rect.y1), + }; +} + +function clampInt(x: number, min: number, max: number): number { + if (!Number.isFinite(x)) return min; + const xi = Math.round(x); + return xi < min ? min : xi > max ? max : xi; +} + +function inferRoiPlaneFromNormalDir(normalDir: SliceGeometry['normalDir']): SvrRoiPlane { + const ax = Math.abs(normalDir.x); + const ay = Math.abs(normalDir.y); + const az = Math.abs(normalDir.z); + + // DICOM patient/world axes: X=left-right, Y=posterior-anterior, Z=foot-head. + // Normal mostly along Z => axial slices. + if (az >= ax && az >= ay) return 'axial'; + if (ay >= ax && ay >= az) return 'coronal'; + return 'sagittal'; +} + +function computeCubeRoiFromDicomRect01(params: { + rect: RoiRect01; + geom: SliceGeometry; + sourceSeriesUid: string; +}): SvrRoi | null { + const { rect, geom, sourceSeriesUid } = params; + + const r = normalizeRect01(rect); + const w01 = r.right - r.left; + const h01 = r.bottom - r.top; + if (w01 <= 1e-4 || h01 <= 1e-4) return null; + + const rMax = Math.max(0, geom.rows - 1); + const cMax = Math.max(0, geom.cols - 1); + + // Pixel-space center. + const rowCenter = (r.top + r.bottom) * 0.5 * rMax; + const colCenter = (r.left + r.right) * 0.5 * cMax; + + // World center using the same mapping used by the reconstruction: + // world(r,c) = IPP + colDir*(r*rowSpacing) + rowDir*(c*colSpacing) + const cx = + geom.ippMm.x + geom.colDir.x * (rowCenter * geom.rowSpacingMm) + geom.rowDir.x * (colCenter * geom.colSpacingMm); + const cy = + geom.ippMm.y + geom.colDir.y * (rowCenter * geom.rowSpacingMm) + geom.rowDir.y * (colCenter * geom.colSpacingMm); + const cz = + geom.ippMm.z + geom.colDir.z * (rowCenter * geom.rowSpacingMm) + geom.rowDir.z * (colCenter * geom.colSpacingMm); + + // In-plane box extents in mm. + const widthMm = w01 * cMax * geom.colSpacingMm; + const heightMm = h01 * rMax * geom.rowSpacingMm; + + // Expand to a cube (equal extents along X/Y/Z) for simplicity. + const sideMm = Math.max(widthMm, heightMm); + if (!(sideMm > 1e-6)) return null; + + const half = sideMm * 0.5; + return { + mode: 'cube', + sourcePlane: inferRoiPlaneFromNormalDir(geom.normalDir), + sourceSeriesUid, + boundsMm: { + min: [cx - half, cy - half, cz - half], + max: [cx + half, cy + half, cz + half], + }, + }; +} + +function computeDownsampleSize(rows: number, cols: number, maxSize: number): { dsRows: number; dsCols: number } { + const maxDim = Math.max(rows, cols); + if (!Number.isFinite(maxSize) || maxSize <= 1) { + return { dsRows: Math.max(1, rows), dsCols: Math.max(1, cols) }; + } + + const scale = maxDim > maxSize ? maxSize / maxDim : 1; + return { + dsRows: Math.max(1, Math.round(rows * scale)), + dsCols: Math.max(1, Math.round(cols * scale)), + }; +} + +function drawDicomPixelDataToCanvas(params: { + canvas: HTMLCanvasElement; + pixelData: ArrayLike; + rows: number; + cols: number; + maxSize: number; + slope?: number; + intercept?: number; +}): void { + const { canvas, pixelData, rows, cols, maxSize } = params; + const slope = typeof params.slope === 'number' ? params.slope : 1; + const intercept = typeof params.intercept === 'number' ? params.intercept : 0; + + const { dsRows, dsCols } = computeDownsampleSize(rows, cols, maxSize); + + // Higher-fidelity downsampling (box/area average) to reduce aliasing in the ROI preview. + const down = resample2dAreaAverage(pixelData, rows, cols, dsRows, dsCols); + + // Apply modality scaling when available. (Linear, so applying post-downsample is equivalent.) + if (slope !== 1 || intercept !== 0) { + for (let i = 0; i < down.length; i++) { + down[i] = down[i] * slope + intercept; + } + } + + if (canvas.width !== dsCols) canvas.width = dsCols; + if (canvas.height !== dsRows) canvas.height = dsRows; + + const ctx = canvas.getContext('2d'); + if (!ctx) return; + + // Robust windowing (percentiles) is less sensitive to background/outliers than raw min/max. + const finite: number[] = []; + for (let i = 0; i < down.length; i++) { + const v = down[i]; + if (Number.isFinite(v)) finite.push(v); + } + + finite.sort((a, b) => a - b); + + const quantileSorted = (sorted: number[], q: number): number => { + const n = sorted.length; + if (n === 0) return 0; + const qq = q < 0 ? 0 : q > 1 ? 1 : q; + const idx = qq * (n - 1); + const i0 = Math.floor(idx); + const i1 = Math.min(n - 1, i0 + 1); + const t = idx - i0; + const a = sorted[i0] ?? 0; + const b = sorted[i1] ?? a; + return a + (b - a) * t; + }; + + let lo = quantileSorted(finite, 0.01); + let hi = quantileSorted(finite, 0.99); + + if (!(hi > lo + 1e-12)) { + lo = finite[0] ?? 0; + hi = finite[finite.length - 1] ?? lo + 1; + } + + const invRange = hi > lo + 1e-12 ? 1 / (hi - lo) : 0; + + const img = ctx.createImageData(dsCols, dsRows); + const out = img.data; + + for (let i = 0; i < down.length; i++) { + const v = down[i]; + const n = Number.isFinite(v) && invRange > 0 ? (v - lo) * invRange : 0; + const b = Math.round(clamp01(n) * 255); + + const idx = i * 4; + out[idx] = b; + out[idx + 1] = b; + out[idx + 2] = b; + out[idx + 3] = 255; + } + + ctx.putImageData(img, 0, 0); +} + +export function DicomRoiSlicePreview(props: { + slice: { sopInstanceUid: string; geom: SliceGeometry } | null; + sourceSeriesUid: string | null; + maxSize: number; + roiRect: RoiRect01 | null; + setRoiRect: (next: RoiRect01 | null) => void; + roiDragRef: { current: { x0: number; y0: number } | null }; + onSliceDelta: (delta: number) => void; + onRoiFinalized: (roi: SvrRoi | null) => void; + disabled?: boolean; +}) { + const { slice, sourceSeriesUid, maxSize, roiRect, setRoiRect, roiDragRef, onSliceDelta, onRoiFinalized, disabled } = props; + + const rect = roiRect ? normalizeRect01(roiRect) : null; + + const canvasRef = useRef(null); + const [renderError, setRenderError] = useState(null); + + useEffect(() => { + const canvas = canvasRef.current; + if (!canvas) return; + + setRenderError(null); + + if (!slice) { + // Clear canvas. + const ctx = canvas.getContext('2d'); + if (ctx) ctx.clearRect(0, 0, canvas.width, canvas.height); + return; + } + + let alive = true; + + const run = async () => { + try { + const imageId = `miradb:${slice.sopInstanceUid}`; + const image = await cornerstone.loadImage(imageId); + + const getPixelData = (image as unknown as { getPixelData?: () => ArrayLike }).getPixelData; + if (typeof getPixelData !== 'function') { + throw new Error('Cornerstone image did not expose getPixelData()'); + } + + const pixelData = getPixelData.call(image); + + if (!alive) return; + + const slope = typeof (image as unknown as { slope?: unknown }).slope === 'number' ? (image as unknown as { slope: number }).slope : 1; + const intercept = + typeof (image as unknown as { intercept?: unknown }).intercept === 'number' ? (image as unknown as { intercept: number }).intercept : 0; + + drawDicomPixelDataToCanvas({ + canvas, + pixelData, + rows: slice.geom.rows, + cols: slice.geom.cols, + maxSize, + slope, + intercept, + }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + if (!alive) return; + setRenderError(msg); + } + }; + + void run(); + + return () => { + alive = false; + }; + }, [maxSize, slice]); + + const wheelAccumRef = useRef(0); + const wheelTargetRef = useRef(null); + useEffect(() => { + const el = wheelTargetRef.current; + if (!el) return; + + const onWheel = (e: WheelEvent) => { + if (disabled) return; + if (!Number.isFinite(e.deltaY) || e.deltaY === 0) return; + + // Trackpads generate many small deltas; accumulate and step in whole slices. + wheelAccumRef.current += e.deltaY; + + const stepPx = 60; + while (Math.abs(wheelAccumRef.current) >= stepPx) { + const dir = wheelAccumRef.current > 0 ? 1 : -1; + wheelAccumRef.current -= dir * stepPx; + + // Convention: wheel down (deltaY>0) => next slice. + onSliceDelta(dir); + } + + e.preventDefault(); + e.stopPropagation(); + }; + + el.addEventListener('wheel', onWheel, { passive: false }); + return () => el.removeEventListener('wheel', onWheel); + }, [disabled, onSliceDelta]); + + const aspect = slice ? { w: slice.geom.cols, h: slice.geom.rows } : { w: 1, h: 1 }; + + return ( +
+
+ + + {rect ? ( +
+ ) : null} + + {renderError ? ( +
+ {renderError} +
+ ) : !slice ? ( +
+ Select a series to preview an input slice. +
+ ) : null} + +
{ + if (disabled || !slice || !sourceSeriesUid) return; + const box = e.currentTarget.getBoundingClientRect(); + const x = clamp01((e.clientX - box.left) / box.width); + const y = clamp01((e.clientY - box.top) / box.height); + + roiDragRef.current = { x0: x, y0: y }; + setRoiRect({ x0: x, y0: y, x1: x, y1: y }); + onRoiFinalized(null); + + e.currentTarget.setPointerCapture(e.pointerId); + e.preventDefault(); + e.stopPropagation(); + }} + onPointerMove={(e) => { + const drag = roiDragRef.current; + if (disabled || !slice || !drag) return; + + const box = e.currentTarget.getBoundingClientRect(); + const x = clamp01((e.clientX - box.left) / box.width); + const y = clamp01((e.clientY - box.top) / box.height); + + setRoiRect({ x0: drag.x0, y0: drag.y0, x1: x, y1: y }); + e.preventDefault(); + e.stopPropagation(); + }} + onPointerUp={(e) => { + const drag = roiDragRef.current; + roiDragRef.current = null; + + if (!drag || !slice || !sourceSeriesUid) { + e.preventDefault(); + e.stopPropagation(); + return; + } + + const box = e.currentTarget.getBoundingClientRect(); + const x = clamp01((e.clientX - box.left) / box.width); + const y = clamp01((e.clientY - box.top) / box.height); + + const finalRect: RoiRect01 = { x0: drag.x0, y0: drag.y0, x1: x, y1: y }; + setRoiRect(finalRect); + + const roi = computeCubeRoiFromDicomRect01({ rect: finalRect, geom: slice.geom, sourceSeriesUid }); + onRoiFinalized(roi); + + e.preventDefault(); + e.stopPropagation(); + }} + onPointerCancel={(e) => { + roiDragRef.current = null; + e.preventDefault(); + e.stopPropagation(); + }} + /> +
+ +
+ Input slice + {roiRect ? Box : null} +
+
+ ); +} + +const lastRoiPreviewSliceIndexBySeriesUid = new Map(); + +export type Svr3DViewProps = { + data: ComparisonData; + defaultDateIso?: string | null; + defaultSeqId?: string | null; + /** + * Fallback slice selection for the ROI preview. + * Usually comes from the last-viewed slice in the grid/overlay views. + */ + fallbackRoiSeriesUid?: string | null; + fallbackRoiSliceIndex?: number | null; +}; + +export function Svr3DView({ data, defaultDateIso, defaultSeqId, fallbackRoiSeriesUid, fallbackRoiSliceIndex }: Svr3DViewProps) { + const dates = useMemo(() => sortedDatesDesc(data.dates), [data.dates]); + const dateIso = defaultDateIso && dates.includes(defaultDateIso) ? defaultDateIso : dates[0] || null; + + const [params, setParams] = useState(() => ({ + ...DEFAULT_SVR_PARAMS, + sliceDownsampleMode: 'voxel-aware', + seriesRegistrationMode: 'roi-rigid', + })); + const [generationCollapsed, setGenerationCollapsed] = useState(false); + + const { isRunning, progress, result, error, run, cancel, clear } = useSvrReconstruction(); + + const sequenceGroupsForDate = useMemo(() => { + if (!dateIso) return []; + + const byKey = new Map< + string, + { + label: string; + weight: string | null; + sequence: string | null; + series: SvrSelectedSeries[]; + planeSet: Set; + sliceCount: number; + } + >(); + + for (const seq of data.sequences) { + const ref = data.series_map[seq.id]?.[dateIso]; + if (!ref) continue; + + const seqLabel = formatSequenceLabel(seq); + if (seqLabel === 'Unknown') continue; + + const key = sequenceGroupKey(seq); + let g = byKey.get(key); + if (!g) { + g = { + label: seqLabel, + weight: seq.weight, + sequence: seq.sequence, + series: [], + planeSet: new Set(), + sliceCount: 0, + }; + byKey.set(key, g); + } + + g.series.push({ + seriesUid: ref.series_uid, + studyId: ref.study_id, + dateIso, + instanceCount: ref.instance_count, + label: formatSeriesLabel(seq), + plane: seq.plane, + weight: seq.weight, + sequence: seq.sequence, + }); + + g.planeSet.add(seq.plane || 'Unknown'); + g.sliceCount += ref.instance_count; + } + + const out = Array.from(byKey, ([key, g]) => { + // Keep stable ordering within a group: plane, then label. + g.series.sort((a, b) => { + const pa = a.plane || ''; + const pb = b.plane || ''; + if (pa !== pb) return pa.localeCompare(pb); + return a.label.localeCompare(b.label); + }); + + return { + key, + label: g.label, + weight: g.weight, + sequence: g.sequence, + series: g.series, + planeCount: g.planeSet.size, + sliceCount: g.sliceCount, + }; + }); + + out.sort((a, b) => a.label.localeCompare(b.label)); + return out; + }, [data.sequences, data.series_map, dateIso]); + + const defaultSelectedSequenceKey = useMemo(() => { + if (!dateIso) return null; + + const fallback = sequenceGroupsForDate[0]?.key ?? null; + if (!defaultSeqId) return fallback; + + const currentSeq = data.sequences.find((s) => s.id === defaultSeqId); + if (!currentSeq) return fallback; + + if (formatSequenceLabel(currentSeq) === 'Unknown') return fallback; + + const key = sequenceGroupKey(currentSeq); + return sequenceGroupsForDate.some((g) => g.key === key) ? key : fallback; + }, [data.sequences, dateIso, defaultSeqId, sequenceGroupsForDate]); + + const [selectedSequenceKey, setSelectedSequenceKey] = useState(defaultSelectedSequenceKey); + + useEffect(() => { + setSelectedSequenceKey(defaultSelectedSequenceKey); + }, [defaultSelectedSequenceKey]); + + const selectedGroup = useMemo(() => { + if (!selectedSequenceKey) return null; + return sequenceGroupsForDate.find((g) => g.key === selectedSequenceKey) ?? null; + }, [selectedSequenceKey, sequenceGroupsForDate]); + + const selectedSeries = useMemo(() => selectedGroup?.series ?? [], [selectedGroup]); + + // ROI-first flow: pick a ROI on an input slice, then run SVR restricted to that cube. + const [roiSeriesUid, setRoiSeriesUid] = useState(null); + + const preferredRoiSeriesUid = useMemo(() => { + if (!defaultSeqId) return null; + const seq = data.sequences.find((s) => s.id === defaultSeqId); + if (!seq) return null; + + // Prefer the same plane the user was looking at in the comparison grid/overlay. + const match = selectedSeries.find((s) => (s.plane ?? null) === (seq.plane ?? null)); + return match?.seriesUid ?? null; + }, [data.sequences, defaultSeqId, selectedSeries]); + + const effectiveRoiSeriesUid = useMemo(() => { + if (roiSeriesUid && selectedSeries.some((s) => s.seriesUid === roiSeriesUid)) { + return roiSeriesUid; + } + return preferredRoiSeriesUid ?? selectedSeries[0]?.seriesUid ?? null; + }, [preferredRoiSeriesUid, roiSeriesUid, selectedSeries]); + + const roiSeries = useMemo(() => { + if (!effectiveRoiSeriesUid) return null; + return selectedSeries.find((s) => s.seriesUid === effectiveRoiSeriesUid) ?? null; + }, [effectiveRoiSeriesUid, selectedSeries]); + + const [roiSeriesSopUids, setRoiSeriesSopUids] = useState(null); + const [roiSeriesSopUidsError, setRoiSeriesSopUidsError] = useState(null); + + // Use -1 as a sentinel meaning "auto (middle slice)". + const [roiSliceIndex, setRoiSliceIndex] = useState(-1); + + const [roiSliceGeom, setRoiSliceGeom] = useState(null); + const [roiSliceGeomError, setRoiSliceGeomError] = useState(null); + + const [roiRect, setRoiRect] = useState(null); + const roiDragRef = useRef<{ x0: number; y0: number } | null>(null); + + // Keep fallback slice inputs in refs so ROI-series effects don't retrigger on every slice tick. + const fallbackRoiSeriesUidRef = useRef(fallbackRoiSeriesUid); + const fallbackRoiSliceIndexRef = useRef(fallbackRoiSliceIndex); + useEffect(() => { + fallbackRoiSeriesUidRef.current = fallbackRoiSeriesUid; + fallbackRoiSliceIndexRef.current = fallbackRoiSliceIndex; + }, [fallbackRoiSeriesUid, fallbackRoiSliceIndex]); + + // Canonical ROI used for reconstruction (stays valid even if the user scrolls away from the selection slice). + const [roiWorld, setRoiWorld] = useState(null); + + // Date is controlled by the surrounding UI (Dates sidebar). When it changes, clear local selection/ROI/run results. + const prevDateIsoRef = useRef(dateIso); + useEffect(() => { + if (prevDateIsoRef.current === dateIso) return; + prevDateIsoRef.current = dateIso; + + setRoiSeriesUid(null); + setRoiRect(null); + roiDragRef.current = null; + setRoiWorld(null); + + clear(); + }, [clear, dateIso]); + + useEffect(() => { + setRoiSeriesSopUids(null); + setRoiSeriesSopUidsError(null); + + // Slice selection priority: + // 1) The last slice the user viewed in the SVR ROI preview for this series. + // 2) The last slice the user viewed in the grid/overlay views (if it matches this series). + // 3) Default to the middle slice. + const saved = effectiveRoiSeriesUid ? lastRoiPreviewSliceIndexBySeriesUid.get(effectiveRoiSeriesUid) : undefined; + + let nextSliceIndex = -1; + if (typeof saved === 'number' && Number.isFinite(saved)) { + nextSliceIndex = Math.round(saved); + } else { + const fallbackSeries = fallbackRoiSeriesUidRef.current; + const fallbackSlice = fallbackRoiSliceIndexRef.current; + + if ( + effectiveRoiSeriesUid && + fallbackSeries && + effectiveRoiSeriesUid === fallbackSeries && + typeof fallbackSlice === 'number' && + Number.isFinite(fallbackSlice) + ) { + nextSliceIndex = Math.round(fallbackSlice); + } + } + + setRoiSliceIndex(nextSliceIndex); + setRoiSliceGeom(null); + setRoiSliceGeomError(null); + + setRoiRect(null); + roiDragRef.current = null; + setRoiWorld(null); + + if (!effectiveRoiSeriesUid) return; + + let alive = true; + const run = async () => { + try { + const uids = await getSortedSopInstanceUidsForSeries(effectiveRoiSeriesUid); + if (!alive) return; + setRoiSeriesSopUids(uids); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + if (!alive) return; + setRoiSeriesSopUidsError(msg); + } + }; + + void run(); + return () => { + alive = false; + }; + }, [effectiveRoiSeriesUid]); + + // Persist explicit slice selection (>=0) so leaving/re-entering SVR preserves ROI preview position. + const roiSeriesCount = roiSeriesSopUids?.length ?? 0; + useEffect(() => { + if (!effectiveRoiSeriesUid) return; + if (roiSliceIndex < 0) return; + + const idx = roiSeriesCount > 0 ? clampInt(roiSliceIndex, 0, roiSeriesCount - 1) : roiSliceIndex; + lastRoiPreviewSliceIndexBySeriesUid.set(effectiveRoiSeriesUid, idx); + }, [effectiveRoiSeriesUid, roiSeriesCount, roiSliceIndex]); + + const effectiveRoiSliceIndex = useMemo(() => { + if (roiSeriesCount <= 0) return 0; + + const dflt = Math.floor(roiSeriesCount / 2); + return roiSliceIndex >= 0 ? clampInt(roiSliceIndex, 0, roiSeriesCount - 1) : dflt; + }, [roiSeriesCount, roiSliceIndex]); + + const roiSopInstanceUid = roiSeriesSopUids ? (roiSeriesSopUids[effectiveRoiSliceIndex] ?? null) : null; + + useEffect(() => { + setRoiSliceGeom(null); + setRoiSliceGeomError(null); + + // The selection rectangle is tied to a specific slice; clear it when the slice changes. + setRoiRect(null); + roiDragRef.current = null; + + if (!roiSopInstanceUid) return; + + let alive = true; + const run = async () => { + try { + const db = await getDB(); + const inst = (await db.get('instances', roiSopInstanceUid)) as DicomInstance | undefined; + if (!inst) { + throw new Error('Missing DICOM instance for ROI preview'); + } + + const geom = getSliceGeometryFromInstance(inst); + if (!alive) return; + setRoiSliceGeom(geom); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + if (!alive) return; + setRoiSliceGeomError(msg); + } + }; + + void run(); + return () => { + alive = false; + }; + }, [roiSopInstanceUid]); + + const roiPreviewSlice = useMemo(() => { + if (!roiSopInstanceUid || !roiSliceGeom) return null; + return { sopInstanceUid: roiSopInstanceUid, geom: roiSliceGeom }; + }, [roiSopInstanceUid, roiSliceGeom]); + + const roiSideMm = useMemo(() => { + if (!roiWorld) return null; + const dx = roiWorld.boundsMm.max[0] - roiWorld.boundsMm.min[0]; + const dy = roiWorld.boundsMm.max[1] - roiWorld.boundsMm.min[1]; + const dz = roiWorld.boundsMm.max[2] - roiWorld.boundsMm.min[2]; + return Math.max(dx, dy, dz); + }, [roiWorld]); + + const selectedPlaneCount = selectedGroup?.planeCount ?? 0; + const canRun = !isRunning && selectedSeries.length >= 2 && selectedPlaneCount >= 2; + const percent = progress ? Math.round((progress.current / Math.max(1, progress.total)) * 100) : 0; + + const progressMessage = progress ? progress.message : ''; + + return ( +
+
+ {generationCollapsed ? null : ( +
+
+
+ Sequence on this date (uses all planes) +
+
+ {sequenceGroupsForDate.length === 0 ? ( +
No series found for this date.
+ ) : ( +
+ {sequenceGroupsForDate.map((g) => { + const checked = selectedSequenceKey === g.key; + + const planeLabel = `${g.planeCount} plane${g.planeCount === 1 ? '' : 's'}`; + const sliceLabel = `${g.sliceCount} slice${g.sliceCount === 1 ? '' : 's'}`; + + return ( + + ); + })} +
+ )} +
+
+ +
+ + Advanced SVR settings + + +
+
+ + + + +
+ +
+
+ Voxel size: Target isotropic output spacing. Smaller = more detail but slower/heavier. The voxel size may be + increased automatically to respect Max volume dim. +
+
+ Iterations: How many SVR refinement passes to run. 0 = quick “splat/average only”; higher can reduce + slice-to-slice inconsistency but costs time. +
+
+ Slice downsample max: Each input slice may be downsampled before reconstruction, but we won't downsample so far + that in-plane spacing becomes worse than the target voxel size. +
+
+ Max volume dim: Caps each output grid dimension (in voxels) by increasing voxel size if needed. Lower = + faster/smaller; higher = more memory/time. +
+
+ Tip: draw a box on an input slice and run Run SVR (box) to keep the volume smaller + faster. +
+
+
+
+ +
+ Focus box (optional) +
+
+ + +
+ + {roiSeriesSopUidsError ? ( +
{roiSeriesSopUidsError}
+ ) : roiSliceGeomError ? ( +
{roiSliceGeomError}
+ ) : null} + +
+
+ {roiSeriesSopUids && roiSeriesSopUids.length > 0 + ? `Slice ${effectiveRoiSliceIndex + 1} / ${roiSeriesSopUids.length}` + : roiSeries + ? 'Loading slices…' + : 'Select a series to preview'} +
+ +
+ + + +
+
+ + { + if (!roiSeriesSopUids || roiSeriesSopUids.length === 0) return; + const cur = roiSliceIndex >= 0 ? roiSliceIndex : effectiveRoiSliceIndex; + setRoiSliceIndex(clampInt(cur + delta, 0, roiSeriesSopUids.length - 1)); + }} + onRoiFinalized={(roi) => { + setRoiWorld(roi); + if (!roi) return; + + setParams((p) => { + const clamp = (x: number, min: number, max: number) => (x < min ? min : x > max ? max : x); + + const inPlaneMm = roiSliceGeom ? Math.min(roiSliceGeom.rowSpacingMm, roiSliceGeom.colSpacingMm) : p.targetVoxelSizeMm; + const nextVoxel = clamp(inPlaneMm, 0.25, 1.0); + + return { + ...p, + // Favor voxel size at (or slightly above) the best in-plane spacing. + targetVoxelSizeMm: nextVoxel, + // Ensure we don't downsample to a coarser spacing than the output voxels. + sliceDownsampleMode: 'voxel-aware', + // Allow near-native resolution (especially once ROI cropping is in place). + sliceDownsampleMaxSize: Math.max(p.sliceDownsampleMaxSize, 512), + // Allow higher-res grids for ROI work. + maxVolumeDim: Math.max(p.maxVolumeDim, 320), + // More refinement iterations for detail. + iterations: Math.max(p.iterations, 6), + stepSize: 0.5, + // Always use ROI rigid alignment. + seriesRegistrationMode: 'roi-rigid', + }; + }); + }} + disabled={isRunning} + /> + +
+ + + {roiWorld && roiSideMm ? ( +
+ Box: ~{roiSideMm.toFixed(1)}mm cube ({roiWorld.sourcePlane}) +
+ ) : null} +
+ +
+ Drag to draw a box on an input slice. When a box is set, Run SVR will reconstruct only that box. Starting + with a smaller box lets you decrease voxel size for more detail without making the volume huge. +
+
+
+ +
+ + +
+ {isRunning ? ( + + ) : null} + + +
+
+ + {progress && ( +
+
+ + {progressMessage} + {percent}% +
+
+
+
+
+ )} + + {error && ( +
+ {error} +
+ )} + + {!result ? ( +
Run SVR to generate a 3D volume (uses a focus box when set).
+ ) : ( +
+ Volume: {result.volume.dims[0]}×{result.volume.dims[1]}×{result.volume.dims[2]} @ {result.volume.voxelSizeMm[0]}mm +
+ )} +
+ )} + +
+ + + +
+
+
+ ); +} diff --git a/frontend/src/components/SvrModal.tsx b/frontend/src/components/SvrModal.tsx new file mode 100644 index 0000000..48d765a --- /dev/null +++ b/frontend/src/components/SvrModal.tsx @@ -0,0 +1,793 @@ +import { useEffect, useMemo, useRef, useState } from 'react'; +import { X, Layers3, Loader2, Download } from 'lucide-react'; +import type { ComparisonData } from '../types/api'; +import type { SvrParams, SvrRoi, SvrRoiPlane, SvrSelectedSeries } from '../types/svr'; +import { DEFAULT_SVR_PARAMS } from '../types/svr'; +import { useSvrReconstruction } from '../hooks/useSvrReconstruction'; +import { SvrVolume3DModal } from './SvrVolume3DModal'; + +export type SvrModalProps = { + data: ComparisonData; + defaultDateIso?: string | null; + defaultSeqId?: string | null; + onClose: () => void; +}; + +function sortedDatesDesc(dates: string[]): string[] { + return [...dates].sort((a, b) => b.localeCompare(a)); +} + +function formatSeqLabel(seq: { plane: string | null; weight: string | null; sequence: string | null }): string { + return [seq.plane, seq.weight, seq.sequence].filter(Boolean).join(' ') || 'Unknown'; +} + +function downloadBlob(blob: Blob, filename: string): void { + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + // Best-effort cleanup. + setTimeout(() => URL.revokeObjectURL(url), 2000); +} + +function toArrayBuffer(view: ArrayBufferView): ArrayBuffer { + // TS's DOM lib types require BlobParts to be backed by ArrayBuffer (not SharedArrayBuffer). + // Copying also ensures the bytes match the view's byteOffset/byteLength. + const bytes = new Uint8Array(view.buffer, view.byteOffset, view.byteLength); + const copy = new Uint8Array(bytes.byteLength); + copy.set(bytes); + return copy.buffer; +} + +type RoiRect01 = { + plane: SvrRoiPlane; + x0: number; + y0: number; + x1: number; + y1: number; +}; + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +function normalizeRect01(rect: RoiRect01): { left: number; right: number; top: number; bottom: number } { + return { + left: Math.min(rect.x0, rect.x1), + right: Math.max(rect.x0, rect.x1), + top: Math.min(rect.y0, rect.y1), + bottom: Math.max(rect.y0, rect.y1), + }; +} + +function fitIntervalToBounds(params: { min: number; max: number; boundMin: number; boundMax: number }): { min: number; max: number } { + let { min, max } = params; + const { boundMin, boundMax } = params; + + const len = max - min; + const boundLen = boundMax - boundMin; + if (!(len > 0) || !(boundLen > 0)) { + return { min: boundMin, max: boundMax }; + } + + // If the interval is larger than bounds, clamp (shrinks). + if (len >= boundLen) { + return { min: boundMin, max: boundMax }; + } + + // Otherwise, shift the interval so it fits without shrinking. + if (min < boundMin) { + max += boundMin - min; + min = boundMin; + } + if (max > boundMax) { + min -= max - boundMax; + max = boundMax; + } + + // Final safety clamp. + if (min < boundMin) min = boundMin; + if (max > boundMax) max = boundMax; + + return { min, max }; +} + +function computeCubeRoiFromRect01(rect: RoiRect01, volume: { dims: [number, number, number]; voxelSizeMm: [number, number, number]; originMm: [number, number, number] }): SvrRoi | null { + const { dims, voxelSizeMm, originMm } = volume; + const [nx, ny, nz] = dims; + + // SVR output is isotropic today. + const vox = voxelSizeMm[0]; + if (!Number.isFinite(vox) || vox <= 0) return null; + + const [ox, oy, oz] = originMm; + + const fullX = { min: ox, max: ox + (nx - 1) * vox }; + const fullY = { min: oy, max: oy + (ny - 1) * vox }; + const fullZ = { min: oz, max: oz + (nz - 1) * vox }; + + const r = normalizeRect01(rect); + const w01 = r.right - r.left; + const h01 = r.bottom - r.top; + if (w01 <= 1e-4 || h01 <= 1e-4) return null; + + const midX = ox + Math.floor(nx / 2) * vox; + const midY = oy + Math.floor(ny / 2) * vox; + const midZ = oz + Math.floor(nz / 2) * vox; + + // Map the 2D rect to two world axes (A,B); then expand to a cube by making A/B square + // and adding equal extent along the through-plane axis (C). + let a0 = 0; + let a1 = 0; + let b0 = 0; + let b1 = 0; + + // Centers for the through-plane axis. + let cx = midX; + let cy = midY; + let cz = midZ; + + if (rect.plane === 'axial') { + a0 = ox + r.left * (nx - 1) * vox; + a1 = ox + r.right * (nx - 1) * vox; + b0 = oy + r.top * (ny - 1) * vox; + b1 = oy + r.bottom * (ny - 1) * vox; + // through-plane is Z + cz = midZ; + } else if (rect.plane === 'coronal') { + a0 = ox + r.left * (nx - 1) * vox; + a1 = ox + r.right * (nx - 1) * vox; + b0 = oz + r.top * (nz - 1) * vox; + b1 = oz + r.bottom * (nz - 1) * vox; + // through-plane is Y + cy = midY; + } else { + // sagittal + a0 = oy + r.left * (ny - 1) * vox; + a1 = oy + r.right * (ny - 1) * vox; + b0 = oz + r.top * (nz - 1) * vox; + b1 = oz + r.bottom * (nz - 1) * vox; + // through-plane is X + cx = midX; + } + + const aCenter = (a0 + a1) * 0.5; + const bCenter = (b0 + b1) * 0.5; + const sideMm = Math.max(Math.abs(a1 - a0), Math.abs(b1 - b0)); + if (!(sideMm > 1e-6)) return null; + + // Start with a cube centered on the drawn box (in-plane) and the mid-slice (through-plane). + // Then shift it as needed to keep it inside the current volume bounds. + let xMin = 0; + let xMax = 0; + let yMin = 0; + let yMax = 0; + let zMin = 0; + let zMax = 0; + + const half = sideMm * 0.5; + + if (rect.plane === 'axial') { + const xi = fitIntervalToBounds({ min: aCenter - half, max: aCenter + half, boundMin: fullX.min, boundMax: fullX.max }); + const yi = fitIntervalToBounds({ min: bCenter - half, max: bCenter + half, boundMin: fullY.min, boundMax: fullY.max }); + const zi = fitIntervalToBounds({ min: cz - half, max: cz + half, boundMin: fullZ.min, boundMax: fullZ.max }); + xMin = xi.min; + xMax = xi.max; + yMin = yi.min; + yMax = yi.max; + zMin = zi.min; + zMax = zi.max; + } else if (rect.plane === 'coronal') { + const xi = fitIntervalToBounds({ min: aCenter - half, max: aCenter + half, boundMin: fullX.min, boundMax: fullX.max }); + const zi = fitIntervalToBounds({ min: bCenter - half, max: bCenter + half, boundMin: fullZ.min, boundMax: fullZ.max }); + const yi = fitIntervalToBounds({ min: cy - half, max: cy + half, boundMin: fullY.min, boundMax: fullY.max }); + xMin = xi.min; + xMax = xi.max; + yMin = yi.min; + yMax = yi.max; + zMin = zi.min; + zMax = zi.max; + } else { + const yi = fitIntervalToBounds({ min: aCenter - half, max: aCenter + half, boundMin: fullY.min, boundMax: fullY.max }); + const zi = fitIntervalToBounds({ min: bCenter - half, max: bCenter + half, boundMin: fullZ.min, boundMax: fullZ.max }); + const xi = fitIntervalToBounds({ min: cx - half, max: cx + half, boundMin: fullX.min, boundMax: fullX.max }); + xMin = xi.min; + xMax = xi.max; + yMin = yi.min; + yMax = yi.max; + zMin = zi.min; + zMax = zi.max; + } + + return { + mode: 'cube', + sourcePlane: rect.plane, + boundsMm: { + min: [xMin, yMin, zMin], + max: [xMax, yMax, zMax], + }, + }; +} + +function RoiSelectablePreview(props: { + plane: SvrRoiPlane; + label: string; + url: string | undefined; + aspectW: number; + aspectH: number; + roiRect: RoiRect01 | null; + setRoiRect: (next: RoiRect01 | null) => void; + roiDragRef: { current: { plane: SvrRoiPlane; x0: number; y0: number } | null }; + disabled?: boolean; +}) { + const { plane, label, url, aspectW, aspectH, roiRect, setRoiRect, roiDragRef, disabled } = props; + + const rect = roiRect?.plane === plane ? normalizeRect01(roiRect) : null; + + return ( +
+
+ {url ? {label} : null} + + {rect ? ( +
+ ) : null} + + {url ? ( +
{ + if (disabled) return; + const box = e.currentTarget.getBoundingClientRect(); + const x = clamp01((e.clientX - box.left) / box.width); + const y = clamp01((e.clientY - box.top) / box.height); + + roiDragRef.current = { plane, x0: x, y0: y }; + setRoiRect({ plane, x0: x, y0: y, x1: x, y1: y }); + + e.currentTarget.setPointerCapture(e.pointerId); + e.preventDefault(); + e.stopPropagation(); + }} + onPointerMove={(e) => { + const drag = roiDragRef.current; + if (disabled || !drag || drag.plane !== plane) return; + + const box = e.currentTarget.getBoundingClientRect(); + const x = clamp01((e.clientX - box.left) / box.width); + const y = clamp01((e.clientY - box.top) / box.height); + + setRoiRect({ plane, x0: drag.x0, y0: drag.y0, x1: x, y1: y }); + e.preventDefault(); + e.stopPropagation(); + }} + onPointerUp={(e) => { + const drag = roiDragRef.current; + if (drag?.plane === plane) { + roiDragRef.current = null; + } + e.preventDefault(); + e.stopPropagation(); + }} + onPointerCancel={(e) => { + const drag = roiDragRef.current; + if (drag?.plane === plane) { + roiDragRef.current = null; + } + e.preventDefault(); + e.stopPropagation(); + }} + /> + ) : null} +
+ +
+ {label} + {roiRect?.plane === plane ? ROI : null} +
+
+ ); +} + +export function SvrModal({ data, defaultDateIso, defaultSeqId, onClose }: SvrModalProps) { + const dates = useMemo(() => sortedDatesDesc(data.dates), [data.dates]); + const initialDate = defaultDateIso && dates.includes(defaultDateIso) ? defaultDateIso : dates[0] || null; + + const [dateIso, setDateIso] = useState(initialDate); + const [params, setParams] = useState(DEFAULT_SVR_PARAMS); + + const { isRunning, progress, result, error, run, cancel, clear } = useSvrReconstruction(); + + const [viewer3dOpen, setViewer3dOpen] = useState(false); + + const [roiRect, setRoiRect] = useState(null); + const roiDragRef = useRef<{ plane: SvrRoiPlane; x0: number; y0: number } | null>(null); + + const [lastRunMeta, setLastRunMeta] = useState<{ params: SvrParams; selectedSeries: SvrSelectedSeries[] } | null>(null); + + const roi = useMemo(() => { + if (!roiRect || !result) return null; + return computeCubeRoiFromRect01(roiRect, result.volume); + }, [roiRect, result]); + + const roiSideMm = useMemo(() => { + if (!roi) return null; + const dx = roi.boundsMm.max[0] - roi.boundsMm.min[0]; + const dy = roi.boundsMm.max[1] - roi.boundsMm.min[1]; + const dz = roi.boundsMm.max[2] - roi.boundsMm.min[2]; + return Math.max(dx, dy, dz); + }, [roi]); + + const optionsForDate: SvrSelectedSeries[] = useMemo(() => { + if (!dateIso) return []; + + const out: SvrSelectedSeries[] = []; + + for (const seq of data.sequences) { + const ref = data.series_map[seq.id]?.[dateIso]; + if (!ref) continue; + + out.push({ + seriesUid: ref.series_uid, + studyId: ref.study_id, + dateIso, + instanceCount: ref.instance_count, + label: formatSeqLabel(seq), + plane: seq.plane, + weight: seq.weight, + sequence: seq.sequence, + }); + } + + // Keep stable ordering: plane, then label. + out.sort((a, b) => { + const pa = a.plane || ''; + const pb = b.plane || ''; + if (pa !== pb) return pa.localeCompare(pb); + return a.label.localeCompare(b.label); + }); + + return out; + }, [data.sequences, data.series_map, dateIso]); + + const [selectedUids, setSelectedUids] = useState>(new Set()); + + // Preselect: when opening, select all planes matching the current weight/sequence (not plane). + const didInitSelectionRef = useRef(false); + useEffect(() => { + if (didInitSelectionRef.current) return; + + // Mark as initialized even if we don't end up selecting anything, + // so we don't keep re-running this effect. + didInitSelectionRef.current = true; + + if (!dateIso || !defaultSeqId) return; + + const currentSeq = data.sequences.find((s) => s.id === defaultSeqId); + if (!currentSeq) return; + + const next = new Set(); + for (const opt of optionsForDate) { + if (opt.weight === currentSeq.weight && opt.sequence === currentSeq.sequence) { + next.add(opt.seriesUid); + } + } + + if (next.size > 0) { + setSelectedUids(next); + } + }, [data.sequences, dateIso, defaultSeqId, optionsForDate]); + + // Object URLs for previews + const [previewUrls, setPreviewUrls] = useState<{ axial?: string; coronal?: string; sagittal?: string }>({}); + useEffect(() => { + // Cleanup previous URLs. + for (const url of Object.values(previewUrls)) { + if (url) URL.revokeObjectURL(url); + } + + if (!result) { + setPreviewUrls({}); + return; + } + + const next = { + axial: URL.createObjectURL(result.previews.axial), + coronal: URL.createObjectURL(result.previews.coronal), + sagittal: URL.createObjectURL(result.previews.sagittal), + }; + + setPreviewUrls(next); + + return () => { + for (const url of Object.values(next)) { + if (url) URL.revokeObjectURL(url); + } + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [result]); + + const selectedSeries = useMemo(() => { + const m = new Map(optionsForDate.map((o) => [o.seriesUid, o] as const)); + return Array.from(selectedUids) + .map((uid) => m.get(uid)) + .filter((x): x is SvrSelectedSeries => !!x); + }, [optionsForDate, selectedUids]); + + const canRun = !isRunning && selectedSeries.length >= 2; + + const percent = progress ? Math.round((progress.current / Math.max(1, progress.total)) * 100) : 0; + + return ( + <> + {viewer3dOpen && result ? ( + { + setViewer3dOpen(false); + }} + /> + ) : null} + +
+
+
+

+ + Slice-to-Volume Reconstruction (SVR) +

+ +
+ +
+
+
+ Select multiple series from different planes for a single date, then run iterative SVR (multi-plane fusion + refinement). +
+ +
+ + +
+ +
+
+ Series on this date (pick 2+ across planes) +
+
+ {optionsForDate.length === 0 ? ( +
No series found for this date.
+ ) : ( +
+ {optionsForDate.map((opt) => { + const checked = selectedUids.has(opt.seriesUid); + return ( + + ); + })} +
+ )} +
+
+ +
+ + + + +
+ +
+ + +
+ {isRunning ? ( + + ) : null} + + +
+
+ + {progress && ( +
+
+ + {progress.message} + {percent}% +
+
+
+
+
+ )} + + {error && ( +
+ {error} +
+ )} +
+ +
+
Result
+ + {!result ? ( +
+ Run SVR to generate a reconstructed volume and orthogonal previews. +
+ ) : ( + <> +
+ Volume: {result.volume.dims[0]}×{result.volume.dims[1]}×{result.volume.dims[2]} @ {result.volume.voxelSizeMm[0]}mm +
+ +
+ + + +
+ +
+ Drag a box on a preview to define a cube ROI (the box is expanded to a cube automatically), then run SVR in ROI. +
+ +
+ + + + + {roi && roiSideMm ? ( +
+ ROI: ~{roiSideMm.toFixed(1)}mm cube ({roi.sourcePlane}) +
+ ) : null} +
+ +
+ + + + + +
+ +
+ Note: .f32 is raw Float32 voxels in x-fastest order. Use the JSON sidecar for dims/spacing/origin. +
+ + )} +
+
+
+
+ + ); +} diff --git a/frontend/src/components/SvrVolume3DModal.tsx b/frontend/src/components/SvrVolume3DModal.tsx new file mode 100644 index 0000000..14ce4d4 --- /dev/null +++ b/frontend/src/components/SvrVolume3DModal.tsx @@ -0,0 +1,574 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { X } from 'lucide-react'; +import type { SvrVolume } from '../types/svr'; + +function clamp(x: number, min: number, max: number): number { + return x < min ? min : x > max ? max : x; +} + +function mat3FromYawPitch(yaw: number, pitch: number): Float32Array { + const cy = Math.cos(yaw); + const sy = Math.sin(yaw); + const cp = Math.cos(pitch); + const sp = Math.sin(pitch); + + // Column-major mat3 (WebGL expects column-major when transpose=false). + // R = Ry(yaw) * Rx(pitch) + return new Float32Array([ + cy, + 0, + -sy, + + sy * sp, + cp, + cy * sp, + + sy * cp, + -sp, + cy * cp, + ]); +} + +function toUint8Volume(data: Float32Array): Uint8Array { + const out = new Uint8Array(data.length); + for (let i = 0; i < data.length; i++) { + const v = data[i] ?? 0; + const b = Math.round(clamp(v, 0, 1) * 255); + out[i] = b; + } + return out; +} + +type VolumeTextureFormat = + | { kind: 'f32'; internalFormat: number; format: number; type: number; minMagFilter: number } + | { kind: 'u8'; internalFormat: number; format: number; type: number; minMagFilter: number }; + +function chooseVolumeTextureFormat(gl: WebGL2RenderingContext): { + primary: VolumeTextureFormat; + fallback: VolumeTextureFormat; +} { + const floatLinear = !!gl.getExtension('OES_texture_float_linear'); + + const primary: VolumeTextureFormat = { + kind: 'f32', + internalFormat: gl.R32F, + format: gl.RED, + type: gl.FLOAT, + minMagFilter: floatLinear ? gl.LINEAR : gl.NEAREST, + }; + + const fallback: VolumeTextureFormat = { + kind: 'u8', + internalFormat: gl.R8, + format: gl.RED, + type: gl.UNSIGNED_BYTE, + minMagFilter: gl.LINEAR, + }; + + return { primary, fallback }; +} + +function compileShader(gl: WebGL2RenderingContext, type: number, src: string): WebGLShader { + const sh = gl.createShader(type); + if (!sh) throw new Error('Failed to create shader'); + gl.shaderSource(sh, src); + gl.compileShader(sh); + if (!gl.getShaderParameter(sh, gl.COMPILE_STATUS)) { + const log = gl.getShaderInfoLog(sh) || '(no log)'; + gl.deleteShader(sh); + throw new Error(log); + } + return sh; +} + +function createProgram(gl: WebGL2RenderingContext, vsSrc: string, fsSrc: string): WebGLProgram { + const vs = compileShader(gl, gl.VERTEX_SHADER, vsSrc); + const fs = compileShader(gl, gl.FRAGMENT_SHADER, fsSrc); + + const prog = gl.createProgram(); + if (!prog) throw new Error('Failed to create program'); + gl.attachShader(prog, vs); + gl.attachShader(prog, fs); + gl.linkProgram(prog); + + gl.deleteShader(vs); + gl.deleteShader(fs); + + if (!gl.getProgramParameter(prog, gl.LINK_STATUS)) { + const log = gl.getProgramInfoLog(prog) || '(no log)'; + gl.deleteProgram(prog); + throw new Error(log); + } + + return prog; +} + +export type SvrVolume3DModalProps = { + volume: SvrVolume; + onClose: () => void; +}; + +export function SvrVolume3DModal({ volume, onClose }: SvrVolume3DModalProps) { + const canvasRef = useRef(null); + + const [initError, setInitError] = useState(null); + + // Viewer controls + const [threshold, setThreshold] = useState(0.05); + const [steps, setSteps] = useState(160); + const [gamma, setGamma] = useState(1.0); + const [zoom, setZoom] = useState(1.0); + const [yaw, setYaw] = useState(0); + const [pitch, setPitch] = useState(0); + + const paramsRef = useRef({ threshold, steps, gamma, zoom, yaw, pitch }); + useEffect(() => { + paramsRef.current = { threshold, steps, gamma, zoom, yaw, pitch }; + }, [gamma, pitch, steps, threshold, yaw, zoom]); + + const { boxScale, dims } = useMemo(() => { + const [nx, ny, nz] = volume.dims; + const maxDim = Math.max(1, nx, ny, nz); + return { + dims: { nx, ny, nz }, + boxScale: [nx / maxDim, ny / maxDim, nz / maxDim] as const, + }; + }, [volume.dims]); + + const resetView = useCallback(() => { + setYaw(0); + setPitch(0); + setZoom(1.0); + }, []); + + // Pointer drag rotation (simple yaw/pitch trackball). + const dragRef = useRef<{ x: number; y: number; yaw: number; pitch: number } | null>(null); + + const onPointerDown = useCallback((e: React.PointerEvent) => { + const canvas = canvasRef.current; + if (!canvas) return; + + dragRef.current = { + x: e.clientX, + y: e.clientY, + yaw, + pitch, + }; + + canvas.setPointerCapture(e.pointerId); + e.preventDefault(); + e.stopPropagation(); + }, [pitch, yaw]); + + const onPointerMove = useCallback((e: React.PointerEvent) => { + const d = dragRef.current; + if (!d) return; + + const dx = e.clientX - d.x; + const dy = e.clientY - d.y; + + const nextYaw = d.yaw + dx * 0.01; + const nextPitch = clamp(d.pitch + dy * 0.01, -Math.PI / 2 + 1e-3, Math.PI / 2 - 1e-3); + + setYaw(nextYaw); + setPitch(nextPitch); + + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onPointerUp = useCallback((e: React.PointerEvent) => { + dragRef.current = null; + e.preventDefault(); + e.stopPropagation(); + }, []); + + useEffect(() => { + setInitError(null); + + const canvas = canvasRef.current; + if (!canvas) return; + + const gl = canvas.getContext('webgl2', { + antialias: true, + alpha: false, + depth: false, + stencil: false, + preserveDrawingBuffer: false, + }); + + if (!gl) { + setInitError('WebGL2 is not available in this browser/environment.'); + return; + } + + const { primary, fallback } = chooseVolumeTextureFormat(gl); + + const vsSrc = `#version 300 es +in vec2 a_pos; +out vec2 v_uv; +void main() { + v_uv = a_pos * 0.5 + 0.5; + gl_Position = vec4(a_pos, 0.0, 1.0); +}`; + + const fsSrc = `#version 300 es +precision highp float; +precision highp sampler3D; + +in vec2 v_uv; +out vec4 outColor; + +uniform sampler3D u_vol; +uniform mat3 u_rot; +uniform vec3 u_box; +uniform float u_aspect; +uniform float u_zoom; +uniform float u_thr; +uniform int u_steps; +uniform float u_gamma; + +float saturate(float x) { + return clamp(x, 0.0, 1.0); +} + +bool intersectBox(vec3 ro, vec3 rd, vec3 bmin, vec3 bmax, out float t0, out float t1) { + vec3 invD = 1.0 / rd; + vec3 tbot = (bmin - ro) * invD; + vec3 ttop = (bmax - ro) * invD; + vec3 tmin = min(ttop, tbot); + vec3 tmax = max(ttop, tbot); + t0 = max(max(tmin.x, tmin.y), tmin.z); + t1 = min(min(tmax.x, tmax.y), tmax.z); + return t1 >= max(t0, 0.0); +} + +void main() { + // NDC in [-1, 1] + vec2 p = v_uv * 2.0 - 1.0; + p.x *= u_aspect; + p /= max(1e-3, u_zoom); + + // World/view ray + vec3 roW = vec3(0.0, 0.0, 1.6); + vec3 rdW = normalize(vec3(p, -1.2)); + + // Rotate ray into volume/object space (volume is rotated by u_rot). + mat3 invR = transpose(u_rot); + vec3 ro = invR * roW; + vec3 rd = invR * rdW; + + vec3 bmin = -0.5 * u_box; + vec3 bmax = 0.5 * u_box; + + float t0; + float t1; + if (!intersectBox(ro, rd, bmin, bmax, t0, t1)) { + outColor = vec4(0.0, 0.0, 0.0, 1.0); + return; + } + + // MIP raymarch + const int MAX_STEPS = 256; + int n = clamp(u_steps, 8, MAX_STEPS); + float dt = (t1 - t0) / float(n); + + float m = 0.0; + float t = max(t0, 0.0); + + for (int i = 0; i < MAX_STEPS; i++) { + if (i >= n) break; + vec3 pos = ro + rd * (t + float(i) * dt); + + // Map object-space box to texture coords [0,1] + vec3 tc = pos / u_box + 0.5; + + float v = texture(u_vol, tc).r; + if (v >= u_thr) { + m = max(m, v); + } + } + + float g = max(1e-3, u_gamma); + float c = pow(saturate(m), 1.0 / g); + outColor = vec4(vec3(c), 1.0); +}`; + + let program: WebGLProgram | null = null; + let vao: WebGLVertexArrayObject | null = null; + let vbo: WebGLBuffer | null = null; + let tex: WebGLTexture | null = null; + let raf = 0; + + try { + program = createProgram(gl, vsSrc, fsSrc); + + // Full-screen triangle (2D clip space) + vao = gl.createVertexArray(); + vbo = gl.createBuffer(); + if (!vao || !vbo) throw new Error('Failed to allocate GL buffers'); + + gl.bindVertexArray(vao); + gl.bindBuffer(gl.ARRAY_BUFFER, vbo); + + // Triangle: (-1,-1), (3,-1), (-1,3) + const verts = new Float32Array([-1, -1, 3, -1, -1, 3]); + gl.bufferData(gl.ARRAY_BUFFER, verts, gl.STATIC_DRAW); + + const aPos = gl.getAttribLocation(program, 'a_pos'); + gl.enableVertexAttribArray(aPos); + gl.vertexAttribPointer(aPos, 2, gl.FLOAT, false, 0, 0); + + gl.bindVertexArray(null); + gl.bindBuffer(gl.ARRAY_BUFFER, null); + + // Volume texture (prefer float for fidelity; fall back to 8-bit for compatibility) + tex = gl.createTexture(); + if (!tex) throw new Error('Failed to allocate 3D texture'); + + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_3D, tex); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + + let fmt: VolumeTextureFormat = primary; + + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_R, gl.CLAMP_TO_EDGE); + + const tryUpload = (candidate: VolumeTextureFormat, candidateData: ArrayBufferView) => { + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MIN_FILTER, candidate.minMagFilter); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MAG_FILTER, candidate.minMagFilter); + + gl.texImage3D( + gl.TEXTURE_3D, + 0, + candidate.internalFormat, + dims.nx, + dims.ny, + dims.nz, + 0, + candidate.format, + candidate.type, + candidateData + ); + + const err = gl.getError(); + return err === gl.NO_ERROR; + }; + + try { + const ok = tryUpload(primary, volume.data); + if (!ok) { + const u8 = toUint8Volume(volume.data); + fmt = fallback; + tryUpload(fallback, u8); + } + } catch { + const u8 = toUint8Volume(volume.data); + fmt = fallback; + tryUpload(fallback, u8); + } + + console.info('[svr3d] Volume texture format', { kind: fmt.kind, dims }); + + gl.bindTexture(gl.TEXTURE_3D, null); + + const uVolLoc = gl.getUniformLocation(program, 'u_vol'); + const uRotLoc = gl.getUniformLocation(program, 'u_rot'); + const uBoxLoc = gl.getUniformLocation(program, 'u_box'); + const uAspectLoc = gl.getUniformLocation(program, 'u_aspect'); + const uZoomLoc = gl.getUniformLocation(program, 'u_zoom'); + const uThrLoc = gl.getUniformLocation(program, 'u_thr'); + const uStepsLoc = gl.getUniformLocation(program, 'u_steps'); + const uGammaLoc = gl.getUniformLocation(program, 'u_gamma'); + + const resizeAndViewport = () => { + const dpr = window.devicePixelRatio || 1; + const w = Math.max(1, Math.floor(canvas.clientWidth * dpr)); + const h = Math.max(1, Math.floor(canvas.clientHeight * dpr)); + if (canvas.width !== w || canvas.height !== h) { + canvas.width = w; + canvas.height = h; + } + gl.viewport(0, 0, canvas.width, canvas.height); + }; + + const draw = () => { + resizeAndViewport(); + + const { threshold, steps, gamma, zoom, yaw, pitch } = paramsRef.current; + + gl.disable(gl.DEPTH_TEST); + gl.disable(gl.CULL_FACE); + + gl.useProgram(program); + gl.bindVertexArray(vao); + + // Bind texture + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_3D, tex); + gl.uniform1i(uVolLoc, 0); + + // Uniforms + const rot = mat3FromYawPitch(yaw, pitch); + gl.uniformMatrix3fv(uRotLoc, false, rot); + gl.uniform3f(uBoxLoc, boxScale[0], boxScale[1], boxScale[2]); + gl.uniform1f(uAspectLoc, canvas.width / Math.max(1, canvas.height)); + gl.uniform1f(uZoomLoc, zoom); + gl.uniform1f(uThrLoc, clamp(threshold, 0, 1)); + gl.uniform1i(uStepsLoc, Math.round(clamp(steps, 8, 256))); + gl.uniform1f(uGammaLoc, clamp(gamma, 0.25, 4)); + + gl.drawArrays(gl.TRIANGLES, 0, 3); + + gl.bindTexture(gl.TEXTURE_3D, null); + gl.bindVertexArray(null); + + raf = window.requestAnimationFrame(draw); + }; + + raf = window.requestAnimationFrame(draw); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + console.error('[SVR3D] Failed to initialize:', e); + setInitError(msg); + } + + return () => { + if (raf) window.cancelAnimationFrame(raf); + + if (gl) { + if (tex) gl.deleteTexture(tex); + if (vbo) gl.deleteBuffer(vbo); + if (vao) gl.deleteVertexArray(vao); + if (program) gl.deleteProgram(program); + } + }; + // We intentionally re-init when volume changes. + }, [boxScale, dims, volume.data]); + + return ( +
+
+
+
+
SVR 3D Viewer
+
+ MIP volume render · {dims.nx}×{dims.ny}×{dims.nz} +
+
+ + +
+ +
+
+
+
+ + + {initError ? ( +
+ {initError} +
+ ) : ( +
+ Drag to rotate +
+ )} +
+
+
+ +
+
Controls
+ + + + + + + + + +
+ +
+ +
+ This is a lightweight in-browser volume render (MIP). It’s meant for quick visual inspection of the SVR output. +
+
+
+
+
+ ); +} diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx new file mode 100644 index 0000000..0f000fb --- /dev/null +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -0,0 +1,1370 @@ +import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'; +import { ChevronLeft, ChevronRight } from 'lucide-react'; +import type { SvrVolume } from '../types/svr'; +import { resample2dAreaAverage } from '../utils/svr/resample2d'; + +function clamp(x: number, min: number, max: number): number { + return x < min ? min : x > max ? max : x; +} + +/** + * Camera model constants used by both: + * - the fragment shader (ray origin + image plane) + * - the 2D axes overlay projection helper (`projectWorldToCanvas`) + * + * Keep these in sync or the overlay will drift relative to the 3D render. + */ +const SVR3D_CAMERA_Z = 1.6; +const SVR3D_FOCAL_Z = 1.2; + +async function rgbaToPngBlob(params: { rgba: Uint8ClampedArray; width: number; height: number }): Promise { + const { rgba, width, height } = params; + + const canvas = document.createElement('canvas'); + canvas.width = width; + canvas.height = height; + + const ctx = canvas.getContext('2d'); + if (!ctx) { + throw new Error('Failed to create 2D canvas context'); + } + + const img = ctx.createImageData(width, height); + img.data.set(rgba); + ctx.putImageData(img, 0, 0); + + const blob = await new Promise((resolve, reject) => { + canvas.toBlob((b) => { + if (!b) { + reject(new Error('canvas.toBlob() returned null')); + return; + } + resolve(b); + }, 'image/png'); + }); + + return blob; +} + +type Vec3 = { x: number; y: number; z: number }; +// Quaternion [x, y, z, w] +type Quat = [number, number, number, number]; + +function v3Add(a: Vec3, b: Vec3): Vec3 { + return { x: a.x + b.x, y: a.y + b.y, z: a.z + b.z }; +} + +function v3Scale(v: Vec3, s: number): Vec3 { + return { x: v.x * s, y: v.y * s, z: v.z * s }; +} + +function v3ApplyMat3(m: Float32Array, v: Vec3): Vec3 { + // Column-major 3x3. + return { + x: m[0]! * v.x + m[3]! * v.y + m[6]! * v.z, + y: m[1]! * v.x + m[4]! * v.y + m[7]! * v.z, + z: m[2]! * v.x + m[5]! * v.y + m[8]! * v.z, + }; +} + +function projectWorldToCanvas(params: { + world: Vec3; + canvasW: number; + canvasH: number; + aspect: number; + zoom: number; +}): { x: number; y: number } | null { + const { world, canvasW, canvasH, aspect, zoom } = params; + + // Must match the simple camera model used in the fragment shader: + // roW = (0,0,CAM_Z) + // rdW = normalize(vec3(p, -FOCAL_Z)) + const CAM_Z = SVR3D_CAMERA_Z; + const FOCAL_Z = SVR3D_FOCAL_Z; + + const vz = world.z - CAM_Z; + if (!(vz < -1e-6)) { + // Point is at/behind the camera plane; skip. + return null; + } + + // Intersect the ray from camera origin through the point with the image plane at z = CAM_Z - FOCAL_Z. + const t = -FOCAL_Z / vz; + + const px = world.x * t; + const py = world.y * t; + + // In shader: p.x *= aspect; p /= zoom. + // So inverse mapping is: ndc.x = px * zoom / aspect; ndc.y = py * zoom. + const ndcX = (px * zoom) / Math.max(1e-6, aspect); + const ndcY = py * zoom; + + return { + x: (ndcX * 0.5 + 0.5) * canvasW, + y: (1 - (ndcY * 0.5 + 0.5)) * canvasH, + }; +} + +function niceStepMm(rangeMm: number, targetTicks: number): number { + const r = Math.abs(rangeMm); + if (!(r > 1e-6) || !(targetTicks > 0)) return 1; + + const raw = r / targetTicks; + const pow10 = Math.pow(10, Math.floor(Math.log10(raw))); + const x = raw / pow10; + + const nice = x <= 1 ? 1 : x <= 2 ? 2 : x <= 5 ? 5 : 10; + return nice * pow10; +} + +type DrawAxesOverlayParams = { + axesCanvas: HTMLCanvasElement; + axesCtx: CanvasRenderingContext2D; + canvas: HTMLCanvasElement; + volume: SvrVolume; + boxScale: readonly [number, number, number]; + rotMat: Float32Array; + zoom: number; +}; + +function drawAxesOverlay(params: DrawAxesOverlayParams): void { + const { axesCanvas, axesCtx, canvas, volume, boxScale, rotMat, zoom } = params; + + const w = axesCanvas.width; + const h = axesCanvas.height; + if (!(w > 0 && h > 0)) return; + + // Clear. + axesCtx.clearRect(0, 0, w, h); + + // Volume physical size in mm. + const [nx, ny, nz] = volume.dims; + const [vx, vy, vz] = volume.voxelSizeMm; + + const sizeMm = { + x: Math.abs(nx * vx), + y: Math.abs(ny * vy), + z: Math.abs(nz * vz), + }; + + // Object-space box extents used by the shader. + const box = { x: boxScale[0], y: boxScale[1], z: boxScale[2] }; + + // Place axes on the (x-, y-, z+) corner of the box. + const originObj: Vec3 = { + x: -0.5 * box.x, + y: -0.5 * box.y, + z: 0.5 * box.z, + }; + + const aspect = w / Math.max(1, h); + const dpr = canvas.clientWidth > 0 ? canvas.width / canvas.clientWidth : window.devicePixelRatio || 1; + + // NOTE: `rotMat` here is the same u_rot we send to the shader, so object->world matches. + const projectObj = (obj: Vec3) => { + const world = v3ApplyMat3(rotMat, obj); + return projectWorldToCanvas({ world, canvasW: w, canvasH: h, aspect, zoom }); + }; + + // 2D styling. + axesCtx.save(); + axesCtx.lineCap = 'round'; + axesCtx.lineJoin = 'round'; + + const fontPx = Math.max(10, Math.round(10 * dpr)); + axesCtx.font = `${fontPx}px ui-sans-serif, system-ui`; + axesCtx.textBaseline = 'middle'; + + const tickMajorPx = 7 * dpr; + const tickMinorPx = 4 * dpr; + const labelOffsetPx = 10 * dpr; + + const axes: Array<{ + name: 'X' | 'Y' | 'Z'; + dirObj: Vec3; + lenObj: number; + lenMm: number; + rgba: string; + }> = [ + { name: 'X', dirObj: { x: 1, y: 0, z: 0 }, lenObj: box.x, lenMm: sizeMm.x, rgba: 'rgba(255,80,80,0.9)' }, + { name: 'Y', dirObj: { x: 0, y: 1, z: 0 }, lenObj: box.y, lenMm: sizeMm.y, rgba: 'rgba(80,255,80,0.9)' }, + // Use -Z so the axis spans the full box depth from the front face into the volume. + { name: 'Z', dirObj: { x: 0, y: 0, z: -1 }, lenObj: box.z, lenMm: sizeMm.z, rgba: 'rgba(80,160,255,0.9)' }, + ]; + + for (const axis of axes) { + if (!(axis.lenObj > 1e-9) || !(axis.lenMm > 1e-6)) continue; + + const p0 = projectObj(originObj); + const p1 = projectObj(v3Add(originObj, v3Scale(axis.dirObj, axis.lenObj))); + if (!p0 || !p1) continue; + + const dx = p1.x - p0.x; + const dy = p1.y - p0.y; + const dLen = Math.hypot(dx, dy); + if (!(dLen > 1e-6)) continue; + + const ux = dx / dLen; + const uy = dy / dLen; + const px = -uy; + const py = ux; + + // Main axis line. + axesCtx.lineWidth = 1.25 * dpr; + axesCtx.strokeStyle = axis.rgba; + axesCtx.beginPath(); + axesCtx.moveTo(p0.x, p0.y); + axesCtx.lineTo(p1.x, p1.y); + axesCtx.stroke(); + + // Ticks. + const majorStepMm = niceStepMm(axis.lenMm, 5); + const minorStepMm = majorStepMm >= 10 ? majorStepMm / 5 : majorStepMm / 2; + + const stepObj = axis.lenObj / axis.lenMm; + + const isNear = (a: number, b: number) => Math.abs(a - b) <= 1e-6 * Math.max(1, axis.lenMm); + + const drawTickAt = (tMm: number, isMajor: boolean) => { + const tObj = tMm * stepObj; + const ptObj = v3Add(originObj, v3Scale(axis.dirObj, tObj)); + const p = projectObj(ptObj); + if (!p) return; + + const half = (isMajor ? tickMajorPx : tickMinorPx) * 0.5; + axesCtx.lineWidth = (isMajor ? 1.25 : 1.0) * dpr; + axesCtx.strokeStyle = axis.rgba; + axesCtx.beginPath(); + axesCtx.moveTo(p.x - px * half, p.y - py * half); + axesCtx.lineTo(p.x + px * half, p.y + py * half); + axesCtx.stroke(); + + if (isMajor && tMm > 0) { + const text = `${Math.round(tMm)}mm`; + const lx = p.x + px * labelOffsetPx; + const ly = p.y + py * labelOffsetPx; + + axesCtx.textAlign = px >= 0 ? 'left' : 'right'; + axesCtx.lineWidth = 3 * dpr; + axesCtx.strokeStyle = 'rgba(0,0,0,0.8)'; + axesCtx.strokeText(text, lx, ly); + axesCtx.fillStyle = axis.rgba; + axesCtx.fillText(text, lx, ly); + } + }; + + // Minor ticks. + for (let t = 0; t <= axis.lenMm + minorStepMm * 0.25; t += minorStepMm) { + // Skip ticks that coincide with major ticks. + const q = Math.round(t / majorStepMm); + const isMajor = isNear(t, q * majorStepMm); + drawTickAt(Math.min(t, axis.lenMm), isMajor); + } + + // Axis label at end. + { + const text = `${axis.name}: ${Math.round(axis.lenMm)}mm`; + const lx = p1.x + px * (labelOffsetPx * 1.2) + ux * (6 * dpr); + const ly = p1.y + py * (labelOffsetPx * 1.2) + uy * (6 * dpr); + axesCtx.textAlign = px >= 0 ? 'left' : 'right'; + axesCtx.lineWidth = 3 * dpr; + axesCtx.strokeStyle = 'rgba(0,0,0,0.8)'; + axesCtx.strokeText(text, lx, ly); + axesCtx.fillStyle = axis.rgba; + axesCtx.fillText(text, lx, ly); + } + } + + axesCtx.restore(); +} + +function v3Normalize(v: Vec3): Vec3 { + const len = Math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z); + if (len <= 1e-12) return { x: 0, y: 0, z: 1 }; + const inv = 1 / len; + return { x: v.x * inv, y: v.y * inv, z: v.z * inv }; +} + +function quatNormalize(q: Quat): Quat { + const [x, y, z, w] = q; + const len = Math.sqrt(x * x + y * y + z * z + w * w); + if (len <= 1e-12) return [0, 0, 0, 1]; + const inv = 1 / len; + return [x * inv, y * inv, z * inv, w * inv]; +} + +function quatMultiply(a: Quat, b: Quat): Quat { + // Hamilton product (composition) + const ax = a[0]; + const ay = a[1]; + const az = a[2]; + const aw = a[3]; + + const bx = b[0]; + const by = b[1]; + const bz = b[2]; + const bw = b[3]; + + return [ + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + aw * bw - ax * bx - ay * by - az * bz, + ]; +} + +function quatFromAxisAngle(axis: Vec3, angleRad: number): Quat { + const a = v3Normalize(axis); + const half = angleRad * 0.5; + const s = Math.sin(half); + const c = Math.cos(half); + return quatNormalize([a.x * s, a.y * s, a.z * s, c]); +} + +function mat3FromQuat(q: Quat, out: Float32Array): void { + const x = q[0]; + const y = q[1]; + const z = q[2]; + const w = q[3]; + + const x2 = x + x; + const y2 = y + y; + const z2 = z + z; + + const xx = x * x2; + const yy = y * y2; + const zz = z * z2; + + const xy = x * y2; + const xz = x * z2; + const yz = y * z2; + + const wx = w * x2; + const wy = w * y2; + const wz = w * z2; + + // WebGL expects column-major layout when transpose=false. + // These are the standard quaternion->matrix terms (row/column layout handled below). + const m00 = 1 - (yy + zz); + const m01 = xy - wz; + const m02 = xz + wy; + + const m10 = xy + wz; + const m11 = 1 - (xx + zz); + const m12 = yz - wx; + + const m20 = xz - wy; + const m21 = yz + wx; + const m22 = 1 - (xx + yy); + + // Column-major mat3 for WebGL. + out[0] = m00; + out[1] = m10; + out[2] = m20; + + out[3] = m01; + out[4] = m11; + out[5] = m21; + + out[6] = m02; + out[7] = m12; + out[8] = m22; +} + + +function toUint8Volume(data: Float32Array): Uint8Array { + const out = new Uint8Array(data.length); + for (let i = 0; i < data.length; i++) { + const v = data[i] ?? 0; + const b = Math.round(clamp(v, 0, 1) * 255); + out[i] = b; + } + return out; +} + +type VolumeTextureFormat = + | { kind: 'f32'; internalFormat: number; format: number; type: number; minMagFilter: number } + | { kind: 'u8'; internalFormat: number; format: number; type: number; minMagFilter: number }; + +function chooseVolumeTextureFormat(gl: WebGL2RenderingContext): { + primary: VolumeTextureFormat; + fallback: VolumeTextureFormat; +} { + // Float textures preserve subtle contrast; if linear filtering isn't supported we can still sample with NEAREST. + const floatLinear = !!gl.getExtension('OES_texture_float_linear'); + + const primary: VolumeTextureFormat = { + kind: 'f32', + internalFormat: gl.R32F, + format: gl.RED, + type: gl.FLOAT, + minMagFilter: floatLinear ? gl.LINEAR : gl.NEAREST, + }; + + const fallback: VolumeTextureFormat = { + kind: 'u8', + internalFormat: gl.R8, + format: gl.RED, + type: gl.UNSIGNED_BYTE, + minMagFilter: gl.LINEAR, + }; + + return { primary, fallback }; +} + +function compileShader(gl: WebGL2RenderingContext, type: number, src: string): WebGLShader { + const sh = gl.createShader(type); + if (!sh) throw new Error('Failed to create shader'); + gl.shaderSource(sh, src); + gl.compileShader(sh); + if (!gl.getShaderParameter(sh, gl.COMPILE_STATUS)) { + const log = gl.getShaderInfoLog(sh) || '(no log)'; + gl.deleteShader(sh); + throw new Error(log); + } + return sh; +} + +function createProgram(gl: WebGL2RenderingContext, vsSrc: string, fsSrc: string): WebGLProgram { + const vs = compileShader(gl, gl.VERTEX_SHADER, vsSrc); + const fs = compileShader(gl, gl.FRAGMENT_SHADER, fsSrc); + + const prog = gl.createProgram(); + if (!prog) throw new Error('Failed to create program'); + gl.attachShader(prog, vs); + gl.attachShader(prog, fs); + gl.linkProgram(prog); + + gl.deleteShader(vs); + gl.deleteShader(fs); + + if (!gl.getProgramParameter(prog, gl.LINK_STATUS)) { + const log = gl.getProgramInfoLog(prog) || '(no log)'; + gl.deleteProgram(prog); + throw new Error(log); + } + + return prog; +} + +export type SvrVolume3DViewerProps = { + volume: SvrVolume | null; +}; + +export type SvrVolume3DViewerHandle = { + /** Capture the current 3D canvas frame as a PNG (best-effort). */ + capture3dPng: () => Promise; + /** Reset view + controls to a stable preset for reproducible harness captures. */ + applyHarnessPreset: () => void; +}; + +export const SvrVolume3DViewer = forwardRef(function SvrVolume3DViewer( + { volume }, + ref +) { + const canvasRef = useRef(null); + const axesCanvasRef = useRef(null); + const pendingCapture3dRef = useRef<{ resolve: (b: Blob | null) => void } | null>(null); + + const [initError, setInitError] = useState(null); + + // Viewer controls (composite-only) + const [controlsCollapsed, setControlsCollapsed] = useState(false); + const [threshold, setThreshold] = useState(0.05); + const [steps, setSteps] = useState(160); + const [gamma, setGamma] = useState(1.0); + const [opacity, setOpacity] = useState(4.0); + const [zoom, setZoom] = useState(1.0); + + // Slice inspector (orthogonal slices). + const sliceCanvasRef = useRef(null); + const [inspectPlane, setInspectPlane] = useState<'axial' | 'coronal' | 'sagittal'>('axial'); + const [inspectIndex, setInspectIndex] = useState(0); + + const paramsRef = useRef({ threshold, steps, gamma, opacity, zoom }); + useEffect(() => { + paramsRef.current = { threshold, steps, gamma, opacity, zoom }; + }, [gamma, opacity, steps, threshold, zoom]); + + const rotationRef = useRef([0, 0, 0, 1]); + + const { boxScale, dims } = useMemo(() => { + if (!volume) { + return { + dims: { nx: 1, ny: 1, nz: 1 }, + boxScale: [1, 1, 1] as const, + }; + } + + const [nx, ny, nz] = volume.dims; + const maxDim = Math.max(1, nx, ny, nz); + return { + dims: { nx, ny, nz }, + boxScale: [nx / maxDim, ny / maxDim, nz / maxDim] as const, + }; + }, [volume]); + + const resetView = useCallback(() => { + rotationRef.current = [0, 0, 0, 1]; + setZoom(1.0); + }, []); + + useImperativeHandle( + ref, + () => ({ + capture3dPng: () => { + if (!volume) return Promise.resolve(null); + if (!canvasRef.current) return Promise.resolve(null); + + return new Promise((resolve) => { + // Only allow one pending capture; resolve any previous request. + if (pendingCapture3dRef.current) { + pendingCapture3dRef.current.resolve(null); + } + + pendingCapture3dRef.current = { resolve }; + + // Safety: don't leave callers hanging if the GL loop isn't running. + window.setTimeout(() => { + if (pendingCapture3dRef.current?.resolve === resolve) { + pendingCapture3dRef.current = null; + resolve(null); + } + }, 1500); + }); + }, + applyHarnessPreset: () => { + // Stable defaults for harness screenshots. + setThreshold(0.05); + setSteps(160); + setGamma(1.0); + setOpacity(4.0); + setControlsCollapsed(false); + resetView(); + }, + }), + [resetView, volume] + ); + + // Pointer drag rotation (viewport-relative yaw/pitch). + // + // Goal: keep controls constant relative to the viewport: + // - horizontal mouse movement => yaw about screen vertical axis + // - vertical mouse movement => pitch about screen horizontal axis + const dragRef = useRef<{ lastX: number; lastY: number; pointerId: number } | null>(null); + + const onPointerDown = useCallback((e: React.PointerEvent) => { + const canvas = canvasRef.current; + if (!canvas) return; + + dragRef.current = { + lastX: e.clientX, + lastY: e.clientY, + pointerId: e.pointerId, + }; + + canvas.setPointerCapture(e.pointerId); + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onPointerMove = useCallback((e: React.PointerEvent) => { + const canvas = canvasRef.current; + const d = dragRef.current; + if (!canvas || !d || d.pointerId !== e.pointerId) return; + + const dx = e.clientX - d.lastX; + const dy = e.clientY - d.lastY; + + d.lastX = e.clientX; + d.lastY = e.clientY; + + const minDim = Math.max(1, Math.min(canvas.clientWidth, canvas.clientHeight)); + const anglePerPx = Math.PI / minDim; + + // Apply *delta* rotations about fixed viewport/world axes. + // + // Important: composing absolute yaw/pitch as `R = R_pitch * R_yaw` makes yaw behave like a local-axis + // rotation once pitch != 0 (unintuitive). Pre-multiplying the current rotation with world-axis deltas + // keeps both axes fixed relative to the viewport. + // NOTE: positive clientY is down, so `deltaPitch = +dy` feels like “drag down -> tilt down”. + const deltaYaw = dx * anglePerPx; + const deltaPitch = dy * anglePerPx; + + const qYaw = quatFromAxisAngle({ x: 0, y: 1, z: 0 }, deltaYaw); + const qPitch = quatFromAxisAngle({ x: 1, y: 0, z: 0 }, deltaPitch); + + // Apply yaw first (screen vertical axis), then pitch (screen horizontal axis). + const qDelta = quatMultiply(qPitch, qYaw); + rotationRef.current = quatNormalize(quatMultiply(qDelta, rotationRef.current)); + + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onPointerUp = useCallback((e: React.PointerEvent) => { + if (dragRef.current?.pointerId === e.pointerId) { + dragRef.current = null; + } + e.preventDefault(); + e.stopPropagation(); + }, []); + + // Mousewheel zoom on the canvas. + useEffect(() => { + const canvas = canvasRef.current; + if (!canvas) return; + + const onWheel = (e: WheelEvent) => { + if (!Number.isFinite(e.deltaY) || e.deltaY === 0) return; + + // Multiplicative zoom feels better across trackpads (small deltas) and mouse wheels (large deltas). + const factor = Math.exp(-e.deltaY * 0.001); + setZoom((z) => clamp(z * factor, 0.6, 10.0)); + + e.preventDefault(); + e.stopPropagation(); + }; + + canvas.addEventListener('wheel', onWheel, { passive: false }); + return () => canvas.removeEventListener('wheel', onWheel); + }, []); + + const inspectorInfo = useMemo(() => { + if (!volume) { + return { + maxIndex: 0, + srcRows: 1, + srcCols: 1, + }; + } + + const [nx, ny, nz] = volume.dims; + + if (inspectPlane === 'axial') { + return { + maxIndex: Math.max(0, nz - 1), + srcRows: ny, + srcCols: nx, + }; + } + + if (inspectPlane === 'coronal') { + return { + maxIndex: Math.max(0, ny - 1), + srcRows: nz, + srcCols: nx, + }; + } + + // sagittal + return { + maxIndex: Math.max(0, nx - 1), + srcRows: nz, + srcCols: ny, + }; + }, [inspectPlane, volume]); + + // Default the inspector to the mid-slice when the volume or plane changes. + useEffect(() => { + if (!volume) return; + setInspectIndex(Math.floor(inspectorInfo.maxIndex / 2)); + }, [inspectPlane, inspectorInfo.maxIndex, volume]); + + // Draw the inspector slice to a 2D canvas. + useEffect(() => { + const canvas = sliceCanvasRef.current; + if (!canvas) return; + if (!volume) return; + + const ctx = canvas.getContext('2d'); + if (!ctx) return; + + const [nx, ny, nz] = volume.dims; + const data = volume.data; + + const idx = Math.round(clamp(inspectIndex, 0, inspectorInfo.maxIndex)); + + const srcRows = inspectorInfo.srcRows; + const srcCols = inspectorInfo.srcCols; + + const src = new Float32Array(srcRows * srcCols); + + const strideY = nx; + const strideZ = nx * ny; + + if (inspectPlane === 'axial') { + const z = idx; + const zBase = z * strideZ; + for (let y = 0; y < ny; y++) { + const inBase = zBase + y * strideY; + const outBase = y * nx; + for (let x = 0; x < nx; x++) { + src[outBase + x] = data[inBase + x] ?? 0; + } + } + } else if (inspectPlane === 'coronal') { + const y = idx; + for (let z = 0; z < nz; z++) { + const inBase = z * strideZ + y * strideY; + const outBase = z * nx; + for (let x = 0; x < nx; x++) { + src[outBase + x] = data[inBase + x] ?? 0; + } + } + } else { + // sagittal + const x = idx; + for (let z = 0; z < nz; z++) { + const zBase = z * strideZ; + const outBase = z * ny; + for (let y = 0; y < ny; y++) { + src[outBase + y] = data[zBase + y * strideY + x] ?? 0; + } + } + } + + // Downsample for interactive rendering (avoid huge canvases). + const MAX_SIZE = 256; + const maxDim = Math.max(srcRows, srcCols); + const scale = maxDim > MAX_SIZE ? MAX_SIZE / maxDim : 1; + const dsRows = Math.max(1, Math.round(srcRows * scale)); + const dsCols = Math.max(1, Math.round(srcCols * scale)); + + const down = resample2dAreaAverage(src, srcRows, srcCols, dsRows, dsCols); + + if (canvas.width !== dsCols) canvas.width = dsCols; + if (canvas.height !== dsRows) canvas.height = dsRows; + + const img = ctx.createImageData(dsCols, dsRows); + const out = img.data; + + for (let i = 0; i < down.length; i++) { + const v = down[i] ?? 0; + const b = Math.round(clamp(v, 0, 1) * 255); + + const j = i * 4; + out[j] = b; + out[j + 1] = b; + out[j + 2] = b; + out[j + 3] = 255; + } + + ctx.putImageData(img, 0, 0); + }, [inspectIndex, inspectPlane, inspectorInfo.maxIndex, inspectorInfo.srcCols, inspectorInfo.srcRows, volume]); + + useEffect(() => { + setInitError(null); + + const canvas = canvasRef.current; + if (!canvas) return; + + if (!volume) { + // No volume yet; nothing to initialize. + return; + } + + const gl = canvas.getContext('webgl2', { + antialias: true, + alpha: false, + depth: false, + stencil: false, + preserveDrawingBuffer: false, + }); + + if (!gl) { + setInitError('WebGL2 is not available in this browser/environment.'); + return; + } + + // Prefer float textures for fidelity; fall back to 8-bit if unavailable. + const { primary, fallback } = chooseVolumeTextureFormat(gl); + + const vsSrc = `#version 300 es +in vec2 a_pos; +out vec2 v_uv; +void main() { + v_uv = a_pos * 0.5 + 0.5; + gl_Position = vec4(a_pos, 0.0, 1.0); +}`; + + const fsSrc = `#version 300 es +precision highp float; +precision highp sampler3D; + +in vec2 v_uv; +out vec4 outColor; + +uniform sampler3D u_vol; +uniform mat3 u_rot; +uniform vec3 u_box; +uniform float u_aspect; +uniform float u_zoom; +uniform float u_thr; +uniform int u_steps; +uniform float u_gamma; +uniform float u_opacity; +uniform vec3 u_texel; + +const float CAM_Z = ${SVR3D_CAMERA_Z}; +const float FOCAL_Z = ${SVR3D_FOCAL_Z}; + +float saturate(float x) { + return clamp(x, 0.0, 1.0); +} + +float radial01(vec3 pos) { + // pos is in object space centered at the volume centroid. + // Normalize by the half box extents so that r=1 is approximately the box surface (clamped). + vec3 halfBox = 0.5 * u_box; + vec3 q = pos / max(halfBox, vec3(1e-6)); + return saturate(length(q)); +} + +bool intersectBox(vec3 ro, vec3 rd, vec3 bmin, vec3 bmax, out float t0, out float t1) { + vec3 invD = 1.0 / rd; + vec3 tbot = (bmin - ro) * invD; + vec3 ttop = (bmax - ro) * invD; + vec3 tmin = min(ttop, tbot); + vec3 tmax = max(ttop, tbot); + t0 = max(max(tmin.x, tmin.y), tmin.z); + t1 = min(min(tmax.x, tmax.y), tmax.z); + return t1 >= max(t0, 0.0); +} + +void main() { + // NDC in [-1, 1] + vec2 p = v_uv * 2.0 - 1.0; + p.x *= u_aspect; + p /= max(1e-3, u_zoom); + + // World/view ray + vec3 roW = vec3(0.0, 0.0, CAM_Z); + vec3 rdW = normalize(vec3(p, -FOCAL_Z)); + + // Rotate ray into volume/object space (volume is rotated by u_rot). + mat3 invR = transpose(u_rot); + vec3 ro = invR * roW; + vec3 rd = invR * rdW; + + vec3 bmin = -0.5 * u_box; + vec3 bmax = 0.5 * u_box; + + float t0; + float t1; + if (!intersectBox(ro, rd, bmin, bmax, t0, t1)) { + outColor = vec4(0.0, 0.0, 0.0, 1.0); + return; + } + + // Raymarch (front-to-back compositing) + const int MAX_STEPS = 256; + int n = clamp(u_steps, 8, MAX_STEPS); + float dt = (t1 - t0) / float(n); + + // Radial prior + gradient-based shading. + // + // Prior: the center of the box is more likely to contain the structure of interest. + // We use that to: + // - keep the intensity threshold low near the center and higher near the edges + // - boost edge shading near the center + // + // NOTE: Use *linear* radial ramps for predictability. + const float EDGE_K = 14.0; + const float CENTER_EDGE_GAIN = 2.5; + + float accum = 0.0; + float aAccum = 0.0; + + float t = max(t0, 0.0); + + // View direction in object space (toward the camera). + vec3 vDir = normalize(-rd); + + for (int i = 0; i < MAX_STEPS; i++) { + if (i >= n) break; + vec3 pos = ro + rd * (t + float(i) * dt); + + // Map object-space box to texture coords [0,1] + vec3 tc = pos / u_box + 0.5; + + float r = radial01(pos); + + // thrW ramps 0 at center -> 1 at edge. + float thrW = r; + // centerW ramps 1 at center -> 0 at edge. + float centerW = 1.0 - r; + + float thr = saturate(u_thr * thrW); + + float v = saturate(texture(u_vol, tc).r); + + if (v >= thr) { + float val = saturate((v - thr) / max(1e-6, 1.0 - thr)); + + // Gradient in object/texture space (central differences). + vec3 d = u_texel; + float vx1 = saturate(texture(u_vol, clamp(tc + vec3(d.x, 0.0, 0.0), 0.0, 1.0)).r); + float vx0 = saturate(texture(u_vol, clamp(tc - vec3(d.x, 0.0, 0.0), 0.0, 1.0)).r); + float vy1 = saturate(texture(u_vol, clamp(tc + vec3(0.0, d.y, 0.0), 0.0, 1.0)).r); + float vy0 = saturate(texture(u_vol, clamp(tc - vec3(0.0, d.y, 0.0), 0.0, 1.0)).r); + float vz1 = saturate(texture(u_vol, clamp(tc + vec3(0.0, 0.0, d.z), 0.0, 1.0)).r); + float vz0 = saturate(texture(u_vol, clamp(tc - vec3(0.0, 0.0, d.z), 0.0, 1.0)).r); + + vec3 grad = vec3(vx1 - vx0, vy1 - vy0, vz1 - vz0); + float gmag = length(grad); + + // Edge factor (boosted near the center). + // + // IMPORTANT: use an exponential mapping so the "Edge strength" slider stays responsive + // instead of quickly saturating to 1.0 for most edges. + float centerGain = mix(1.0, CENTER_EDGE_GAIN, saturate(centerW)); + float edgeRaw = gmag * EDGE_K * centerGain; + float edge = 1.0 - exp(-edgeRaw * u_gamma); + edge = saturate(edge); + edge = edge * edge; + + // Simple shading using the gradient as a normal (view-aligned light). + vec3 nrm = normalize(grad + vec3(1e-6)); + float diff = abs(dot(nrm, vDir)); + float shade = 0.25 + 0.75 * diff; + + // Make edges matter for visibility (opacity) and for perceived contrast (brightness). + float a = saturate(val * (0.15 + 0.85 * edge)); + + // Convert to per-step opacity; dt keeps opacity roughly stable as step count changes. + float aStep = 1.0 - exp(-u_opacity * a * dt * 4.0); + aStep = saturate(aStep); + + float sampleV = v * shade * (0.6 + 0.4 * edge); + + accum += (1.0 - aAccum) * sampleV * aStep; + aAccum += (1.0 - aAccum) * aStep; + + if (aAccum > 0.98) { + break; + } + } + } + + outColor = vec4(vec3(saturate(accum)), 1.0); +}`; + + let program: WebGLProgram | null = null; + let vao: WebGLVertexArrayObject | null = null; + let vbo: WebGLBuffer | null = null; + let tex: WebGLTexture | null = null; + let raf = 0; + + try { + program = createProgram(gl, vsSrc, fsSrc); + + // Full-screen triangle (2D clip space) + vao = gl.createVertexArray(); + vbo = gl.createBuffer(); + if (!vao || !vbo) throw new Error('Failed to allocate GL buffers'); + + gl.bindVertexArray(vao); + gl.bindBuffer(gl.ARRAY_BUFFER, vbo); + + // Triangle: (-1,-1), (3,-1), (-1,3) + const verts = new Float32Array([-1, -1, 3, -1, -1, 3]); + gl.bufferData(gl.ARRAY_BUFFER, verts, gl.STATIC_DRAW); + + const aPos = gl.getAttribLocation(program, 'a_pos'); + gl.enableVertexAttribArray(aPos); + gl.vertexAttribPointer(aPos, 2, gl.FLOAT, false, 0, 0); + + gl.bindVertexArray(null); + gl.bindBuffer(gl.ARRAY_BUFFER, null); + + // Volume texture (prefer float for fidelity; fall back to 8-bit for compatibility) + tex = gl.createTexture(); + if (!tex) throw new Error('Failed to allocate 3D texture'); + + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_3D, tex); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + + // We'll try float first; if WebGL rejects it, re-upload as R8. + let fmt: VolumeTextureFormat = primary; + let data: ArrayBufferView = volume.data; + + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_R, gl.CLAMP_TO_EDGE); + + const tryUpload = (candidate: VolumeTextureFormat, candidateData: ArrayBufferView) => { + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MIN_FILTER, candidate.minMagFilter); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MAG_FILTER, candidate.minMagFilter); + + gl.texImage3D( + gl.TEXTURE_3D, + 0, + candidate.internalFormat, + dims.nx, + dims.ny, + dims.nz, + 0, + candidate.format, + candidate.type, + candidateData + ); + + const err = gl.getError(); + return err === gl.NO_ERROR; + }; + + try { + const ok = tryUpload(primary, volume.data); + if (!ok) { + // Fall back to 8-bit normalized. + const u8 = toUint8Volume(volume.data); + fmt = fallback; + data = u8; + tryUpload(fallback, data); + } + } catch { + const u8 = toUint8Volume(volume.data); + fmt = fallback; + data = u8; + tryUpload(fallback, data); + } + + console.info('[svr3d] Volume texture format', { kind: fmt.kind, dims }); + + gl.bindTexture(gl.TEXTURE_3D, null); + + const u = { + vol: gl.getUniformLocation(program, 'u_vol'), + rot: gl.getUniformLocation(program, 'u_rot'), + box: gl.getUniformLocation(program, 'u_box'), + aspect: gl.getUniformLocation(program, 'u_aspect'), + zoom: gl.getUniformLocation(program, 'u_zoom'), + thr: gl.getUniformLocation(program, 'u_thr'), + steps: gl.getUniformLocation(program, 'u_steps'), + gamma: gl.getUniformLocation(program, 'u_gamma'), + opacity: gl.getUniformLocation(program, 'u_opacity'), + texel: gl.getUniformLocation(program, 'u_texel'), + } as const; + + const rotMat = new Float32Array(9); + + const axesCanvas = axesCanvasRef.current; + const axesCtx = axesCanvas ? axesCanvas.getContext('2d') : null; + + const resizeAndViewport = () => { + const dpr = window.devicePixelRatio || 1; + const w = Math.max(1, Math.floor(canvas.clientWidth * dpr)); + const h = Math.max(1, Math.floor(canvas.clientHeight * dpr)); + if (canvas.width !== w || canvas.height !== h) { + canvas.width = w; + canvas.height = h; + } + + if (axesCanvas) { + if (axesCanvas.width !== w || axesCanvas.height !== h) { + axesCanvas.width = w; + axesCanvas.height = h; + } + } + + gl.viewport(0, 0, canvas.width, canvas.height); + }; + + + const draw = () => { + resizeAndViewport(); + + const { threshold, steps, gamma, opacity, zoom } = paramsRef.current; + + gl.disable(gl.DEPTH_TEST); + gl.disable(gl.CULL_FACE); + + gl.useProgram(program); + gl.bindVertexArray(vao); + + // Bind texture + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_3D, tex); + gl.uniform1i(u.vol, 0); + + // Uniforms + mat3FromQuat(rotationRef.current, rotMat); + gl.uniformMatrix3fv(u.rot, false, rotMat); + gl.uniform3f(u.box, boxScale[0], boxScale[1], boxScale[2]); + gl.uniform1f(u.aspect, canvas.width / Math.max(1, canvas.height)); + gl.uniform1f(u.zoom, zoom); + // Threshold is an edge "scale" (0..5). The shader maps it to a linear 0-at-center threshold. + gl.uniform1f(u.thr, clamp(threshold, 0, 5)); + gl.uniform1i(u.steps, Math.round(clamp(steps, 8, 256))); + gl.uniform1f(u.gamma, clamp(gamma, 0.1, 10)); + gl.uniform1f(u.opacity, clamp(opacity, 0.1, 20)); + gl.uniform3f( + u.texel, + 1 / Math.max(1, dims.nx), + 1 / Math.max(1, dims.ny), + 1 / Math.max(1, dims.nz) + ); + + gl.drawArrays(gl.TRIANGLES, 0, 3); + + // Overlay reference axes with mm tick marks for gauging physical size. + if (axesCanvas && axesCtx) { + drawAxesOverlay({ axesCanvas, axesCtx, canvas, volume, boxScale, rotMat, zoom }); + } + + // One-shot capture for the harness export: read pixels from the current frame. + const pending = pendingCapture3dRef.current; + if (pending) { + pendingCapture3dRef.current = null; + + try { + const w = canvas.width; + const h = canvas.height; + + const rgba = new Uint8Array(w * h * 4); + gl.readPixels(0, 0, w, h, gl.RGBA, gl.UNSIGNED_BYTE, rgba); + + // Flip Y (WebGL origin is bottom-left; ImageData expects top-left). + const flipped = new Uint8ClampedArray(rgba.length); + const rowBytes = w * 4; + for (let y = 0; y < h; y++) { + const srcStart = (h - 1 - y) * rowBytes; + const dstStart = y * rowBytes; + flipped.set(rgba.subarray(srcStart, srcStart + rowBytes), dstStart); + } + + void rgbaToPngBlob({ rgba: flipped, width: w, height: h }) + .then((b) => pending.resolve(b)) + .catch(() => pending.resolve(null)); + } catch (e) { + console.warn('[svr3d] Failed to capture screenshot', e); + pending.resolve(null); + } + } + + gl.bindTexture(gl.TEXTURE_3D, null); + gl.bindVertexArray(null); + + raf = window.requestAnimationFrame(draw); + }; + + raf = window.requestAnimationFrame(draw); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + console.error('[SVR3D] Failed to initialize:', e); + setInitError(msg); + } + + return () => { + if (pendingCapture3dRef.current) { + pendingCapture3dRef.current.resolve(null); + pendingCapture3dRef.current = null; + } + + if (raf) window.cancelAnimationFrame(raf); + + if (gl) { + if (tex) gl.deleteTexture(tex); + if (vbo) gl.deleteBuffer(vbo); + if (vao) gl.deleteVertexArray(vao); + if (program) gl.deleteProgram(program); + } + }; + }, [boxScale, dims, volume]); + + return ( +
+
+
+
+ + + + + + + {!volume ? ( +
+ Run SVR to generate a volume for 3D viewing. +
+ ) : initError ? ( +
+ {initError} +
+ ) : ( +
+ Drag to rotate · Wheel to zoom +
+ )} +
+
+
+ + {controlsCollapsed ? null : ( +
+
3D Controls
+ + + + + + + + + + + +
+ +
+ +
+ Composite rendering with edge shading: tune opacity/threshold, and increase edge strength to make boundaries pop (stronger near the box center). +
+ +
+
Slice Inspector
+
+
+ + + +
+ +
+ Intensities are shown with a fixed 0 to 1 mapping. +
+ +
+ +
+ + {volume ? ( +
+ Volume dims: {dims.nx}×{dims.ny}×{dims.nz} +
+ ) : null} +
+
+
+ )} +
+ ); +}); diff --git a/frontend/src/components/TumorSegmentationOverlay.tsx b/frontend/src/components/TumorSegmentationOverlay.tsx new file mode 100644 index 0000000..de61a12 --- /dev/null +++ b/frontend/src/components/TumorSegmentationOverlay.tsx @@ -0,0 +1,2107 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { BarChart3, Copy, Download, Eye, EyeOff, RotateCcw, Save, Sparkles, Wand2, X } from 'lucide-react'; +import { propagateTumorAcrossSeries } from '../utils/tumorPropagation'; +import type { NormalizedPoint, TumorPolygon, TumorThreshold, ViewerTransform } from '../db/schema'; +import type { DicomViewerHandle } from './DicomViewer'; +import { + getAllTumorGroundTruth, + getSopInstanceUidForInstanceIndex, + getTumorGroundTruthForInstance, + getTumorSegmentationForInstance, + saveTumorSegmentation, +} from '../utils/localApi'; +import { + decodeCapturedPngToGrayscale, + estimateThresholdFromSeedPoints, + segmentTumorFromGrayscale, + type SegmentTumorOptions, +} from '../utils/segmentation/segmentTumor'; +import { runGtBenchmark } from '../utils/segmentation/gtBenchmark'; +import { exportTumorHarnessDatasetAndDownload } from '../utils/segmentation/harness/exportTumorHarnessDataset'; +import { computeMaskMetrics, type MaskMetrics } from '../utils/segmentation/maskMetrics'; +import { + computePolygonBoundaryMetrics, + type PolygonBoundaryMetrics, +} from '../utils/segmentation/polygonBoundaryMetrics'; +import { rasterizePolygonToMask } from '../utils/segmentation/rasterizePolygon'; +import { + normalizeViewerTransform, + remapPolygonBetweenViewerTransforms, + remapPointsBetweenViewerTransforms, +} from '../utils/viewTransform'; + +function clamp01(v: number) { + return Math.max(0, Math.min(1, v)); +} + +function polygonToSvgPath(p: TumorPolygon): string { + if (!p.points.length) return ''; + + const d = [`M ${p.points[0].x.toFixed(4)} ${p.points[0].y.toFixed(4)}`]; + for (let i = 1; i < p.points.length; i++) { + d.push(`L ${p.points[i].x.toFixed(4)} ${p.points[i].y.toFixed(4)}`); + } + d.push('Z'); + return d.join(' '); +} + +function polygonBounds01(p: TumorPolygon): { minX: number; minY: number; maxX: number; maxY: number } { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const pt of p.points) { + if (pt.x < minX) minX = pt.x; + if (pt.y < minY) minY = pt.y; + if (pt.x > maxX) maxX = pt.x; + if (pt.y > maxY) maxY = pt.y; + } + + return { + minX: clamp01(minX), + minY: clamp01(minY), + maxX: clamp01(maxX), + maxY: clamp01(maxY), + }; +} + +function pointsBounds01(points: NormalizedPoint[]): { minX: number; minY: number; maxX: number; maxY: number } { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const pt of points) { + if (pt.x < minX) minX = pt.x; + if (pt.y < minY) minY = pt.y; + if (pt.x > maxX) maxX = pt.x; + if (pt.y > maxY) maxY = pt.y; + } + + return { + minX: clamp01(minX), + minY: clamp01(minY), + maxX: clamp01(maxX), + maxY: clamp01(maxY), + }; +} + +export type TumorSegmentationOverlayProps = { + enabled: boolean; + onRequestClose: () => void; + + viewerRef: React.RefObject; + + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + /** Instance index in effective slice ordering (i.e. after reverseSliceOrder mapping). */ + effectiveInstanceIndex: number; + + /** Current viewer transform (pan/zoom/rotation/affine). */ + viewerTransform: ViewerTransform; +}; + +export function TumorSegmentationOverlay({ + enabled, + onRequestClose, + viewerRef, + comboId, + dateIso, + studyId, + seriesUid, + effectiveInstanceIndex, + viewerTransform, +}: TumorSegmentationOverlayProps) { + const containerRef = useRef(null); + + // Keep the latest viewer transform in a ref so we can snapshot it at specific lifecycle moments + // (e.g. when capturing a PNG) without re-running those effects on every pan/zoom/rotation change. + const viewerTransformRef = useRef(viewerTransform); + useEffect(() => { + viewerTransformRef.current = viewerTransform; + }, [viewerTransform]); + + const [paintPoints, setPaintPoints] = useState([]); + const [paintPointsViewTransform, setPaintPointsViewTransform] = useState(null); + const [isPainting, setIsPainting] = useState(false); + + const [draftThreshold, setDraftThreshold] = useState(null); + const [draftPolygon, setDraftPolygon] = useState(null); + const [draftPolygonViewTransform, setDraftPolygonViewTransform] = useState(null); + const [draftSeed, setDraftSeed] = useState(null); + + const [savedPolygon, setSavedPolygon] = useState(null); + const [savedPolygonViewTransform, setSavedPolygonViewTransform] = useState(null); + const [savedSeed, setSavedSeed] = useState(null); + const [savedThreshold, setSavedThreshold] = useState(null); + + const [groundTruthPolygon, setGroundTruthPolygon] = useState(null); + const [groundTruthPolygonViewTransform, setGroundTruthPolygonViewTransform] = useState(null); + + const [tunedOptions, setTunedOptions] = useState(null); + + const [gtMetrics, setGtMetrics] = useState(null); + const [gtBoundaryMetrics, setGtBoundaryMetrics] = useState(null); + + const [diffOverlayEnabled, setDiffOverlayEnabled] = useState(true); + const diffCanvasRef = useRef(null); + + const [autoTuneStatus, setAutoTuneStatus] = useState<{ running: boolean; message?: string }>( + () => ({ running: false }) + ); + const [gtBenchmarkStatus, setGtBenchmarkStatus] = useState<{ running: boolean; message?: string }>( + () => ({ running: false }) + ); + const [harnessExportStatus, setHarnessExportStatus] = useState<{ running: boolean; message?: string }>( + () => ({ running: false }) + ); + const [autoTuneLastStats, setAutoTuneLastStats] = useState< + | { + evals: { + stage1TolSweep: number; + stage2ParamTune: number; + stage3TolRefine: number; + stage4PolyTune: number; + total: number; + }; + ms: { + stage1TolSweep: number; + stage2ParamTune: number; + stage3TolRefine: number; + stage4PolyTune: number; + total: number; + }; + } + | null + >(null); + const [autoTuneLastBest, setAutoTuneLastBest] = useState< + | { + anchor: number; + tol: number; + opts: SegmentTumorOptions | undefined; + metrics: MaskMetrics; + boundary: PolygonBoundaryMetrics; + paintLeakPx: number; + paintDistMeanPx: number; + paintDistP95Px: number; + paintDistMaxPx: number; + } + | null + >(null); + + const [containerSize, setContainerSize] = useState<{ w: number; h: number }>({ w: 0, h: 0 }); + + // Cache the grayscale pixels captured after the user paints so threshold tuning doesn't + // re-capture PNGs (which can be slow/flaky and was causing "Error 5"-style crashes). + const capturedRef = useRef<{ gray: Uint8Array; w: number; h: number; viewTransform: ViewerTransform } | null>(null); + const [captureVersion, setCaptureVersion] = useState(0); + + const [busy, setBusy] = useState(false); + const busyRef = useRef(false); // Track busy state in ref for use in effects + const [error, setError] = useState(null); + + // Tolerance slider: anchor stays fixed, tolerance changes. + // This makes the segmentation area monotonic with slider movement. + const [thresholdAnchor, setThresholdAnchor] = useState(null); + const [thresholdTolerance, setThresholdTolerance] = useState(24); + + const effectiveThresholdFromSlider: TumorThreshold = useMemo(() => { + const anchor = Math.max(0, Math.min(255, Math.round(thresholdAnchor ?? 128))); + const tolerance = Math.max(0, Math.min(127, Math.round(thresholdTolerance))); + return { + low: Math.max(0, anchor - tolerance), + high: Math.min(255, anchor + tolerance), + anchor, + tolerance, + }; + }, [thresholdAnchor, thresholdTolerance]); + + const computeDraftFromCurrentCapture = useCallback( + (threshold: TumorThreshold, overrideOpts?: SegmentTumorOptions) => { + const opts = overrideOpts ?? tunedOptions ?? undefined; + + console.log('[TumorOverlay] computeDraftFromCurrentCapture START', { + threshold, + paintPointsCount: paintPoints.length, + opts, + }); + const t0 = performance.now(); + + const cap = capturedRef.current; + if (!cap) { + console.error('[TumorOverlay] No captured image available'); + throw new Error('No captured image available'); + } + + console.log('[TumorOverlay] Captured image:', { w: cap.w, h: cap.h, grayLength: cap.gray.length }); + + try { + const result = segmentTumorFromGrayscale(cap.gray, cap.w, cap.h, paintPoints, threshold, opts); + console.log('[TumorOverlay] Segmentation result:', { pointsCount: result.polygon.points.length, area: result.meta.areaPx }); + setDraftPolygon(result.polygon); + setDraftPolygonViewTransform(cap.viewTransform); + setDraftThreshold(threshold); + setDraftSeed(result.seed); + const elapsed = performance.now() - t0; + console.log('[TumorOverlay] computeDraftFromCurrentCapture DONE', { elapsed: elapsed.toFixed(1) + 'ms' }); + } catch (err) { + console.error('[TumorOverlay] Segmentation failed:', err); + throw err; + } + }, + [paintPoints, tunedOptions] + ); + + // Load existing saved segmentation when enabled or when slice changes. + useEffect(() => { + if (!enabled) return; + + let cancelled = false; + (async () => { + try { + setError(null); + const sop = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + + const [row, gt] = await Promise.all([ + getTumorSegmentationForInstance(seriesUid, sop), + getTumorGroundTruthForInstance(seriesUid, sop), + ]); + + if (cancelled) return; + + const fallbackView = normalizeViewerTransform(null); + + setSavedPolygon(row?.polygon ?? null); + setSavedPolygonViewTransform(row?.meta?.viewTransform ?? fallbackView); + setSavedSeed(row?.seed ?? null); + setSavedThreshold(row?.threshold ?? null); + + setGroundTruthPolygon(gt?.polygon ?? null); + setGroundTruthPolygonViewTransform(gt?.viewTransform ?? fallbackView); + } catch (e) { + console.error(e); + } + })(); + + return () => { + cancelled = true; + }; + }, [enabled, seriesUid, effectiveInstanceIndex]); + + // Reset draft state when turning on. + useEffect(() => { + if (!enabled) return; + setPaintPoints([]); + setPaintPointsViewTransform(null); + setDraftPolygon(null); + setDraftPolygonViewTransform(null); + setDraftThreshold(null); + setDraftSeed(null); + setError(null); + setGroundTruthPolygon(null); + setGroundTruthPolygonViewTransform(null); + capturedRef.current = null; + setCaptureVersion((v) => v + 1); + setGtMetrics(null); + setGtBoundaryMetrics(null); + }, [enabled]); + + // Track container size (used for brush sizing + UI placement). + useEffect(() => { + if (!enabled) return; + const el = containerRef.current; + if (!el) return; + + const update = () => { + const r = el.getBoundingClientRect(); + setContainerSize({ w: r.width, h: r.height }); + }; + + update(); + const ro = new ResizeObserver(update); + ro.observe(el); + return () => ro.disconnect(); + }, [enabled]); + + const getLocalNormPoint = useCallback((e: PointerEvent | React.PointerEvent): NormalizedPoint | null => { + const el = containerRef.current; + if (!el) return null; + const r = el.getBoundingClientRect(); + if (r.width <= 0 || r.height <= 0) return null; + const x = ((e as PointerEvent).clientX - r.left) / r.width; + const y = ((e as PointerEvent).clientY - r.top) / r.height; + return { x: clamp01(x), y: clamp01(y) }; + }, []); + + const onPointerDown = useCallback( + (e: React.PointerEvent) => { + if (!enabled) return; + if (!e.isPrimary) return; + if (e.button !== 0) return; + + didPaintRef.current = false; + + // Avoid starting a paint gesture on overlay buttons. + const target = e.target as HTMLElement | null; + if (target?.closest('[data-tumor-ui="true"]')) return; + + const p = getLocalNormPoint(e); + if (!p) return; + + setError(null); + setPaintPointsViewTransform({ ...viewerTransformRef.current }); + setPaintPoints([p]); + setDraftPolygon(null); + setDraftPolygonViewTransform(null); + setDraftThreshold(null); + setDraftSeed(null); + setIsPainting(true); + capturedRef.current = null; + setCaptureVersion((v) => v + 1); + setGtMetrics(null); + setGtBoundaryMetrics(null); + + try { + containerRef.current?.setPointerCapture(e.pointerId); + } catch { + // ignore + } + }, + [enabled, getLocalNormPoint] + ); + + const didPaintRef = useRef(false); + + const onPointerMove = useCallback( + (e: React.PointerEvent) => { + if (!enabled) return; + if (!isPainting) return; + const p = getLocalNormPoint(e); + if (!p) return; + setPaintPoints((prev) => { + const last = prev[prev.length - 1]; + if (last && Math.hypot(last.x - p.x, last.y - p.y) < 0.002) { + return prev; + } + didPaintRef.current = true; + return [...prev, p]; + }); + }, + [enabled, getLocalNormPoint, isPainting] + ); + + // Prevent the underlying viewer from interpreting a paint drag as a click. + const onClickCapture = useCallback((e: React.MouseEvent) => { + if (!didPaintRef.current) return; + didPaintRef.current = false; + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onPointerUp = useCallback( + async (e: React.PointerEvent) => { + if (!enabled) return; + if (!isPainting) return; + + try { + containerRef.current?.releasePointerCapture(e.pointerId); + } catch { + // ignore + } + + setIsPainting(false); + + // Ignore click-only gestures. + if (paintPoints.length < 4) { + setError('Click and drag to paint over the tumor region'); + return; + } + + // Compute initial segmentation. + busyRef.current = true; + setBusy(true); + try { + console.log('[TumorOverlay] Starting initial segmentation after paint'); + const v = viewerRef.current; + if (!v) throw new Error('Viewer not ready'); + + const png = await v.captureVisiblePng({ maxSize: 512 }); // Higher resolution for smoother polygons + console.log('[TumorOverlay] PNG captured, decoding...'); + const decoded = await decodeCapturedPngToGrayscale(png); + console.log('[TumorOverlay] Decoded grayscale:', { w: decoded.width, h: decoded.height }); + capturedRef.current = { + gray: decoded.gray, + w: decoded.width, + h: decoded.height, + viewTransform: { ...viewerTransformRef.current }, + }; + setCaptureVersion((v) => v + 1); + + const initialThreshold = estimateThresholdFromSeedPoints( + decoded.gray, + decoded.width, + decoded.height, + paintPoints + ); + console.log('[TumorOverlay] Initial threshold:', initialThreshold); + + // Initialize tolerance-mode slider state. + const anchor = + typeof initialThreshold.anchor === 'number' + ? initialThreshold.anchor + : Math.round((initialThreshold.low + initialThreshold.high) / 2); + const tolerance = + typeof initialThreshold.tolerance === 'number' + ? initialThreshold.tolerance + : Math.round((initialThreshold.high - initialThreshold.low) / 2); + + setThresholdAnchor(anchor); + setThresholdTolerance(tolerance); + + // Use the threshold derived from (anchor, tolerance) so the slider starts "in sync". + const thresholdForSeg: TumorThreshold = { + low: Math.max(0, Math.min(255, Math.round(anchor - tolerance))), + high: Math.max(0, Math.min(255, Math.round(anchor + tolerance))), + anchor, + tolerance, + }; + + computeDraftFromCurrentCapture(thresholdForSeg); + + console.log('[TumorOverlay] Initial segmentation complete'); + } catch (err) { + console.error('[TumorOverlay] Initial segmentation error:', err); + setError(err instanceof Error ? err.message : 'Segmentation failed'); + } finally { + busyRef.current = false; + setBusy(false); + } + }, + [computeDraftFromCurrentCapture, enabled, isPainting, paintPoints, viewerRef] + ); + + // Live update segmentation when threshold changes (debounced). + // + // Important: we reuse the grayscale pixels captured after painting, instead of re-capturing + // PNGs on every slider move (which can be slow/flaky and was causing crashes). + useEffect(() => { + if (!enabled) return; + if (!draftPolygon) return; + if (!draftThreshold) return; + if (paintPoints.length < 4) return; + if (!capturedRef.current) return; + + if (draftThreshold.low === effectiveThresholdFromSlider.low && draftThreshold.high === effectiveThresholdFromSlider.high) { + return; + } + + // Don't trigger new segmentation while one is in progress. + if (busyRef.current) { + console.log('[TumorOverlay] Skipping threshold update - busy'); + return; + } + + // Debounce slider changes to avoid firing too frequently. + const timeout = window.setTimeout(() => { + // Double-check busy state at execution time. + if (busyRef.current) { + console.log('[TumorOverlay] Skipping threshold update at exec time - busy'); + return; + } + + busyRef.current = true; + setBusy(true); + + // Use requestAnimationFrame to give UI a chance to update. + requestAnimationFrame(() => { + try { + console.log('[TumorOverlay] Running threshold-triggered segmentation'); + computeDraftFromCurrentCapture(effectiveThresholdFromSlider); + } catch (err) { + console.error('[TumorOverlay] Threshold segmentation error:', err); + setError(err instanceof Error ? err.message : 'Segmentation failed'); + } finally { + busyRef.current = false; + setBusy(false); + } + }); + }, 150); // Increased debounce to 150ms + + return () => window.clearTimeout(timeout); + }, [draftPolygon, draftThreshold, effectiveThresholdFromSlider, enabled, paintPoints.length, computeDraftFromCurrentCapture]); + + const viewSize = useMemo(() => ({ w: containerSize.w, h: containerSize.h }), [containerSize.h, containerSize.w]); + + const paintPointsDisplay = useMemo(() => { + if (paintPoints.length === 0) return []; + const from = paintPointsViewTransform ?? viewerTransform; + return viewSize.w > 0 && viewSize.h > 0 + ? remapPointsBetweenViewerTransforms(paintPoints, viewSize, from, viewerTransform) + : paintPoints; + }, [paintPoints, paintPointsViewTransform, viewSize, viewerTransform]); + + const draftPolygonDisplay = useMemo(() => { + if (!draftPolygon) return null; + const from = draftPolygonViewTransform ?? viewerTransform; + return viewSize.w > 0 && viewSize.h > 0 + ? remapPolygonBetweenViewerTransforms(draftPolygon, viewSize, from, viewerTransform) + : draftPolygon; + }, [draftPolygon, draftPolygonViewTransform, viewSize, viewerTransform]); + + const savedPolygonDisplay = useMemo(() => { + if (!savedPolygon) return null; + const from = savedPolygonViewTransform ?? viewerTransform; + return viewSize.w > 0 && viewSize.h > 0 + ? remapPolygonBetweenViewerTransforms(savedPolygon, viewSize, from, viewerTransform) + : savedPolygon; + }, [savedPolygon, savedPolygonViewTransform, viewSize, viewerTransform]); + + const groundTruthPolygonDisplay = useMemo(() => { + if (!groundTruthPolygon) return null; + const from = groundTruthPolygonViewTransform ?? viewerTransform; + return viewSize.w > 0 && viewSize.h > 0 + ? remapPolygonBetweenViewerTransforms(groundTruthPolygon, viewSize, from, viewerTransform) + : groundTruthPolygon; + }, [groundTruthPolygon, groundTruthPolygonViewTransform, viewSize, viewerTransform]); + + const draftPath = useMemo(() => { + if (!draftPolygonDisplay) return ''; + return polygonToSvgPath(draftPolygonDisplay); + }, [draftPolygonDisplay]); + + const savedPath = useMemo(() => { + if (!savedPolygonDisplay) return ''; + return polygonToSvgPath(savedPolygonDisplay); + }, [savedPolygonDisplay]); + + const groundTruthPath = useMemo(() => { + if (!groundTruthPolygonDisplay) return ''; + return polygonToSvgPath(groundTruthPolygonDisplay); + }, [groundTruthPolygonDisplay]); + + const [propStatus, setPropStatus] = useState<{ running: boolean; saved: number; message?: string }>( + { running: false, saved: 0 } + ); + + const onSave = useCallback(async () => { + if (!draftPolygon || !draftThreshold || !draftSeed) return; + + setBusy(true); + setError(null); + try { + const sop = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + + const view = + draftPolygonViewTransform ?? + capturedRef.current?.viewTransform ?? + ({ ...viewerTransformRef.current } as ViewerTransform); + + const viewportSize = + viewSize.w > 0 && viewSize.h > 0 + ? { w: Math.round(viewSize.w), h: Math.round(viewSize.h) } + : undefined; + + await saveTumorSegmentation({ + comboId, + dateIso, + studyId, + seriesUid, + sopInstanceUid: sop, + polygon: draftPolygon, + threshold: draftThreshold, + seed: draftSeed, + meta: { viewTransform: view, viewportSize }, + }); + + setSavedPolygon(draftPolygon); + setSavedPolygonViewTransform(view); + setSavedSeed(draftSeed); + setSavedThreshold(draftThreshold); + } catch (err) { + console.error(err); + setError(err instanceof Error ? err.message : 'Save failed'); + } finally { + setBusy(false); + } + }, [ + comboId, + dateIso, + draftPolygon, + draftPolygonViewTransform, + draftSeed, + draftThreshold, + effectiveInstanceIndex, + seriesUid, + studyId, + viewSize, + ]); + + const onPropagateSeries = useCallback(async () => { + if (!savedSeed || !savedThreshold) return; + + if (viewSize.w <= 0 || viewSize.h <= 0) { + setPropStatus({ running: false, saved: 0, message: 'Viewer size not ready (try again).' }); + return; + } + + setPropStatus({ running: true, saved: 0, message: 'Propagating…' }); + try { + const result = await propagateTumorAcrossSeries({ + comboId, + dateIso, + studyId, + seriesUid, + viewportSize: viewSize, + startEffectiveIndex: effectiveInstanceIndex, + seed: savedSeed, + seedViewTransform: savedPolygonViewTransform ?? { ...viewerTransformRef.current }, + threshold: savedThreshold, + stop: { + minAreaPx: 80, + maxMissesInARow: 3, + }, + onProgress: ({ direction, index, saved }) => { + setPropStatus({ + running: true, + saved, + message: `Propagating ${direction} (slice ${index + 1})…`, + }); + }, + }); + + setPropStatus({ + running: false, + saved: result.saved, + message: `Propagation complete (saved ${result.saved} slices).`, + }); + } catch (err) { + console.error(err); + setPropStatus({ running: false, saved: 0, message: err instanceof Error ? err.message : 'Propagation failed' }); + } + }, [comboId, dateIso, effectiveInstanceIndex, savedPolygonViewTransform, savedSeed, savedThreshold, seriesUid, studyId, viewSize]); + + const clearDiffOverlay = useCallback(() => { + const canvas = diffCanvasRef.current; + if (!canvas) return; + const ctx = canvas.getContext('2d'); + if (!ctx) return; + ctx.clearRect(0, 0, canvas.width, canvas.height); + }, []); + + const drawDiffOverlay = useCallback( + (predMask: Uint8Array, gtMask: Uint8Array, w: number, h: number) => { + const canvas = diffCanvasRef.current; + if (!canvas) return; + + if (canvas.width !== w || canvas.height !== h) { + canvas.width = w; + canvas.height = h; + } + + const ctx = canvas.getContext('2d'); + if (!ctx) return; + + const rgba = new Uint8ClampedArray(w * h * 4); + + // FN (miss): red. FP (over-seg): magenta. + for (let i = 0; i < predMask.length; i++) { + const p = predMask[i] ? 1 : 0; + const g = gtMask[i] ? 1 : 0; + + const o = i * 4; + if (g && !p) { + rgba[o] = 255; + rgba[o + 1] = 0; + rgba[o + 2] = 0; + rgba[o + 3] = 150; + } else if (!g && p) { + rgba[o] = 255; + rgba[o + 1] = 0; + rgba[o + 2] = 255; + rgba[o + 3] = 110; + } + } + + const img = + typeof ImageData !== 'undefined' + ? new ImageData(rgba, w, h) + : (() => { + const id = ctx.createImageData(w, h); + id.data.set(rgba); + return id; + })(); + + ctx.putImageData(img, 0, 0); + }, + [] + ); + + useEffect(() => { + if (!enabled) return; + + const cap = capturedRef.current; + const gtRaw = groundTruthPolygon; + const predRaw = draftPolygon ?? savedPolygon; + + if (!cap || !gtRaw || !predRaw) { + setGtMetrics(null); + setGtBoundaryMetrics(null); + clearDiffOverlay(); + return; + } + + const size = { w: cap.w, h: cap.h }; + const evalView = cap.viewTransform; + + const predFrom = draftPolygon + ? draftPolygonViewTransform ?? evalView + : savedPolygonViewTransform ?? evalView; + + const gtFrom = groundTruthPolygonViewTransform ?? evalView; + + try { + // Metrics are computed in the capture/eval view so they stay stable as the user pans/zooms. + const gtEval = remapPolygonBetweenViewerTransforms(gtRaw, size, gtFrom, evalView); + const predEval = remapPolygonBetweenViewerTransforms(predRaw, size, predFrom, evalView); + + const gtMaskEval = rasterizePolygonToMask(gtEval, cap.w, cap.h); + const predMaskEval = rasterizePolygonToMask(predEval, cap.w, cap.h); + + const metrics = computeMaskMetrics(predMaskEval, gtMaskEval); + const boundary = computePolygonBoundaryMetrics(predEval, gtEval, cap.w, cap.h); + + setGtMetrics(metrics); + setGtBoundaryMetrics(boundary); + + // Diff overlay is drawn in the *current* viewer transform so it stays visually aligned. + if (diffOverlayEnabled) { + const gtDisplay = remapPolygonBetweenViewerTransforms(gtRaw, size, gtFrom, viewerTransform); + const predDisplay = remapPolygonBetweenViewerTransforms(predRaw, size, predFrom, viewerTransform); + + const gtMaskDisplay = rasterizePolygonToMask(gtDisplay, cap.w, cap.h); + const predMaskDisplay = rasterizePolygonToMask(predDisplay, cap.w, cap.h); + + drawDiffOverlay(predMaskDisplay, gtMaskDisplay, cap.w, cap.h); + } else { + clearDiffOverlay(); + } + } catch (e) { + console.error('[TumorOverlay] GT evaluation failed:', e); + setGtMetrics(null); + setGtBoundaryMetrics(null); + clearDiffOverlay(); + } + }, [ + captureVersion, + clearDiffOverlay, + diffOverlayEnabled, + drawDiffOverlay, + draftPolygon, + draftPolygonViewTransform, + enabled, + groundTruthPolygon, + groundTruthPolygonViewTransform, + savedPolygon, + savedPolygonViewTransform, + viewerTransform, + ]); + + const onCopyGtReport = useCallback(async () => { + const effectiveThreshold = (() => { + const t = effectiveThresholdFromSlider; + if (!t) return null; + + // The slider threshold is symmetric (anchor ± tolerance), but the segmentation can apply + // asymmetric scaling (toleranceLowScale/toleranceHighScale) via tunedOptions. + const tol = t.tolerance ?? Math.round((t.high - t.low) / 2); + + // NOTE: These defaults must match `segmentTumorFromGrayscale` (segmentTumor.ts). Otherwise + // the report can be misleading: the UI slider shows a symmetric range, but segmentation may + // apply asymmetric scaling even when the user hasn't explicitly tuned it. + const lowScale = tunedOptions?.toleranceLowScale ?? 1; + const highScale = tunedOptions?.toleranceHighScale ?? 1; + + const clamp8 = (v: number) => Math.max(0, Math.min(255, Math.round(v))); + + const anchorRaw = + typeof t.anchor === 'number' && Number.isFinite(t.anchor) ? t.anchor : Math.round((t.low + t.high) * 0.5); + const anchor = clamp8(anchorRaw); + + return { + anchor, + tolerance: tol, + toleranceLowScale: lowScale, + toleranceHighScale: highScale, + low: clamp8(anchor - tol * lowScale), + high: clamp8(anchor + tol * highScale), + }; + })(); + + const report = { + comboId, + dateIso, + seriesUid, + effectiveInstanceIndex, + capture: capturedRef.current ? { w: capturedRef.current.w, h: capturedRef.current.h } : null, + threshold: effectiveThresholdFromSlider, + effectiveThreshold, + tunedOptions, + paintPointsCount: paintPoints.length, + draftPolygonPoints: draftPolygon?.points.length ?? 0, + savedPolygonPoints: savedPolygon?.points.length ?? 0, + gtPolygonPoints: groundTruthPolygon?.points.length ?? 0, + metrics: gtMetrics, + boundaryMetrics: gtBoundaryMetrics, + autoTuneLastStats, + autoTuneLastBest, + note: 'Auto-tune uses a recall guardrail, then prioritizes staying near paint + reducing FP + boundary overshoot. Metrics are vs GT when available. FN=miss (red), FP=over-seg (magenta).', + }; + + const text = JSON.stringify(report, null, 2); + + try { + await navigator.clipboard.writeText(text); + } catch (e) { + console.warn('[TumorOverlay] Failed to write to clipboard; logging report instead', e); + console.log('[TumorOverlay] GT report:', report); + setError('Failed to copy; report was logged to console.'); + } + }, [ + autoTuneLastBest, + autoTuneLastStats, + comboId, + dateIso, + effectiveInstanceIndex, + effectiveThresholdFromSlider, + draftPolygon, + groundTruthPolygon, + gtMetrics, + gtBoundaryMetrics, + paintPoints, + savedPolygon, + seriesUid, + tunedOptions, + ]); + + const onCopyGtBenchmark = useCallback(async () => { + if (busyRef.current || gtBenchmarkStatus.running) return; + + const prevBusy = busyRef.current; + busyRef.current = true; + setBusy(true); + setError(null); + setGtBenchmarkStatus({ running: true, message: 'Benchmark: loading GT rows…' }); + + try { + const gtRows = await getAllTumorGroundTruth(); + const cases = gtRows + .filter((r) => (r.polygon?.points?.length ?? 0) >= 3) + .map((r) => ({ + id: r.id, + comboId: r.comboId, + dateIso: r.dateIso, + seriesUid: r.seriesUid, + sopInstanceUid: r.sopInstanceUid, + gtPolygon: r.polygon, + gtViewTransform: r.viewTransform, + gtViewportSize: r.viewportSize, + })); + + if (cases.length === 0) { + throw new Error('No ground truth polygons found in IndexedDB.'); + } + + const v2Off: SegmentTumorOptions = { + bgModel: { enabled: false }, + geodesic: { enabled: false }, + }; + + const v2Bg: SegmentTumorOptions = { + bgModel: { enabled: true }, + geodesic: { enabled: false }, + }; + + const v2BgGeo: SegmentTumorOptions = { + bgModel: { enabled: true }, + geodesic: { enabled: true }, + }; + + const configs = [ + { name: 'baseline', opts: v2Off }, + { name: 'v2:bg', opts: v2Bg }, + { name: 'v2:bg+geo', opts: v2BgGeo }, + ...(tunedOptions + ? [ + { name: 'tuned', opts: { ...tunedOptions, ...v2Off } }, + { name: 'tuned+v2:bg', opts: { ...tunedOptions, ...v2Bg } }, + { name: 'tuned+v2:bg+geo', opts: { ...tunedOptions, ...v2BgGeo } }, + ] + : []), + ]; + + const report = await runGtBenchmark({ + cases, + configs, + maxEvalDim: 256, + paintPointsPerCase: 24, + yieldEveryCases: 1, + onProgress: (p) => setGtBenchmarkStatus({ running: true, message: p.message }), + }); + + const wrapped = { + comboId, + dateIso, + tunedOptions, + report, + note: 'Benchmark uses deterministic auto-generated paint points derived from GT polygons. baseline forces v2 features off so results are comparable even if localStorage segmentation flags are set. v2:* enables brush-only background model and/or geodesic edge-aware gating.', + }; + + const text = JSON.stringify(wrapped, null, 2); + + try { + await navigator.clipboard.writeText(text); + setGtBenchmarkStatus({ running: false, message: `Benchmark copied (${cases.length} cases).` }); + } catch (e) { + console.warn('[TumorOverlay] Failed to write benchmark to clipboard; logging instead', e); + console.log('[TumorOverlay] GT benchmark report:', wrapped); + setGtBenchmarkStatus({ running: false, message: 'Benchmark done (logged to console).' }); + setError('Failed to copy; benchmark report was logged to console.'); + } + } catch (e) { + console.error('[TumorOverlay] GT benchmark failed:', e); + setGtBenchmarkStatus({ running: false, message: 'Benchmark failed (see console).' }); + setError(e instanceof Error ? e.message : 'GT benchmark failed'); + } finally { + busyRef.current = prevBusy; + setBusy(false); + } + }, [comboId, dateIso, gtBenchmarkStatus.running, tunedOptions]); + + const onExportHarnessDataset = useCallback(async () => { + if (busyRef.current || harnessExportStatus.running) return; + + const prevBusy = busyRef.current; + busyRef.current = true; + setBusy(true); + setError(null); + setHarnessExportStatus({ running: true, message: 'Export: loading ground truth rows…' }); + + try { + const gtRows = await getAllTumorGroundTruth(); + + const startSopInstanceUid = await getSopInstanceUidForInstanceIndex(seriesUid, effectiveInstanceIndex); + + await exportTumorHarnessDatasetAndDownload({ + maxEvalDim: 256, + gtRows, + paintPointsPerCase: 24, + propagationScenario: + paintPoints.length >= 4 && viewSize.w > 0 && viewSize.h > 0 + ? { + comboId, + dateIso, + studyId, + seriesUid, + startEffectiveIndex: effectiveInstanceIndex, + startSopInstanceUid, + paintPointsViewer01: paintPoints, + paintPointsViewTransform: paintPointsViewTransform, + viewportSize: viewSize, + threshold: effectiveThresholdFromSlider, + stop: { minAreaPx: 80, maxMissesInARow: 3 }, + marginSlices: 2, + } + : undefined, + onProgress: (message) => setHarnessExportStatus({ running: true, message }), + }); + + setHarnessExportStatus({ running: false, message: 'Export complete (downloaded zip).' }); + } catch (e) { + console.error('[TumorOverlay] Harness export failed:', e); + setHarnessExportStatus({ running: false, message: 'Export failed (see console).' }); + setError(e instanceof Error ? e.message : 'Export failed'); + } finally { + busyRef.current = prevBusy; + setBusy(false); + } + }, [ + comboId, + dateIso, + effectiveInstanceIndex, + effectiveThresholdFromSlider, + harnessExportStatus.running, + paintPoints, + paintPointsViewTransform, + seriesUid, + studyId, + viewSize, + ]); + + const onResetTuning = useCallback(() => { + // Force a recompute with default options so the UI reflects the reset immediately. + try { + if (capturedRef.current && paintPoints.length >= 4) { + computeDraftFromCurrentCapture(effectiveThresholdFromSlider, {}); + } + } catch (e) { + console.error('[TumorOverlay] Failed to recompute after tuning reset:', e); + } + + setTunedOptions(null); + }, [computeDraftFromCurrentCapture, effectiveThresholdFromSlider, paintPoints]); + + const onAutoTune = useCallback(async () => { + if (busyRef.current || autoTuneStatus.running) return; + + const cap = capturedRef.current; + if (!cap) { + setError('Paint first (we need a captured image to evaluate against GT).'); + return; + } + + if (paintPoints.length < 4) { + setError('Paint first (not enough paint points to run segmentation).'); + return; + } + + if (!groundTruthPolygon) { + setError('No ground truth polygon found for this slice. Use the GT tool to draw + save one.'); + return; + } + + const anchor = Math.round( + thresholdAnchor ?? + (draftThreshold?.anchor ?? + (draftThreshold ? (draftThreshold.low + draftThreshold.high) / 2 : 128)) + ); + const baseTol = Math.max(0, Math.min(127, Math.round(thresholdTolerance))); + + // Auto-tune can be expensive (many candidates). Use a downsampled grid for mask metrics. + // Polygons are in normalized coords, so rasterizing at lower res is a good approximation. + const evalW = Math.max(128, Math.round(cap.w / 2)); + const evalH = Math.max(128, Math.round(cap.h / 2)); + + const gtFrom = groundTruthPolygonViewTransform ?? cap.viewTransform; + + // Auto-tune evaluates candidates in the captured-image coordinate system (cap.viewTransform). + // If the GT polygon was drawn under a different viewer transform, re-project it into cap space. + const gtPolyEval = remapPolygonBetweenViewerTransforms( + groundTruthPolygon, + { w: evalW, h: evalH }, + gtFrom, + cap.viewTransform + ); + const gtMask = rasterizePolygonToMask(gtPolyEval, evalW, evalH); + + const gtPolyCap = remapPolygonBetweenViewerTransforms( + groundTruthPolygon, + { w: cap.w, h: cap.h }, + gtFrom, + cap.viewTransform + ); + + // Heuristic: penalize candidates whose predicted polygon drifts far outside the user's painted bbox. + // This correlates strongly with "leaky" segmentations (low precision) and matches what feels wrong in the UI. + const paintBounds = pointsBounds01(paintPoints); + + // Additional heuristic: measure how far predicted pixels are from the painted stroke itself. + // This catches the common failure mode where the polygon stays within the paint bbox but still + // "fills" large same-intensity regions far from the actual stroke. + const paintDistEval = (() => { + const out = new Int32Array(evalW * evalH); + out.fill(-1); + + const qx = new Int32Array(evalW * evalH); + const qy = new Int32Array(evalW * evalH); + let qh = 0; + let qt = 0; + + const push = (x: number, y: number, d: number) => { + const i = y * evalW + x; + if (out[i] !== -1) return; + out[i] = d; + qx[qt] = x; + qy[qt] = y; + qt++; + }; + + for (const p of paintPoints) { + const x = Math.max(0, Math.min(evalW - 1, Math.round(p.x * (evalW - 1)))); + const y = Math.max(0, Math.min(evalH - 1, Math.round(p.y * (evalH - 1)))); + push(x, y, 0); + } + + // Fallback: if somehow we have no seeds, treat everything as far away. + if (qt === 0) { + out.fill(999999); + return out; + } + + while (qh < qt) { + const x = qx[qh]!; + const y = qy[qh]!; + qh++; + + const base = out[y * evalW + x]!; + const nd = base + 1; + + if (x > 0) push(x - 1, y, nd); + if (x < evalW - 1) push(x + 1, y, nd); + if (y > 0) push(x, y - 1, nd); + if (y < evalH - 1) push(x, y + 1, nd); + } + + return out; + })(); + + const mkThreshold = (thAnchor: number, tol: number): TumorThreshold => { + const a = Math.max(0, Math.min(255, Math.round(thAnchor))); + const t = Math.max(0, Math.min(127, Math.round(tol))); + return { + low: Math.max(0, Math.min(255, a - t)), + high: Math.max(0, Math.min(255, a + t)), + anchor: a, + tolerance: t, + }; + }; + + type Candidate = { + anchor: number; + tol: number; + opts: SegmentTumorOptions | undefined; + metrics: MaskMetrics; + boundary: PolygonBoundaryMetrics; + /** Max outward expansion beyond the painted bbox (pixels). Lower is better. */ + paintLeakPx: number; + /** Mean Manhattan distance (in eval pixels) from predicted pixels to the painted stroke. Lower is better. */ + paintDistMeanPx: number; + /** 95th percentile Manhattan distance (eval px) from predicted pixels to the painted stroke. Lower is better. */ + paintDistP95Px: number; + /** Maximum Manhattan distance (eval px) from predicted pixels to the painted stroke. Lower is better. */ + paintDistMaxPx: number; + }; + + const isBetter = (a: Candidate, b: Candidate) => { + // IMPORTANT: use near-tie thresholds so we don't choose a meaningfully worse boundary + // just to gain ~0.0001 of overlap metric. + const EPS_F2 = 0.003; + const EPS_RECALL = 0.003; + const EPS_DICE = 0.003; + const EPS_BND_MEAN = 0.25; // px + const EPS_BND_MAX = 0.75; // px + + // Paint-leak/dist thresholds: + // - Keep them fairly small so auto-tune actively searches for candidates that stay near the paint. + // - We still allow a little slack because paint points are noisy (pointer sampling + brush width). + const EPS_PAINT_LEAK = 4; // px (full-res) + const EPS_PAINT_DIST_MEAN = 0.35; // eval px + const EPS_PAINT_DIST_P95 = 0.75; // eval px + const EPS_PAINT_DIST_MAX = 1.5; // eval px + + // When FP is extremely close (near-tie), prefer fewer FN to avoid under-segmentation. + // NOTE: This is at eval resolution, so values are smaller than full-res. + const EPS_FP_TIE = 3; + + // Guardrail: keep recall above a minimum, then aggressively optimize precision / boundary fit. + // + // Why: + // - In the UI, users can usually fix small FN by painting a bit more. + // - Huge FP (low precision) is much harder to correct and often corresponds to "escaping" the paint bbox. + const MIN_RECALL = 0.97; + const aMeetsRecall = a.metrics.recall >= MIN_RECALL; + const bMeetsRecall = b.metrics.recall >= MIN_RECALL; + + if (aMeetsRecall && !bMeetsRecall) return true; + if (bMeetsRecall && !aMeetsRecall) return false; + + if (aMeetsRecall && bMeetsRecall) { + // Within the acceptable-recall region, prioritize: + // 1) staying near the paint (prevents "fill" leaks) + // 2) fewer false positives (precision) + // 3) then boundary fit + // 4) then overlap + if (a.paintLeakPx < b.paintLeakPx - EPS_PAINT_LEAK) return true; + if (b.paintLeakPx < a.paintLeakPx - EPS_PAINT_LEAK) return false; + + // Use a tail metric first: mean can hide a small-but-important leak far from paint. + if (a.paintDistP95Px < b.paintDistP95Px - EPS_PAINT_DIST_P95) return true; + if (b.paintDistP95Px < a.paintDistP95Px - EPS_PAINT_DIST_P95) return false; + + if (a.paintDistMaxPx < b.paintDistMaxPx - EPS_PAINT_DIST_MAX) return true; + if (b.paintDistMaxPx < a.paintDistMaxPx - EPS_PAINT_DIST_MAX) return false; + + if (a.paintDistMeanPx < b.paintDistMeanPx - EPS_PAINT_DIST_MEAN) return true; + if (b.paintDistMeanPx < a.paintDistMeanPx - EPS_PAINT_DIST_MEAN) return false; + + if (Math.abs(a.metrics.fp - b.metrics.fp) <= EPS_FP_TIE) { + if (a.metrics.fn !== b.metrics.fn) return a.metrics.fn < b.metrics.fn; + } + + if (a.metrics.fp !== b.metrics.fp) return a.metrics.fp < b.metrics.fp; + + // Reduce outward leakage / overshoot. + if (a.boundary.meanPredToGtPx < b.boundary.meanPredToGtPx - EPS_BND_MEAN) return true; + if (b.boundary.meanPredToGtPx < a.boundary.meanPredToGtPx - EPS_BND_MEAN) return false; + + if (a.metrics.fn !== b.metrics.fn) return a.metrics.fn < b.metrics.fn; + + if (a.boundary.maxPredToGtPx < b.boundary.maxPredToGtPx - EPS_BND_MAX) return true; + if (b.boundary.maxPredToGtPx < a.boundary.maxPredToGtPx - EPS_BND_MAX) return false; + + if (a.boundary.meanSymPx < b.boundary.meanSymPx - EPS_BND_MEAN) return true; + if (b.boundary.meanSymPx < a.boundary.meanSymPx - EPS_BND_MEAN) return false; + + if (a.metrics.dice > b.metrics.dice + EPS_DICE) return true; + if (b.metrics.dice > a.metrics.dice + EPS_DICE) return false; + + return a.metrics.iou > b.metrics.iou; + } + + // Below the recall guardrail, keep optimizing for recall-weighted overlap. + if (a.metrics.f2 > b.metrics.f2 + EPS_F2) return true; + if (b.metrics.f2 > a.metrics.f2 + EPS_F2) return false; + + if (a.metrics.recall > b.metrics.recall + EPS_RECALL) return true; + if (b.metrics.recall > a.metrics.recall + EPS_RECALL) return false; + + if (a.metrics.fn !== b.metrics.fn) return a.metrics.fn < b.metrics.fn; + + if (a.boundary.meanPredToGtPx < b.boundary.meanPredToGtPx - EPS_BND_MEAN) return true; + if (b.boundary.meanPredToGtPx < a.boundary.meanPredToGtPx - EPS_BND_MEAN) return false; + + if (a.metrics.fp !== b.metrics.fp) return a.metrics.fp < b.metrics.fp; + + if (a.boundary.maxPredToGtPx < b.boundary.maxPredToGtPx - EPS_BND_MAX) return true; + if (b.boundary.maxPredToGtPx < a.boundary.maxPredToGtPx - EPS_BND_MAX) return false; + + if (a.metrics.dice > b.metrics.dice + EPS_DICE) return true; + if (b.metrics.dice > a.metrics.dice + EPS_DICE) return false; + + return a.metrics.iou > b.metrics.iou; + }; + + const paintDistMaxPossible = evalW + evalH; + const paintDistHist = new Int32Array(paintDistMaxPossible + 1); + + const evalCandidate = ( + thAnchor: number, + tol: number, + opts: SegmentTumorOptions | undefined + ): Candidate | null => { + const threshold = mkThreshold(thAnchor, tol); + try { + const res = segmentTumorFromGrayscale(cap.gray, cap.w, cap.h, paintPoints, threshold, opts); + const predMask = rasterizePolygonToMask(res.polygon, evalW, evalH); + + // Compute overlap metrics and paint-distance metrics in a single pass. + let tp = 0; + let fp = 0; + let fn = 0; + let tn = 0; + + paintDistHist.fill(0); + let predCount = 0; + let sumPaintDist = 0; + let paintDistMaxPx = 0; + + for (let i = 0; i < predMask.length; i++) { + const p = predMask[i] ? 1 : 0; + const g = gtMask[i] ? 1 : 0; + + if (p) { + predCount++; + + const dRaw = paintDistEval[i] ?? 0; + const d = Math.max(0, Math.min(paintDistMaxPossible, dRaw)); + + sumPaintDist += d; + paintDistHist[d] = (paintDistHist[d] ?? 0) + 1; + if (d > paintDistMaxPx) paintDistMaxPx = d; + } + + if (p && g) tp++; + else if (p && !g) fp++; + else if (!p && g) fn++; + else tn++; + } + + const safeDiv = (num: number, den: number) => (den > 0 ? num / den : 0); + const precision = safeDiv(tp, tp + fp); + const recall = safeDiv(tp, tp + fn); + const dice = safeDiv(2 * tp, 2 * tp + fp + fn); + const iou = safeDiv(tp, tp + fp + fn); + + const beta2 = 4; + const f2 = safeDiv((1 + beta2) * precision * recall, beta2 * precision + recall); + + const metrics: MaskMetrics = { tp, fp, fn, tn, precision, recall, dice, iou, f2 }; + + const paintDistMeanPx = predCount > 0 ? sumPaintDist / predCount : Number.POSITIVE_INFINITY; + + // Tail distance metric: approximate via histogram to avoid storing per-pixel distances. + let paintDistP95Px = Number.POSITIVE_INFINITY; + if (predCount > 0) { + const target = Math.ceil(predCount * 0.95); + let cum = 0; + for (let d = 0; d < paintDistHist.length; d++) { + cum += paintDistHist[d] ?? 0; + if (cum >= target) { + paintDistP95Px = d; + break; + } + } + } + + const boundary = computePolygonBoundaryMetrics(res.polygon, gtPolyCap, cap.w, cap.h); + + const predBounds = polygonBounds01(res.polygon); + const leakLeftPx = predBounds.minX < paintBounds.minX ? (paintBounds.minX - predBounds.minX) * cap.w : 0; + const leakRightPx = predBounds.maxX > paintBounds.maxX ? (predBounds.maxX - paintBounds.maxX) * cap.w : 0; + const leakTopPx = predBounds.minY < paintBounds.minY ? (paintBounds.minY - predBounds.minY) * cap.h : 0; + const leakBottomPx = predBounds.maxY > paintBounds.maxY ? (predBounds.maxY - paintBounds.maxY) * cap.h : 0; + const paintLeakPx = Math.max(leakLeftPx, leakRightPx, leakTopPx, leakBottomPx); + + return { + anchor: thAnchor, + tol: threshold.tolerance ?? tol, + opts, + metrics, + boundary, + paintLeakPx, + paintDistMeanPx, + paintDistP95Px, + paintDistMaxPx, + }; + } catch { + return null; + } + }; + + const yieldToUi = () => new Promise((resolve) => window.setTimeout(resolve, 0)); + + const stats = { + evals: { + stage1TolSweep: 0, + stage2ParamTune: 0, + stage3TolRefine: 0, + stage4PolyTune: 0, + total: 0, + }, + ms: { + stage1TolSweep: 0, + stage2ParamTune: 0, + stage3TolRefine: 0, + stage4PolyTune: 0, + total: 0, + }, + }; + + busyRef.current = true; + setBusy(true); + setError(null); + + try { + const baselineOpts = tunedOptions ?? undefined; + + // Stage 1: sweep (anchor, tolerance) around the current estimate. + setAutoTuneStatus({ running: true, message: 'Auto-tune: sweeping anchor+tolerance…' }); + + let best: Candidate | null = null; + + const anchorCandidates: number[] = []; + // Include both parities. Otherwise, when `anchor` is odd we'd only test odd anchors (and vice versa), + // which can miss materially better solutions (e.g. true best anchor at anchor±7). + for (let a = anchor - 12; a <= anchor + 12; a += 2) { + anchorCandidates.push(Math.max(0, Math.min(255, a))); + } + for (let a = anchor - 11; a <= anchor + 11; a += 2) { + anchorCandidates.push(Math.max(0, Math.min(255, a))); + } + const uniqAnchor = Array.from(new Set(anchorCandidates)).sort((a, b) => a - b); + + const tolCandidates: number[] = []; + for (let t = baseTol - 30; t <= baseTol + 30; t += 2) { + tolCandidates.push(Math.max(0, Math.min(127, t))); + } + // Ensure unique + deterministic. + const uniqTol = Array.from(new Set(tolCandidates)).sort((a, b) => a - b); + + const stage1Total = uniqAnchor.length * uniqTol.length; + let stage1Done = 0; + + const stage1Start = performance.now(); + for (const a of uniqAnchor) { + for (const tol of uniqTol) { + const cand = evalCandidate(a, tol, baselineOpts); + stats.evals.stage1TolSweep++; + stats.evals.total++; + stage1Done++; + + if (cand && (!best || isBetter(cand, best))) best = cand; + + if (stage1Done % 60 === 0) { + setAutoTuneStatus({ + running: true, + message: `Auto-tune: sweeping anchor+tolerance… (${stage1Done}/${stage1Total})`, + }); + await yieldToUi(); + } + } + } + stats.ms.stage1TolSweep = performance.now() - stage1Start; + + if (!best) { + throw new Error('Auto-tune failed: no valid segmentations produced.'); + } + + // Stage 2: parameter tuning. + // + // IMPORTANT: Full grid search is combinatorially expensive and can freeze the UI. + // We use a small, deterministic coordinate-descent search instead. + // + // Key quality improvement: couple tolerance with parameter updates by evaluating a small + // local tolerance window per candidate. + setAutoTuneStatus({ running: true, message: 'Auto-tune: tuning parameters…' }); + + const stage2Start = performance.now(); + const optsKey = (o: SegmentTumorOptions | undefined) => JSON.stringify(o ?? null); + const stage2StartOptsKey = optsKey(best.opts); + + // Bias the search toward tighter distance gating. + // This matters a lot for FLAIR-like cases where leakage can explode FP. + const baseMins = [2, 4, 8]; + const paintFactors = [0.25, 0.35, 0.6]; + const widthFactors = [0.05, 0.1, 0.2]; + + const maxDistTriples: Array<{ baseMin: number; paintScaleFactor: number; thresholdWidthFactor: number }> = []; + for (const baseMin of baseMins) { + for (const paintScaleFactor of paintFactors) { + for (const thresholdWidthFactor of widthFactors) { + maxDistTriples.push({ baseMin, paintScaleFactor, thresholdWidthFactor }); + } + } + } + + const openIters = [0, 1]; + const closeIters = [0, 1, 2]; + const adaptiveFlags = [false, true]; + + // Asymmetric tolerance can reduce leakage when only one side is problematic. + const tolLowScales = [0.6, 0.8, 1, 1.25]; + const tolHighScales = [0.6, 0.8, 1, 1.25]; + + // Include mild values near 1.0 so we can get "just a bit" tighter without increasing FN. + const distTolScaleMins = [1, 0.85, 0.7, 0.55, 0.4, 0.25]; + + // Include a gentle edge penalty; 0.35/0.65 were sometimes too coarse. + const edgePenaltyStrengths = [0, 0.15, 0.35, 0.55]; + + // When options change, the best tolerance can move a lot. Use wider offsets so parameter tuning + // can "pull" the search toward a different tolerance regime. + const tolOffsets = [-16, -8, 0, 8, 16]; + const uniqTolAround = (center: number) => { + const out: number[] = []; + for (const off of tolOffsets) { + out.push(Math.max(0, Math.min(127, Math.round(center + off)))); + } + return Array.from(new Set(out)).sort((a, b) => a - b); + }; + + // If Stage 1 found a good (anchor,tol) using baselineOpts, Stage 2 can still need to "jump" + // to a different anchor once maxDist/morph/asymmetry changes. + const anchorOffsets = [-12, 0, 12]; + const uniqAnchorAround = (center: number) => { + const out: number[] = []; + for (const off of anchorOffsets) { + out.push(Math.max(0, Math.min(255, Math.round(center + off)))); + } + return Array.from(new Set(out)).sort((a, b) => a - b); + }; + + // One pass is usually enough once we allow anchor to move; keep it fast. + const PASSES = 1; + const totalUpperBound = + PASSES * + (maxDistTriples.length + + openIters.length + + closeIters.length + + adaptiveFlags.length + + tolLowScales.length + + tolHighScales.length + + distTolScaleMins.length + + edgePenaltyStrengths.length) * + tolOffsets.length * + anchorOffsets.length; + + let idx = 0; + const maybeYield = async () => { + if (idx % 32 === 0) { + setAutoTuneStatus({ + running: true, + message: `Auto-tune: tuning parameters… (${idx}/${totalUpperBound})`, + }); + await yieldToUi(); + } + }; + + const tryUpdate = async (opts: SegmentTumorOptions | undefined) => { + if (!best) return; + + const centerTol = best.tol; + const tols = uniqTolAround(centerTol); + + const centerAnchor = best.anchor; + const anchors = uniqAnchorAround(centerAnchor); + + let localBest: Candidate | null = null; + for (const anchorCand of anchors) { + for (const tol of tols) { + const cand = evalCandidate(anchorCand, tol, opts); + stats.evals.stage2ParamTune++; + stats.evals.total++; + idx++; + + if (cand && (!localBest || isBetter(cand, localBest))) localBest = cand; + await maybeYield(); + } + } + + if (localBest && isBetter(localBest, best)) best = localBest; + }; + + for (let pass = 0; pass < PASSES; pass++) { + // 2a) Tune the distance gate triple. + for (const t of maxDistTriples) { + await tryUpdate({ ...(best.opts ?? {}), maxDistToPaint: t }); + } + + // 2b) Tune morphology. + for (const morphologicalOpenIterations of openIters) { + await tryUpdate({ ...(best.opts ?? {}), morphologicalOpenIterations }); + } + for (const morphologicalCloseIterations of closeIters) { + await tryUpdate({ ...(best.opts ?? {}), morphologicalCloseIterations }); + } + + // 2c) Tune adaptive flag (edge penalty is disabled for adaptive candidates). + for (const adaptiveEnabled of adaptiveFlags) { + await tryUpdate({ + ...(best.opts ?? {}), + adaptiveEnabled, + edgePenaltyStrength: adaptiveEnabled ? 0 : (best.opts?.edgePenaltyStrength ?? 0), + }); + } + + // 2d) Tune asymmetric tolerance. + for (const toleranceLowScale of tolLowScales) { + await tryUpdate({ ...(best.opts ?? {}), toleranceLowScale }); + } + for (const toleranceHighScale of tolHighScales) { + await tryUpdate({ ...(best.opts ?? {}), toleranceHighScale }); + } + + // 2e) Tune soft distance penalty. + for (const distanceToleranceScaleMin of distTolScaleMins) { + await tryUpdate({ ...(best.opts ?? {}), distanceToleranceScaleMin }); + } + + // 2f) Tune edge penalty (only relevant when adaptive is off). + if (best.opts?.adaptiveEnabled !== true) { + for (const edgePenaltyStrength of edgePenaltyStrengths) { + await tryUpdate({ ...(best.opts ?? {}), edgePenaltyStrength }); + } + } + } + stats.ms.stage2ParamTune = performance.now() - stage2Start; + + // Stage 3: refine anchor + tolerance around the best. + // + // Why: + // - Stage 1 chooses (anchor,tol) using baselineOpts. + // - Stage 2 may change opts materially (distance gating / morphology / asymmetry), which can shift + // the best (anchor,tol) pair. + // - Anchor/tolerance interact, so do a small local 2D search (instead of independent 1D passes). + setAutoTuneStatus({ running: true, message: 'Auto-tune: refining threshold…' }); + + const stage3Start = performance.now(); + const stage2EndOptsKey = optsKey(best.opts); + const tolRefineRadius = stage2EndOptsKey === stage2StartOptsKey ? 8 : 20; + const anchorRefineRadius = stage2EndOptsKey === stage2StartOptsKey ? 4 : 10; + + const tolStep = tolRefineRadius > 12 ? 2 : 1; + const anchorStep = anchorRefineRadius > 6 ? 2 : 1; + + for (let a = best.anchor - anchorRefineRadius; a <= best.anchor + anchorRefineRadius; a += anchorStep) { + const anchorCand = Math.max(0, Math.min(255, a)); + + for (let t = best.tol - tolRefineRadius; t <= best.tol + tolRefineRadius; t += tolStep) { + const tol = Math.max(0, Math.min(127, t)); + const cand = evalCandidate(anchorCand, tol, best.opts); + stats.evals.stage3TolRefine++; + stats.evals.total++; + + if (cand && isBetter(cand, best)) best = cand; + } + } + + stats.ms.stage3TolRefine = performance.now() - stage3Start; + + // Stage 4: tune display-side smoothing/simplification. + setAutoTuneStatus({ running: true, message: 'Auto-tune: tuning polygon smoothing…' }); + + const stage4Start = performance.now(); + const smoothingCandidates = [0, 1, 2]; + const epsCandidates = [0.0003, 0.0005, 0.0008, 0.0012, 0.0018, 0.0024]; + + for (const smoothingIterations of smoothingCandidates) { + for (const simplifyEpsilon of epsCandidates) { + const opts: SegmentTumorOptions = { + ...(best.opts ?? {}), + smoothingIterations, + simplifyEpsilon, + }; + const cand = evalCandidate(best.anchor, best.tol, opts); + stats.evals.stage4PolyTune++; + stats.evals.total++; + + if (cand && isBetter(cand, best)) best = cand; + } + } + stats.ms.stage4PolyTune = performance.now() - stage4Start; + + stats.ms.total = performance.now() - stage1Start; + setAutoTuneLastStats(stats); + setAutoTuneLastBest({ + anchor: best.anchor, + tol: best.tol, + opts: best.opts, + metrics: best.metrics, + boundary: best.boundary, + paintLeakPx: best.paintLeakPx, + paintDistMeanPx: best.paintDistMeanPx, + paintDistP95Px: best.paintDistP95Px, + paintDistMaxPx: best.paintDistMaxPx, + }); + + setAutoTuneStatus({ + running: false, + message: `Auto-tune done (F2 ${best.metrics.f2.toFixed(3)}, recall ${best.metrics.recall.toFixed(3)}).`, + }); + + console.log('[TumorOverlay] Auto-tune BEST', { + anchor: best.anchor, + tol: best.tol, + opts: best.opts, + metrics: best.metrics, + boundary: best.boundary, + paintLeakPx: best.paintLeakPx, + paintDistMeanPx: best.paintDistMeanPx, + paintDistP95Px: best.paintDistP95Px, + paintDistMaxPx: best.paintDistMaxPx, + }); + console.log('[TumorOverlay] Auto-tune STATS', stats); + + setTunedOptions(best.opts ?? null); + setThresholdAnchor(best.anchor); + setThresholdTolerance(best.tol); + + computeDraftFromCurrentCapture(mkThreshold(best.anchor, best.tol), best.opts); + } catch (e) { + console.error('[TumorOverlay] Auto-tune failed:', e); + setAutoTuneStatus({ running: false, message: 'Auto-tune failed (see console).' }); + setError(e instanceof Error ? e.message : 'Auto-tune failed'); + } finally { + busyRef.current = false; + setBusy(false); + } + }, [ + autoTuneStatus.running, + computeDraftFromCurrentCapture, + draftThreshold, + groundTruthPolygon, + groundTruthPolygonViewTransform, + paintPoints, + thresholdAnchor, + thresholdTolerance, + tunedOptions, + ]); + + if (!enabled) return null; + + return ( +
+ {/* UI chrome */} + {/* + Position below the viewer's top hover controls (Tumor/GT buttons + ImageControls). + Otherwise it visually overlaps the control bar in GridView/OverlayView. + */} +
+
+ + Tumor + {busy ? : null} +
+ + +
+ + {/* Threshold + save controls (only after painting / draft segmentation exists) */} + {draftPolygon && containerSize.w > 0 && containerSize.h > 0 + ? (() => { + // Anchor threshold controls next to the *user-painted area*, not the draft polygon. + // The polygon can change substantially as the threshold moves, but the painted region + // is the user's mental anchor for where they were working. + const bbox = paintPointsDisplay.length + ? pointsBounds01(paintPointsDisplay) + : polygonBounds01(draftPolygonDisplay ?? draftPolygon); + + // Place the control panel next to the painted region, preferring the right side. + const panelWidth = 176; + const sliderHeight = Math.max( + 120, + Math.min(260, Math.round((bbox.maxY - bbox.minY) * containerSize.h)) + ); + + const gtSectionHeight = groundTruthPolygon ? 175 : 0; + const panelHeight = sliderHeight + 78 + gtSectionHeight; + + const minXpx = bbox.minX * containerSize.w; + const maxXpx = bbox.maxX * containerSize.w; + const minYpx = bbox.minY * containerSize.h; + + let left = maxXpx + 12; + if (left + panelWidth > containerSize.w - 8) { + left = minXpx - panelWidth - 12; + } + left = Math.max(8, Math.min(containerSize.w - panelWidth - 8, left)); + + let top = minYpx; + top = Math.max(8, Math.min(containerSize.h - panelHeight - 8, top)); + + return ( +
+ {/* Vertical threshold slider */} +
+ setThresholdTolerance(parseInt(e.target.value, 10))} + className="absolute" + style={{ + width: sliderHeight, + transform: 'rotate(-90deg)', + transformOrigin: 'center', + }} + aria-label="Tolerance" + /> +
+ +
+ {effectiveThresholdFromSlider.low}–{effectiveThresholdFromSlider.high} +
+ + + + {savedSeed && savedThreshold ? ( + + ) : null} + + {groundTruthPolygon ? ( +
+
+
+ GT Eval + {tunedOptions ? (tuned) : null} +
+ + +
+ + {gtMetrics ? ( +
+
+ F2 + {gtMetrics.f2.toFixed(3)} +
+
+ Recall + {gtMetrics.recall.toFixed(3)} +
+
+ Prec + {gtMetrics.precision.toFixed(3)} +
+
+ IoU + {gtMetrics.iou.toFixed(3)} +
+
+ Dice + {gtMetrics.dice.toFixed(3)} +
+ + {gtBoundaryMetrics ? ( + <> +
+ Bnd out μ + {gtBoundaryMetrics.meanPredToGtPx.toFixed(2)} px +
+
+ Bnd in μ + {gtBoundaryMetrics.meanGtToPredPx.toFixed(2)} px +
+
+ Bnd μ + {gtBoundaryMetrics.meanSymPx.toFixed(2)} px +
+
+ Bnd max + {gtBoundaryMetrics.maxSymPx.toFixed(1)} px +
+ + ) : null} + +
+ FN + {gtMetrics.fn} +
+
+ FP + {gtMetrics.fp} +
+
+ ) : ( +
Paint + segment to evaluate vs GT.
+ )} + +
+ + + + + {import.meta.env.DEV ? ( + + ) : null} + + + + +
+ + {autoTuneStatus.message ? ( +
{autoTuneStatus.message}
+ ) : null} + {gtBenchmarkStatus.message ? ( +
{gtBenchmarkStatus.message}
+ ) : null} + {harnessExportStatus.message ? ( +
{harnessExportStatus.message}
+ ) : null} +
+ ) : null} +
+ ); + })() + : null} + + {/* Error / status */} + {error ? ( +
+ {error} +
+ ) : propStatus.message ? ( +
+ {propStatus.message} +
+ ) : !draftPolygon ? ( +
+ Click and drag to paint the tumor region. +
+ ) : null} + + {/* GT diff overlay (FN red, FP magenta) */} + {groundTruthPolygon && diffOverlayEnabled ? ( + + ) : null} + + {/* Paint stroke preview (transparent pink brush) */} + {paintPointsDisplay.length > 1 ? ( + + `${p.x.toFixed(4)},${p.y.toFixed(4)}`).join(' ')} + fill="none" + stroke="rgba(236, 72, 153, 0.55)" + strokeWidth={Math.max(2, Math.round(Math.min(containerSize.w, containerSize.h) * 0.02))} + strokeLinecap="round" + strokeLinejoin="round" + vectorEffect="non-scaling-stroke" + /> + + ) : null} + + {/* Ground truth polygon (debug) */} + {groundTruthPath ? ( + + + + ) : null} + + {/* Saved polygon */} + {savedPath ? ( + + + + ) : null} + + {/* Draft polygon (overlays saved) */} + {draftPath ? ( + + + + ) : null} +
+ ); +} diff --git a/frontend/src/components/comparison/ComparisonFiltersSidebar.tsx b/frontend/src/components/comparison/ComparisonFiltersSidebar.tsx index 822f7b9..63a2c08 100644 --- a/frontend/src/components/comparison/ComparisonFiltersSidebar.tsx +++ b/frontend/src/components/comparison/ComparisonFiltersSidebar.tsx @@ -61,29 +61,31 @@ export function ComparisonFiltersSidebar({
Sequence
- {sequencesForPlane.map((seq) => { - const hasData = sequencesWithDataForDates.has(seq.id); - const isSelected = selectedSeqId === seq.id; + {sequencesForPlane + .filter((seq) => formatSequenceLabel(seq) !== 'Unknown') + .map((seq) => { + const hasData = sequencesWithDataForDates.has(seq.id); + const isSelected = selectedSeqId === seq.id; - return ( - - ); - })} + return ( + + ); + })}
diff --git a/frontend/src/components/comparison/GridCell.tsx b/frontend/src/components/comparison/GridCell.tsx new file mode 100644 index 0000000..75f0b71 --- /dev/null +++ b/frontend/src/components/comparison/GridCell.tsx @@ -0,0 +1,235 @@ +import { useState, useRef } from 'react'; +import { Pencil, Sparkles } from 'lucide-react'; +import type { AlignmentReference, ExclusionMask, PanelSettings, SeriesRef } from '../../types/api'; +import { formatDate } from '../../utils/format'; +import { getSliceIndex, getEffectiveInstanceIndex, getProgressFromSlice } from '../../utils/math'; +import { ImageControls } from '../ImageControls'; +import { StepControl } from '../StepControl'; +import { DragRectActionOverlay } from '../DragRectActionOverlay'; +import { DicomViewer, type DicomViewerHandle } from '../DicomViewer'; +import { GroundTruthPolygonOverlay } from '../GroundTruthPolygonOverlay'; +import { TumorSegmentationOverlay } from '../TumorSegmentationOverlay'; + +export type GridCellProps = { + comboId: string; + date: string; + refData: SeriesRef | undefined; + settings: PanelSettings; + progress: number; + setProgress: (next: number) => void; + updatePanelSetting: (date: string, update: Partial) => void; + + isHovered: boolean; + + overlayColumns: { date: string; ref?: SeriesRef }[]; + isAligning: boolean; + + startAlignAll: (reference: AlignmentReference, exclusion: ExclusionMask) => Promise; +}; + +export function GridCell({ + comboId, + date, + refData, + settings, + progress, + setProgress, + updatePanelSetting, + isHovered, + overlayColumns, + isAligning, + startAlignAll, +}: GridCellProps) { + const [tumorToolOpen, setTumorToolOpen] = useState(false); + const [gtPolygonToolOpen, setGtPolygonToolOpen] = useState(false); + const tumorViewerRef = useRef(null); + + if (!refData) { + return ( +
+
+ {formatDate(date)} +
+
No series
+
+ ); + } + + const idx = getSliceIndex(refData.instance_count, progress, settings.offset); + const effectiveIdx = getEffectiveInstanceIndex(idx, refData.instance_count, settings.reverseSliceOrder); + + return ( +
+ {/* Cell controls (shown on hover) */} +
+
+
+ + + +
+ + { + updatePanelSetting(date, update); + }} + showSliceControl={false} + /> +
+
+ + {/* Slice selector (shown on hover, bottom-right corner) */} +
e.stopPropagation()} + > +
+ { + updatePanelSetting(date, { offset: settings.offset - 1 }); + }} + onIncrement={() => { + updatePanelSetting(date, { offset: settings.offset + 1 }); + }} + /> +
+
+ +
+ { + void startAlignAll( + { + date, + seriesUid: refData.series_uid, + sliceIndex: effectiveIdx, + sliceCount: refData.instance_count, + settings, + }, + mask + ); + }} + actionTitle={`Align all other dates to ${formatDate(date)}`} + > + { + setProgress(getProgressFromSlice(i, refData.instance_count, settings.offset)); + }} + brightness={settings.brightness} + contrast={settings.contrast} + zoom={settings.zoom} + rotation={settings.rotation} + panX={settings.panX} + panY={settings.panY} + affine00={settings.affine00} + affine01={settings.affine01} + affine10={settings.affine10} + affine11={settings.affine11} + onPanChange={(newPanX, newPanY) => { + updatePanelSetting(date, { panX: newPanX, panY: newPanY }); + }} + /> + + setTumorToolOpen(false)} + viewerRef={tumorViewerRef} + comboId={comboId} + dateIso={date} + studyId={refData.study_id} + seriesUid={refData.series_uid} + effectiveInstanceIndex={effectiveIdx} + viewerTransform={settings} + /> + + setGtPolygonToolOpen(false)} + comboId={comboId} + dateIso={date} + studyId={refData.study_id} + seriesUid={refData.series_uid} + effectiveInstanceIndex={effectiveIdx} + viewerTransform={settings} + /> + + {/* Date overlay (matches overlay view style) */} +
+ {formatDate(date)} +
+
+
+
+ ); +} diff --git a/frontend/src/components/comparison/GridView.tsx b/frontend/src/components/comparison/GridView.tsx index 156e631..a213459 100644 --- a/frontend/src/components/comparison/GridView.tsx +++ b/frontend/src/components/comparison/GridView.tsx @@ -1,6 +1,7 @@ import { useCallback, useState } from 'react'; import type { MouseEvent } from 'react'; import { Loader2 } from 'lucide-react'; +import { GridCell } from './GridCell'; import type { AlignmentProgress, AlignmentReference, @@ -10,13 +11,10 @@ import type { } from '../../types/api'; import { formatDate } from '../../utils/format'; import { DEFAULT_PANEL_SETTINGS } from '../../utils/constants'; -import { getSliceIndex, getEffectiveInstanceIndex, getProgressFromSlice } from '../../utils/math'; -import { ImageControls } from '../ImageControls'; -import { StepControl } from '../StepControl'; -import { DragRectActionOverlay } from '../DragRectActionOverlay'; -import { DicomViewer } from '../DicomViewer'; export type GridViewProps = { + comboId: string; + columns: { date: string; ref?: SeriesRef }[]; gridCols: number; gridCellSize: number; @@ -32,6 +30,7 @@ export type GridViewProps = { }; export function GridView({ + comboId, columns, gridCols, gridCellSize, @@ -80,7 +79,7 @@ export function GridView({
{alignmentProgress.phase !== 'capturing' && alignmentProgress.slicesChecked ? (
- {alignmentProgress.slicesChecked} slices · MI {alignmentProgress.bestMiSoFar.toFixed(3)} + {alignmentProgress.slicesChecked} slices · Score {alignmentProgress.bestMiSoFar.toFixed(3)}
) : null}
@@ -108,131 +107,23 @@ export function GridView({ > {columns.map(({ date, ref }) => { const settings = panelSettings.get(date) || DEFAULT_PANEL_SETTINGS; - - if (!ref) { - return ( -
-
- {formatDate(date)} -
-
No series
-
- ); - } - - const idx = getSliceIndex(ref.instance_count, progress, settings.offset); - const effectiveIdx = getEffectiveInstanceIndex(idx, ref.instance_count, settings.reverseSliceOrder); - const isHovered = hoveredGridCellDate === date; return ( -
- {/* Cell controls (shown on hover) */} -
-
- { - updatePanelSetting(date, update); - }} - showSliceControl={false} - /> -
-
- - {/* Slice selector (shown on hover, bottom-right corner) */} -
e.stopPropagation()} - > -
- { - updatePanelSetting(date, { offset: settings.offset - 1 }); - }} - onIncrement={() => { - updatePanelSetting(date, { offset: settings.offset + 1 }); - }} - /> -
-
- -
- { - void startAlignAll( - { - date, - seriesUid: ref.series_uid, - sliceIndex: effectiveIdx, - sliceCount: ref.instance_count, - settings, - }, - mask - ); - }} - actionTitle={`Align all other dates to ${formatDate(date)}`} - > - { - setProgress(getProgressFromSlice(i, ref.instance_count, settings.offset)); - }} - brightness={settings.brightness} - contrast={settings.contrast} - zoom={settings.zoom} - rotation={settings.rotation} - panX={settings.panX} - panY={settings.panY} - affine00={settings.affine00} - affine01={settings.affine01} - affine10={settings.affine10} - affine11={settings.affine11} - onPanChange={(newPanX, newPanY) => { - updatePanelSetting(date, { panX: newPanX, panY: newPanY }); - }} - /> - - {/* Date overlay (matches overlay view style) */} -
- {formatDate(date)} -
-
-
-
+ comboId={comboId} + date={date} + refData={ref} + settings={settings} + progress={progress} + setProgress={setProgress} + updatePanelSetting={updatePanelSetting} + isHovered={isHovered} + overlayColumns={overlayColumns} + isAligning={isAligning} + startAlignAll={startAlignAll} + /> ); })} {columns.length === 0 &&
Select dates to view
} diff --git a/frontend/src/components/comparison/OverlayView.tsx b/frontend/src/components/comparison/OverlayView.tsx index 7c81538..708ccd9 100644 --- a/frontend/src/components/comparison/OverlayView.tsx +++ b/frontend/src/components/comparison/OverlayView.tsx @@ -1,5 +1,5 @@ -import { useState } from 'react'; -import { Loader2 } from 'lucide-react'; +import { useEffect, useRef, useState } from 'react'; +import { Loader2, Pencil, Sparkles } from 'lucide-react'; import type { AlignmentProgress, AlignmentReference, @@ -8,13 +8,17 @@ import type { SeriesRef, } from '../../types/api'; import { formatDate } from '../../utils/format'; -import { getProgressFromSlice } from '../../utils/math'; +import { getEffectiveInstanceIndex, getProgressFromSlice } from '../../utils/math'; import { ImageControls } from '../ImageControls'; import { StepControl } from '../StepControl'; import { DragRectActionOverlay } from '../DragRectActionOverlay'; -import { DicomViewer } from '../DicomViewer'; +import { DicomViewer, type DicomViewerHandle } from '../DicomViewer'; +import { GroundTruthPolygonOverlay } from '../GroundTruthPolygonOverlay'; +import { TumorSegmentationOverlay } from '../TumorSegmentationOverlay'; export type OverlayViewProps = { + comboId: string; + overlayColumns: { date: string; ref?: SeriesRef }[]; overlayViewerSize: number; @@ -47,6 +51,7 @@ export type OverlayViewProps = { }; export function OverlayView({ + comboId, overlayColumns, overlayViewerSize, overlayDisplayedRef, @@ -72,6 +77,32 @@ export function OverlayView({ setProgress, }: OverlayViewProps) { const [isOverlayViewerHovered, setIsOverlayViewerHovered] = useState(false); + const [tumorToolOpen, setTumorToolOpen] = useState(false); + const [gtPolygonToolOpen, setGtPolygonToolOpen] = useState(false); + const tumorViewerRef = useRef(null); + + // Compare mode is read-only: ensure the tumor tool isn't active. + // We schedule the close to avoid calling setState synchronously inside the effect body. + useEffect(() => { + if (!isOverlayComparing) return; + + const t = window.setTimeout(() => { + setTumorToolOpen(false); + setGtPolygonToolOpen(false); + }, 0); + + return () => window.clearTimeout(t); + }, [isOverlayComparing]); + + // Note: the tool only operates on the *selected* date when not comparing. + const tumorEffectiveSliceIndex = + overlaySelectedRef && overlaySelectedDate + ? getEffectiveInstanceIndex( + overlaySelectedSliceIndex, + overlaySelectedRef.instance_count, + overlaySelectedSettings.reverseSliceOrder + ) + : 0; return (
@@ -94,7 +125,49 @@ export function OverlayView({ : 'opacity-0 pointer-events-none' }`} > -
+
+
+ + + +
+ { void startAlignAll( { @@ -175,7 +248,9 @@ export function OverlayView({ className={`absolute inset-0 ${isOverlayComparing ? 'opacity-0 pointer-events-none' : 'opacity-100'}`} > {overlaySelectedRef && overlaySelectedDate ? ( - + + + setTumorToolOpen(false)} + viewerRef={tumorViewerRef} + comboId={comboId} + dateIso={overlaySelectedDate} + studyId={overlaySelectedRef.study_id} + seriesUid={overlaySelectedRef.series_uid} + effectiveInstanceIndex={tumorEffectiveSliceIndex} + viewerTransform={overlaySelectedSettings} + /> + + setGtPolygonToolOpen(false)} + comboId={comboId} + dateIso={overlaySelectedDate} + studyId={overlaySelectedRef.study_id} + seriesUid={overlaySelectedRef.series_uid} + effectiveInstanceIndex={tumorEffectiveSliceIndex} + viewerTransform={overlaySelectedSettings} + /> + ) : null}
@@ -255,7 +354,7 @@ export function OverlayView({
{alignmentProgress.phase !== 'capturing' && alignmentProgress.slicesChecked ? (
- {alignmentProgress.slicesChecked} slices · MI {alignmentProgress.bestMiSoFar.toFixed(3)} + {alignmentProgress.slicesChecked} slices · Score {alignmentProgress.bestMiSoFar.toFixed(3)}
) : null}
diff --git a/frontend/src/db/db.ts b/frontend/src/db/db.ts index 19bba6f..307792e 100644 --- a/frontend/src/db/db.ts +++ b/frontend/src/db/db.ts @@ -3,7 +3,7 @@ import type { IDBPDatabase } from 'idb'; import type { MiraDB } from './schema'; const DB_NAME = 'MiraViewerDB'; -const DB_VERSION = 2; +const DB_VERSION = 4; let dbPromise: Promise> | null = null; @@ -80,6 +80,40 @@ export function getDB() { if (!db.objectStoreNames.contains('panel_settings')) { db.createObjectStore('panel_settings', { keyPath: 'comboId' }); } + + // Tumor segmentations + { + const segStore = db.objectStoreNames.contains('tumor_segmentations') + ? transaction.objectStore('tumor_segmentations') + : db.createObjectStore('tumor_segmentations', { keyPath: 'id' }); + + if (!segStore.indexNames.contains('by-series')) { + segStore.createIndex('by-series', 'seriesUid'); + } + if (!segStore.indexNames.contains('by-sop')) { + segStore.createIndex('by-sop', 'sopInstanceUid'); + } + if (!segStore.indexNames.contains('by-combo-date')) { + segStore.createIndex('by-combo-date', ['comboId', 'dateIso']); + } + } + + // Tumor ground truth (manual polygon) + { + const gtStore = db.objectStoreNames.contains('tumor_ground_truth') + ? transaction.objectStore('tumor_ground_truth') + : db.createObjectStore('tumor_ground_truth', { keyPath: 'id' }); + + if (!gtStore.indexNames.contains('by-series')) { + gtStore.createIndex('by-series', 'seriesUid'); + } + if (!gtStore.indexNames.contains('by-sop')) { + gtStore.createIndex('by-sop', 'sopInstanceUid'); + } + if (!gtStore.indexNames.contains('by-combo-date')) { + gtStore.createIndex('by-combo-date', ['comboId', 'dateIso']); + } + } }, }); } diff --git a/frontend/src/db/schema.ts b/frontend/src/db/schema.ts index b65bcab..bf33d41 100644 --- a/frontend/src/db/schema.ts +++ b/frontend/src/db/schema.ts @@ -8,6 +8,127 @@ export interface DicomStudy { accessionNumber?: string; } +export type NormalizedPoint = { x: number; y: number }; + +// Viewport size in CSS pixels when the user authored an overlay. +// +// This is needed to correctly re-project viewer-normalized points/polygons into image coordinates +// because the "contain" mapping depends on both viewport size and image aspect ratio. +export type ViewportSize = { w: number; h: number }; + +export type ViewerTransform = { + /** Zoom factor (1 = 100%). */ + zoom: number; + /** Rotation in degrees. */ + rotation: number; + /** Normalized pan (fraction of viewport width). */ + panX: number; + /** Normalized pan (fraction of viewport height). */ + panY: number; + + /** Hidden affine residual (shear / anisotropic scale), row-major 2x2. */ + affine00: number; + affine01: number; + affine10: number; + affine11: number; +}; + +export type TumorPolygon = { + /** + * Polygon points in normalized viewer coordinates. + * + * IMPORTANT: + * These points are stored in the viewer's coordinate system at the time they were created. + * To render them correctly under a different pan/zoom/rotation/affine, re-project using the + * saved `viewTransform` metadata. + */ + points: NormalizedPoint[]; +}; + +export type TumorThreshold = { + /** Inclusive lower bound in segmentation pixel domain (typically 0..255). */ + low: number; + /** Inclusive upper bound in segmentation pixel domain (typically 0..255). */ + high: number; + + /** + * Optional fixed "anchor" intensity used when the UI operates in tolerance mode. + * + * Stored so the slider can stay monotonic (tolerance expands/contracts around a fixed anchor). + * Older rows may omit this. + */ + anchor?: number; + + /** + * Optional tolerance (half-width) around `anchor` (0..127-ish). + * Older rows may omit this. + */ + tolerance?: number; +}; + +export interface TumorSegmentationRow { + /** Stable ID (composite encoded). */ + id: string; + + /** Sequence combo id (plane+weight+sequence). */ + comboId: string; + /** ISO-ish date key used by the comparison view (see localApi date formatting). */ + dateIso: string; + + studyId: string; + seriesUid: string; + sopInstanceUid: string; + + /** Version for future algorithm migrations. */ + algorithmVersion: string; + + polygon: TumorPolygon; + threshold: TumorThreshold; + + /** Optional seed point used for region growing (normalized). */ + seed?: NormalizedPoint; + + createdAtMs: number; + updatedAtMs: number; + + meta?: { + areaPx?: number; + areaNorm?: number; + + /** Viewer transform at the time this polygon was saved (used to re-project overlays). */ + viewTransform?: ViewerTransform; + + /** Viewport size (CSS pixels) at the time this polygon was saved. */ + viewportSize?: ViewportSize; + }; +} + +export interface TumorGroundTruthRow { + /** Stable ID (composite encoded). */ + id: string; + + /** Sequence combo id (plane+weight+sequence). */ + comboId: string; + /** ISO-ish date key used by the comparison view (see localApi date formatting). */ + dateIso: string; + + studyId: string; + seriesUid: string; + sopInstanceUid: string; + + /** Manually drawn polygon points in normalized viewer coordinates. */ + polygon: TumorPolygon; + + /** Viewer transform at the time this polygon was saved (used to re-project overlays). */ + viewTransform?: ViewerTransform; + + /** Viewport size (CSS pixels) at the time this polygon was saved. */ + viewportSize?: ViewportSize; + + createdAtMs: number; + updatedAtMs: number; +} + export interface DicomSeries { seriesInstanceUid: string; studyInstanceUid: string; @@ -39,6 +160,7 @@ export interface DicomInstance { imageOrientationPatient?: string; // [rowX, rowY, rowZ, colX, colY, colZ] as string pixelSpacing?: string; // [row, col] as string sliceThickness?: number; + spacingBetweenSlices?: number; // Windowing windowCenter?: number; @@ -93,4 +215,23 @@ export interface MiraDB { key: string; // comboId value: PanelSettingsRow; }; + tumor_segmentations: { + key: string; // id + value: TumorSegmentationRow; + indexes: { + 'by-series': string; + 'by-sop': string; + 'by-combo-date': [string, string]; + }; + }; + + tumor_ground_truth: { + key: string; // id + value: TumorGroundTruthRow; + indexes: { + 'by-series': string; + 'by-sop': string; + 'by-combo-date': [string, string]; + }; + }; } diff --git a/frontend/src/hooks/useAutoAlign.ts b/frontend/src/hooks/useAutoAlign.ts index 72b9935..899e527 100644 --- a/frontend/src/hooks/useAutoAlign.ts +++ b/frontend/src/hooks/useAutoAlign.ts @@ -12,6 +12,10 @@ import { import { clamp, nowMs } from '../utils/math'; import { registerAffine2DWithElastix } from '../utils/elastixRegistration'; import { warpGrayscaleAffine } from '../utils/warpAffine'; +import { + computeGradientMagnitudeL1Square, + buildInclusionMaskFromThresholdSquare, +} from '../utils/imageFeatures'; import { affineAboutOriginToStandard, composeStandardAffine2D, @@ -23,6 +27,7 @@ import { type PanelGeometry, } from '../utils/panelTransform'; import { isDebugAlignmentEnabled, debugAlignmentLog } from '../utils/debugAlignment'; +import { recordAlignmentSliceScore, resetAlignmentSliceScoreStore } from '../utils/alignmentSliceScoreStore'; // Perf tuning for the MI-based slice search. // @@ -33,14 +38,49 @@ import { isDebugAlignmentEnabled, debugAlignmentLog } from '../utils/debugAlignm // // Type note: keep this typed as `number` (not a numeric literal) so we can compare it to // ALIGNMENT_IMAGE_SIZE without TS treating the comparison as always-false. -const SLICE_SEARCH_IMAGE_SIZE: number = 128; -const SLICE_SEARCH_MI_BINS: number = 32; -const SLICE_SEARCH_STOP_DECREASE_STREAK: number = 4; +const SLICE_SEARCH_IMAGE_SIZE: number = 512; + +// MI scoring tuning. +const SLICE_SEARCH_MI_BINS: number = 64; + +// Stop logic tuning. +const SLICE_SEARCH_STOP_DECREASE_STREAK: number = 3; +// Ensure we search far enough to avoid “off by ~5 slices” misses due to early noisy dips. +const SLICE_SEARCH_MIN_SEARCH_RADIUS: number = 5; + +// Keep UI responsive during heavy 512px scoring. +const SLICE_SEARCH_YIELD_EVERY_SLICES: number = 2; + +// Background suppression in scoring. +const SLICE_SEARCH_FOREGROUND_THRESHOLD: number = 0.02; + +// Gradient scoring (MI/NMI on grad magnitude). +// +// Note: currently disabled because grad magnitude was not correlating well with correct slice +// matches in practice. +const SLICE_SEARCH_GRADIENT_WEIGHT: number = 0; + +// Similarity metric used for coarse slice search. +const SLICE_SEARCH_SCORE_METRIC: 'ssim' | 'lncc' | 'zncc' | 'ngf' | 'census' | 'mind' | 'phase' = 'phase'; +const SLICE_SEARCH_SSIM_BLOCK_SIZE: number = 16; + +// Downsample sizes for more expensive metrics. +const SLICE_SEARCH_MIND_SIZE: number = 64; +const SLICE_SEARCH_PHASE_SIZE: number = 64; + +// Optional: constrain the slice search to a window around the best guess. +const SLICE_SEARCH_WINDOW_RADIUS: number = 40; // Registration perf tuning. // // User-requested: run single-pass registrations (no multi-resolution pyramid). This is the // fastest configuration but can reduce robustness on some inputs. +// +// Important: keep the Elastix seed registration at a smaller size than the 512px slice search. +// We previously attempted seed registration at 512 and observed failures in practice (progress +// flashing + immediate abort). The seed transform is only used to pre-warp candidates; it does +// not need to be computed at full slice-search resolution. +const SEED_REGISTRATION_IMAGE_SIZE: number = ALIGNMENT_IMAGE_SIZE; const SEED_REGISTRATION_RESOLUTIONS = 1; const REFINEMENT_REGISTRATION_RESOLUTIONS = 1; @@ -136,7 +176,11 @@ export function useAutoAlign() { // Single render element used for all captures. We also keep scratch canvases around to // avoid allocating a new + ImageData buffers on every slice capture. - const renderElement = createCornerstoneRenderElement(ALIGNMENT_IMAGE_SIZE); + // + // Important: The element size must be >= any target capture size, otherwise larger captures + // would just upsample a smaller source canvas ("fake 512"). + const renderElementSize = Math.max(ALIGNMENT_IMAGE_SIZE, SLICE_SEARCH_IMAGE_SIZE); + const renderElement = createCornerstoneRenderElement(renderElementSize); const captureScratchFull = createPixelCaptureScratch(ALIGNMENT_IMAGE_SIZE); const captureScratchSliceSearch = createPixelCaptureScratch(SLICE_SEARCH_IMAGE_SIZE); @@ -152,6 +196,12 @@ export function useAutoAlign() { debug: debugAlignment, }); + // In-memory store used by the per-cell debug overlay (SSIM/LNCC + MI/NMI breakdown). + resetAlignmentSliceScoreStore({ + referenceSeriesUid: reference.seriesUid, + referenceSliceIndex: reference.sliceIndex, + }); + if (!debugAlignment) { console.info( "[alignment] Tip: enable verbose logs with localStorage.setItem('miraviewer:debug-alignment', '1')" @@ -188,6 +238,42 @@ export function useAutoAlign() { const referencePixelsForSliceSearch = referenceRenderForSliceSearch.pixels; + // Slice-search feature prep (shared across all target dates). + const inclusion = buildInclusionMaskFromThresholdSquare( + referencePixelsForSliceSearch, + SLICE_SEARCH_IMAGE_SIZE, + SLICE_SEARCH_FOREGROUND_THRESHOLD, + { minIncludedFrac: 0.05 } + ); + const sliceSearchInclusionMask = inclusion?.mask; + + const referenceGradPixelsForSliceSearch = + SLICE_SEARCH_GRADIENT_WEIGHT !== 0 + ? computeGradientMagnitudeL1Square(referencePixelsForSliceSearch, SLICE_SEARCH_IMAGE_SIZE) + : null; + + if (debugAlignment) { + console.info('[alignment] Slice-search scoring config', { + sliceSearchImageSize: SLICE_SEARCH_IMAGE_SIZE, + scoreMetric: SLICE_SEARCH_SCORE_METRIC, + // SSIM/LNCC are the primary score used for bestIndex selection. + ssimBlockSize: SLICE_SEARCH_SSIM_BLOCK_SIZE, + // Downsample config for heavier metrics. + mindSize: SLICE_SEARCH_MIND_SIZE, + phaseSize: SLICE_SEARCH_PHASE_SIZE, + // MI/NMI are still computed in debug mode so we can compare metrics. + miBins: SLICE_SEARCH_MI_BINS, + stopDecreaseStreak: SLICE_SEARCH_STOP_DECREASE_STREAK, + minSearchRadius: SLICE_SEARCH_MIN_SEARCH_RADIUS, + yieldEverySlices: SLICE_SEARCH_YIELD_EVERY_SLICES, + foregroundThreshold: SLICE_SEARCH_FOREGROUND_THRESHOLD, + inclusionMask: inclusion + ? { includedFrac: Number(inclusion.includedFrac.toFixed(4)), includedCount: inclusion.includedCount } + : null, + gradientWeight: SLICE_SEARCH_GRADIENT_WEIGHT, + }); + } + const referenceDisplayedPixels = applyBrightnessContrastToPixels( referencePixels, reference.settings.brightness, @@ -239,19 +325,30 @@ export function useAutoAlign() { ); const startIdx = clamp(startIdxUnclamped, 0, Math.max(0, seriesRef.instance_count - 1)); + // Best initial guess: normalized index mapping from reference -> target. + const seedIdx = startIdx; + + const sliceSearchMinIndex = clamp(seedIdx - SLICE_SEARCH_WINDOW_RADIUS, 0, Math.max(0, seriesRef.instance_count - 1)); + const sliceSearchMaxIndex = clamp(seedIdx + SLICE_SEARCH_WINDOW_RADIUS, 0, Math.max(0, seriesRef.instance_count - 1)); + debugAlignmentLog( 'date.plan', { date, startIdx, + seedIdx, strategy: { // User-requested: seed the slice search with a coarse 2D affine transform. sliceSearchWarp: true, - seedImageSize: SLICE_SEARCH_IMAGE_SIZE, + seedImageSize: SEED_REGISTRATION_IMAGE_SIZE, seedResolutions: SEED_REGISTRATION_RESOLUTIONS, sliceSearchImageSize: SLICE_SEARCH_IMAGE_SIZE, sliceSearchMiBins: SLICE_SEARCH_MI_BINS, sliceSearchStopDecreaseStreak: SLICE_SEARCH_STOP_DECREASE_STREAK, + sliceSearchMinSearchRadius: SLICE_SEARCH_MIN_SEARCH_RADIUS, + sliceSearchWindowRadius: SLICE_SEARCH_WINDOW_RADIUS, + sliceSearchYieldEverySlices: SLICE_SEARCH_YIELD_EVERY_SLICES, + gradientWeight: SLICE_SEARCH_GRADIENT_WEIGHT, refinementImageSize: ALIGNMENT_IMAGE_SIZE, refinementResolutions: REFINEMENT_REGISTRATION_RESOLUTIONS, }, @@ -268,7 +365,12 @@ export function useAutoAlign() { seriesUid: seriesRef.series_uid, instanceCount: seriesRef.instance_count, startIdx, - seedImageSize: SLICE_SEARCH_IMAGE_SIZE, + seedIdx, + sliceSearchBounds: { + minIndex: sliceSearchMinIndex, + maxIndex: sliceSearchMaxIndex, + }, + seedImageSize: SEED_REGISTRATION_IMAGE_SIZE, refinementImageSize: ALIGNMENT_IMAGE_SIZE, resolutions: { seed: SEED_REGISTRATION_RESOLUTIONS, @@ -280,12 +382,11 @@ export function useAutoAlign() { // // This is used to pre-warp slices during the slice search so the similarity metric is // less dominated by in-plane pose differences. - const seedIdx = startIdx; console.info('[alignment] Seed registration starting', { date, seedIdx, - size: SLICE_SEARCH_IMAGE_SIZE, + size: SEED_REGISTRATION_IMAGE_SIZE, numberOfResolutions: SEED_REGISTRATION_RESOLUTIONS, }); @@ -293,29 +394,31 @@ export function useAutoAlign() { renderElement, seriesRef.series_uid, seedIdx, - SLICE_SEARCH_IMAGE_SIZE, - captureScratchSliceSearch + SEED_REGISTRATION_IMAGE_SIZE, + captureScratchFull ); const tSeed0 = nowMs(); - const seedReg = await registerAffine2DWithElastix( - referencePixelsForSliceSearch, - seedRender.pixels, - SLICE_SEARCH_IMAGE_SIZE, - { - numberOfResolutions: SEED_REGISTRATION_RESOLUTIONS, - webWorker: sharedWebWorker, - } - ); + const seedReg = await registerAffine2DWithElastix(referencePixels, seedRender.pixels, SEED_REGISTRATION_IMAGE_SIZE, { + numberOfResolutions: SEED_REGISTRATION_RESOLUTIONS, + webWorker: sharedWebWorker, + exclusionRect: reference.exclusionMask, + }); const seedRegistrationMs = nowMs() - tSeed0; sharedWebWorker = seedReg.webWorker; + // Seed transform is computed at SEED_REGISTRATION_IMAGE_SIZE, but applied to the + // slice-search grid (SLICE_SEARCH_IMAGE_SIZE). Translation is in pixels, so scale it. + const seedScale = SLICE_SEARCH_IMAGE_SIZE / SEED_REGISTRATION_IMAGE_SIZE; const seed: SeedRegistrationResult = { idx: seedIdx, nmi: seedReg.quality.nmi, transformA: seedReg.A, - transformT: seedReg.translatePx, + transformT: { + x: seedReg.translatePx.x * seedScale, + y: seedReg.translatePx.y * seedScale, + }, transformParameterObject: seedReg.transformParameterObject, }; @@ -359,7 +462,7 @@ export function useAutoAlign() { debugAlignment ); - // 2) Use the seed transform to drive a fast MI-based slice search. + // 2) Use the seed transform to drive a fast similarity-based slice search. // // We pre-warp each candidate slice by the seed transform before scoring against the // reference. This helps slice search focus on the through-plane match instead of @@ -403,9 +506,39 @@ export function useAutoAlign() { const onSliceScored = debugAlignment ? ( index: number, - metrics: { mi: number; nmi: number }, + metrics: { + ssim: number; + lncc: number; + zncc: number; + ngf: number; + census: number; + mind?: number; + phase?: number; + mi: number; + nmi: number; + score: number; + miGrad?: number; + nmiGrad?: number; + pixelsUsed?: number; + }, direction: 'start' | 'left' | 'right' ) => { + // Store per-slice metrics for UI debugging overlays. + recordAlignmentSliceScore(seriesRef.series_uid, index, { + ssim: metrics.ssim, + lncc: metrics.lncc, + zncc: metrics.zncc, + ngf: metrics.ngf, + census: metrics.census, + mind: metrics.mind ?? null, + phase: metrics.phase ?? null, + mi: metrics.mi, + nmi: metrics.nmi, + miGrad: metrics.miGrad ?? null, + nmiGrad: metrics.nmiGrad ?? null, + score: metrics.score, + }); + // Extremely verbose: log per-slice similarity metrics only when debug alignment is enabled. debugAlignmentLog( 'slice-search.score', @@ -413,8 +546,19 @@ export function useAutoAlign() { date, direction, index, + score: Number(metrics.score.toFixed(6)), + ssim: Number(metrics.ssim.toFixed(6)), + lncc: Number(metrics.lncc.toFixed(6)), + zncc: Number(metrics.zncc.toFixed(6)), + ngf: Number(metrics.ngf.toFixed(6)), + census: Number(metrics.census.toFixed(6)), + mind: metrics.mind != null ? Number(metrics.mind.toFixed(6)) : null, + phase: metrics.phase != null ? Number(metrics.phase.toFixed(6)) : null, mi: Number(metrics.mi.toFixed(6)), nmi: Number(metrics.nmi.toFixed(6)), + miGrad: metrics.miGrad != null ? Number(metrics.miGrad.toFixed(6)) : null, + nmiGrad: metrics.nmiGrad != null ? Number(metrics.nmiGrad.toFixed(6)) : null, + pixelsUsed: metrics.pixelsUsed ?? null, }, debugAlignment ); @@ -448,13 +592,27 @@ export function useAutoAlign() { }, { startIndexOverride: seedIdx, + minIndex: sliceSearchMinIndex, + maxIndex: sliceSearchMaxIndex, + scoreMetric: SLICE_SEARCH_SCORE_METRIC, + ssimBlockSize: SLICE_SEARCH_SSIM_BLOCK_SIZE, + mindSize: SLICE_SEARCH_MIND_SIZE, + phaseSize: SLICE_SEARCH_PHASE_SIZE, miBins: SLICE_SEARCH_MI_BINS, stopDecreaseStreak: SLICE_SEARCH_STOP_DECREASE_STREAK, + minSearchRadius: SLICE_SEARCH_MIN_SEARCH_RADIUS, + yieldEverySlices: SLICE_SEARCH_YIELD_EVERY_SLICES, + yieldFn: yieldToMain, onSliceScored, + inclusionMask: sliceSearchInclusionMask, // Pass exclusion mask for tumor avoidance. exclusionRect: reference.exclusionMask, imageWidth: SLICE_SEARCH_IMAGE_SIZE, imageHeight: SLICE_SEARCH_IMAGE_SIZE, + gradient: + referenceGradPixelsForSliceSearch && SLICE_SEARCH_GRADIENT_WEIGHT !== 0 + ? { referenceGradPixels: referenceGradPixelsForSliceSearch, weight: SLICE_SEARCH_GRADIENT_WEIGHT } + : undefined, } ); @@ -473,8 +631,9 @@ export function useAutoAlign() { console.info('[alignment] Slice search finished', { date, strategy: 'seeded', + scoreMetric: SLICE_SEARCH_SCORE_METRIC, bestIndex: searchResult.bestIndex, - bestMi: Number(searchResult.bestMI.toFixed(6)), + bestScore: Number(searchResult.bestMI.toFixed(6)), slicesChecked: searchResult.slicesChecked, }); @@ -483,9 +642,18 @@ export function useAutoAlign() { { date, strategy: 'seeded', + scoreMetric: SLICE_SEARCH_SCORE_METRIC, size: SLICE_SEARCH_IMAGE_SIZE, + ssimBlockSize: SLICE_SEARCH_SSIM_BLOCK_SIZE, bins: SLICE_SEARCH_MI_BINS, stopDecreaseStreak: SLICE_SEARCH_STOP_DECREASE_STREAK, + minSearchRadius: SLICE_SEARCH_MIN_SEARCH_RADIUS, + bounds: { + minIndex: sliceSearchMinIndex, + maxIndex: sliceSearchMaxIndex, + }, + yieldEverySlices: SLICE_SEARCH_YIELD_EVERY_SLICES, + gradientWeight: SLICE_SEARCH_GRADIENT_WEIGHT, slicesChecked: searchResult.slicesChecked, scoreMs: searchResult.timingMs?.scoreMs, renderMs: sliceSearchRenderMs, @@ -524,6 +692,7 @@ export function useAutoAlign() { const refined = await registerAffine2DWithElastix(referencePixels, bestRender.pixels, ALIGNMENT_IMAGE_SIZE, { numberOfResolutions: REFINEMENT_REGISTRATION_RESOLUTIONS, webWorker: sharedWebWorker, + exclusionRect: reference.exclusionMask, }); const refinementMs = nowMs() - tRefine0; diff --git a/frontend/src/hooks/useComparisonFilters.ts b/frontend/src/hooks/useComparisonFilters.ts index 093fedd..cb148f0 100644 --- a/frontend/src/hooks/useComparisonFilters.ts +++ b/frontend/src/hooks/useComparisonFilters.ts @@ -16,10 +16,17 @@ function normalizePlane(plane: string | null): string { return plane && plane.trim() ? plane : OTHER_PLANE; } +function isKnownSequence(seq: { weight: string | null; sequence: string | null }): boolean { + return Boolean((seq.weight && seq.weight.trim()) || (seq.sequence && seq.sequence.trim())); +} + function getAvailablePlanes(data: ComparisonData): string[] { - const hasOther = - data.planes.some((p) => p === OTHER_PLANE || !p.trim()) || - data.sequences.some((s) => !s.plane || !s.plane.trim()); + // Filter out "Unknown" sequences (no weight + no sequence) from plane + sequence selectors. + const knownPlanes = new Set(); + for (const seq of data.sequences) { + if (!isKnownSequence(seq)) continue; + knownPlanes.add(normalizePlane(seq.plane)); + } const planes: string[] = []; const seen = new Set(); @@ -27,6 +34,7 @@ function getAvailablePlanes(data: ComparisonData): string[] { // Keep the plane ordering provided by the dataset, but move Other to the end. for (const p of data.planes) { if (p === OTHER_PLANE || !p.trim()) continue; + if (!knownPlanes.has(p)) continue; if (!seen.has(p)) { seen.add(p); planes.push(p); @@ -35,6 +43,7 @@ function getAvailablePlanes(data: ComparisonData): string[] { // Defensive: include any planes that appear on sequences but not in data.planes. for (const seq of data.sequences) { + if (!isKnownSequence(seq)) continue; const p = normalizePlane(seq.plane); if (p === OTHER_PLANE) continue; if (!seen.has(p)) { @@ -43,7 +52,7 @@ function getAvailablePlanes(data: ComparisonData): string[] { } } - if (hasOther) planes.push(OTHER_PLANE); + if (knownPlanes.has(OTHER_PLANE)) planes.push(OTHER_PLANE); return planes; } @@ -55,7 +64,7 @@ function pickDefaultPlane(planes: string[]): string | null { } function pickDefaultSequence(data: ComparisonData, plane: string): string | null { - const seq = data.sequences.find((s) => normalizePlane(s.plane) === plane) || data.sequences[0]; + const seq = data.sequences.find((s) => isKnownSequence(s) && normalizePlane(s.plane) === plane) || null; return seq ? seq.id : null; } @@ -64,18 +73,21 @@ function findMatchingSequence(data: ComparisonData, newPlane: string, currentSeq if (!currentSeqId) return pickDefaultSequence(data, newPlane); const currentSeq = data.sequences.find((s) => s.id === currentSeqId); - if (!currentSeq) return pickDefaultSequence(data, newPlane); + if (!currentSeq || !isKnownSequence(currentSeq)) return pickDefaultSequence(data, newPlane); // Try to find a sequence in the new plane with same weight and sequence type const exactMatch = data.sequences.find( (s) => - normalizePlane(s.plane) === newPlane && s.weight === currentSeq.weight && s.sequence === currentSeq.sequence + isKnownSequence(s) && + normalizePlane(s.plane) === newPlane && + s.weight === currentSeq.weight && + s.sequence === currentSeq.sequence ); if (exactMatch) return exactMatch.id; // Try matching just the weight const weightMatch = data.sequences.find( - (s) => normalizePlane(s.plane) === newPlane && s.weight === currentSeq.weight + (s) => isKnownSequence(s) && normalizePlane(s.plane) === newPlane && s.weight === currentSeq.weight ); if (weightMatch) return weightMatch.id; @@ -120,7 +132,9 @@ export function useComparisonFilters(data: ComparisonData | null) { const selectedSeqId = useMemo(() => { if (!data || !selectedPlane) return null; const seqIdsForPlane = new Set( - data.sequences.filter((s) => normalizePlane(s.plane) === selectedPlane).map((s) => s.id) + data.sequences + .filter((s) => isKnownSequence(s) && normalizePlane(s.plane) === selectedPlane) + .map((s) => s.id) ); return filters.seqId && seqIdsForPlane.has(filters.seqId) ? filters.seqId diff --git a/frontend/src/hooks/useOverlayNavigation.ts b/frontend/src/hooks/useOverlayNavigation.ts index 7d8af5d..e44d941 100644 --- a/frontend/src/hooks/useOverlayNavigation.ts +++ b/frontend/src/hooks/useOverlayNavigation.ts @@ -4,7 +4,7 @@ import { readLocalStorageJson, writeLocalStorageJson } from '../utils/persistenc import { OVERLAY_NAV_STORAGE_KEY } from '../utils/storageKeys'; type PersistedOverlayNav = { - viewMode?: 'grid' | 'overlay'; + viewMode?: 'grid' | 'overlay' | 'svr3d'; overlayDate?: string; playSpeed?: number; }; @@ -15,7 +15,14 @@ function readPersistedOverlayNav(): PersistedOverlayNav { const obj = parsed as Record; - const viewMode = obj.viewMode === 'overlay' ? 'overlay' : obj.viewMode === 'grid' ? 'grid' : undefined; + const viewMode = + obj.viewMode === 'overlay' + ? 'overlay' + : obj.viewMode === 'grid' + ? 'grid' + : obj.viewMode === 'svr3d' + ? 'svr3d' + : undefined; const overlayDate = typeof obj.overlayDate === 'string' ? obj.overlayDate : undefined; const playSpeed = typeof obj.playSpeed === 'number' && Number.isFinite(obj.playSpeed) ? obj.playSpeed : undefined; @@ -44,8 +51,9 @@ export function useOverlayNavigation( writeLocalStorageJson(OVERLAY_NAV_STORAGE_KEY, next); }, []); - const [viewMode, setViewModeState] = useState<'grid' | 'overlay'>(() => { - return readPersistedOverlayNav().viewMode === 'overlay' ? 'overlay' : 'grid'; + const [viewMode, setViewModeState] = useState<'grid' | 'overlay' | 'svr3d'>(() => { + const restored = readPersistedOverlayNav().viewMode; + return restored === 'overlay' ? 'overlay' : restored === 'svr3d' ? 'svr3d' : 'grid'; }); const [overlayDateIndex, setOverlayDateIndexState] = useState(0); const [previousOverlayDateIndex, setPreviousOverlayDateIndex] = useState(null); @@ -57,7 +65,7 @@ export function useOverlayNavigation( // Track spacebar held state for compare feature const [spaceHeld, setSpaceHeld] = useState(false); - const setViewMode = useCallback((next: 'grid' | 'overlay') => { + const setViewMode = useCallback((next: 'grid' | 'overlay' | 'svr3d') => { setViewModeState(next); // Avoid getting stuck in compare mode if the user releases Space while not in overlay mode. diff --git a/frontend/src/hooks/useSvrReconstruction.ts b/frontend/src/hooks/useSvrReconstruction.ts new file mode 100644 index 0000000..d458179 --- /dev/null +++ b/frontend/src/hooks/useSvrReconstruction.ts @@ -0,0 +1,116 @@ +import { useCallback, useRef, useState } from 'react'; +import type { SvrParams, SvrProgress, SvrResult, SvrSelectedSeries } from '../types/svr'; +import { DEFAULT_SVR_PARAMS } from '../types/svr'; +import { reconstructVolumeMultiPlane } from '../utils/svr/reconstructVolume'; + +export type UseSvrReconstructionState = { + isRunning: boolean; + progress: SvrProgress | null; + result: SvrResult | null; + error: string | null; +}; + +export type SvrRunOutcome = { + result: SvrResult | null; + error: string | null; + durationMs: number; +}; + +export function useSvrReconstruction() { + const [state, setState] = useState({ + isRunning: false, + progress: null, + result: null, + error: null, + }); + + const abortRef = useRef(null); + const lastProgressUpdateMsRef = useRef(0); + + const cancel = useCallback(() => { + abortRef.current?.abort(); + }, []); + + const clear = useCallback(() => { + setState({ isRunning: false, progress: null, result: null, error: null }); + }, []); + + const run = useCallback(async (selectedSeries: SvrSelectedSeries[], params?: Partial): Promise => { + abortRef.current?.abort(); + + const controller = new AbortController(); + abortRef.current = controller; + + const svrParams: SvrParams = { ...DEFAULT_SVR_PARAMS, ...(params || {}) }; + + setState({ + isRunning: true, + progress: { phase: 'idle', current: 0, total: 100, message: 'Starting…' }, + result: null, + error: null, + }); + + lastProgressUpdateMsRef.current = 0; + + const started = performance.now(); + + try { + const result = await reconstructVolumeMultiPlane({ + selectedSeries, + svrParams, + signal: controller.signal, + onProgress: (p) => { + const now = Date.now(); + const isFinal = p.current >= p.total; + + // Avoid spamming React renders. + if (!isFinal && now - lastProgressUpdateMsRef.current < 100) { + return; + } + lastProgressUpdateMsRef.current = now; + + setState((s) => ({ + ...s, + progress: p, + })); + }, + }); + + setState({ + isRunning: false, + progress: { phase: 'finalizing', current: 100, total: 100, message: 'Done' }, + result, + error: null, + }); + + return { + result, + error: null, + durationMs: performance.now() - started, + }; + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + setState({ + isRunning: false, + progress: null, + result: null, + error: msg, + }); + + return { + result: null, + error: msg, + durationMs: performance.now() - started, + }; + } finally { + abortRef.current = null; + } + }, []); + + return { + ...state, + run, + cancel, + clear, + }; +} diff --git a/frontend/src/services/dicomIngestion.ts b/frontend/src/services/dicomIngestion.ts index b6a0ed6..40301a5 100644 --- a/frontend/src/services/dicomIngestion.ts +++ b/frontend/src/services/dicomIngestion.ts @@ -201,6 +201,7 @@ const TAGS = { ImageOrientationPatient: 'x00200037', PixelSpacing: 'x00280030', SliceThickness: 'x00180050', + SpacingBetweenSlices: 'x00180088', WindowCenter: 'x00281050', WindowWidth: 'x00281051', }; @@ -332,6 +333,9 @@ export async function processDicomFile(file: File): Promise { const wc = getNumber(dataSet, TAGS.WindowCenter); const ww = getNumber(dataSet, TAGS.WindowWidth); + const sliceThickness = getNumber(dataSet, TAGS.SliceThickness); + const spacingBetweenSlices = getNumber(dataSet, TAGS.SpacingBetweenSlices); + const instanceBase = { sopInstanceUid: instanceUid, seriesInstanceUid: seriesUid, @@ -343,7 +347,8 @@ export async function processDicomFile(file: File): Promise { imagePositionPatient: getText(dataSet, TAGS.ImagePositionPatient), imageOrientationPatient: getText(dataSet, TAGS.ImageOrientationPatient), pixelSpacing: pixelSpacing, - sliceThickness: getNumber(dataSet, TAGS.SliceThickness), + sliceThickness: sliceThickness > 0 ? sliceThickness : undefined, + spacingBetweenSlices: spacingBetweenSlices > 0 ? spacingBetweenSlices : undefined, windowCenter: wc, windowWidth: ww, }; diff --git a/frontend/src/services/exportBackup.ts b/frontend/src/services/exportBackup.ts index e0294ab..75a0dfb 100644 --- a/frontend/src/services/exportBackup.ts +++ b/frontend/src/services/exportBackup.ts @@ -39,6 +39,7 @@ function toMetadata(instance: DicomInstance) { imageOrientationPatient: instance.imageOrientationPatient ?? null, pixelSpacing: instance.pixelSpacing ?? null, sliceThickness: instance.sliceThickness ?? null, + spacingBetweenSlices: instance.spacingBetweenSlices ?? null, windowCenter: instance.windowCenter ?? null, windowWidth: instance.windowWidth ?? null, }; diff --git a/frontend/src/services/svrHarness.ts b/frontend/src/services/svrHarness.ts new file mode 100644 index 0000000..ac5c84f --- /dev/null +++ b/frontend/src/services/svrHarness.ts @@ -0,0 +1,86 @@ +import JSZip from 'jszip'; +import type { SvrParams, SvrResult, SvrSelectedSeries } from '../types/svr'; + +type HarnessRunName = 'baseline' | 'high-detail'; + +export type SvrHarnessRun = { + name: HarnessRunName; + params: SvrParams; + durationMs: number; + result: SvrResult; + + /** Optional 3D render capture (from the WebGL viewer), if available. */ + render3dPng?: Blob | null; +}; + +export type ExportSvrHarnessZipParams = { + dateIso: string | null; + selectedSeries: SvrSelectedSeries[]; + runs: [SvrHarnessRun, SvrHarnessRun]; +}; + +async function toArrayBuffer(value: Blob): Promise { + // JSZip can consume blobs, but ArrayBuffer is the most compatible across runtimes. + return value.arrayBuffer(); +} + +export async function exportSvrHarnessZip(params: ExportSvrHarnessZipParams): Promise { + const { dateIso, selectedSeries, runs } = params; + + const zip = new JSZip(); + const exportedAt = new Date().toISOString(); + + const manifest = { + exportedAt, + dateIso, + selectedSeries, + runs: runs.map((r) => ({ + name: r.name, + durationMs: Math.round(r.durationMs), + dims: r.result.volume.dims, + voxelSizeMm: r.result.volume.voxelSizeMm, + originMm: r.result.volume.originMm, + boundsMm: r.result.volume.boundsMm, + params: r.params, + })), + version: 1, + }; + + zip.file('svr_harness.json', JSON.stringify(manifest, null, 2)); + + for (const r of runs) { + const folder = zip.folder(r.name); + if (!folder) continue; + + folder.file( + 'meta.json', + JSON.stringify( + { + name: r.name, + durationMs: Math.round(r.durationMs), + dims: r.result.volume.dims, + voxelSizeMm: r.result.volume.voxelSizeMm, + originMm: r.result.volume.originMm, + boundsMm: r.result.volume.boundsMm, + params: r.params, + selectedSeries, + }, + null, + 2 + ) + ); + + const previewsFolder = folder.folder('previews'); + if (!previewsFolder) continue; + + previewsFolder.file('axial.png', await toArrayBuffer(r.result.previews.axial)); + previewsFolder.file('coronal.png', await toArrayBuffer(r.result.previews.coronal)); + previewsFolder.file('sagittal.png', await toArrayBuffer(r.result.previews.sagittal)); + + if (r.render3dPng) { + previewsFolder.file('render3d.png', await toArrayBuffer(r.render3dPng)); + } + } + + return zip.generateAsync({ type: 'blob', compression: 'DEFLATE', compressionOptions: { level: 6 } }); +} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 0565859..3625aa2 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -125,6 +125,11 @@ export interface AlignmentProgress { dateIndex: number; totalDates: number; slicesChecked: number; - /** Mutual information (natural log) from the coarse slice search. Higher is better. */ + /** + * Slice-search score. Higher is better. + * + * This value corresponds to whatever metric is being used for slice search (e.g. SSIM or + * LNCC). It is not necessarily MI/NMI. + */ bestMiSoFar: number; } diff --git a/frontend/src/types/svr.ts b/frontend/src/types/svr.ts new file mode 100644 index 0000000..79ab140 --- /dev/null +++ b/frontend/src/types/svr.ts @@ -0,0 +1,143 @@ +export type SvrPhase = 'idle' | 'loading' | 'initializing' | 'reconstructing' | 'finalizing'; + +export type SvrProgress = { + phase: SvrPhase; + current: number; + total: number; + message: string; +}; + +export type SvrRoiPlane = 'axial' | 'coronal' | 'sagittal'; + +export type SvrRoi = { + /** For now we only support cube ROIs (square in-plane + equal extent through-plane). */ + mode: 'cube'; + /** Which preview plane the user drew the ROI on (used for metadata / debugging). */ + sourcePlane: SvrRoiPlane; + /** + * Which input series the ROI was defined against. + * + * When using `seriesRegistrationMode: 'bounds-center'`, we use this series as the alignment reference so the ROI stays + * in the same coordinate frame. + */ + sourceSeriesUid?: string; + /** ROI bounds in world/patient mm coordinates (same frame as DICOM IPP/IOP). */ + boundsMm: { + min: [number, number, number]; + max: [number, number, number]; + }; +}; + +export type SvrParams = { + /** Target isotropic voxel size in mm (may be increased automatically to fit within maxVolumeDim). */ + targetVoxelSizeMm: number; + /** Clamp each output dimension (x/y/z) to this maximum by increasing voxel size if needed. */ + maxVolumeDim: number; + + /** Downsample behavior for input slices before reconstruction. */ + sliceDownsampleMode: 'fixed' | 'voxel-aware'; + + /** Downsample each slice (keeping aspect) so max(rows, cols) <= this value before reconstruction. */ + sliceDownsampleMaxSize: number; + + /** + * Inter-series registration mode applied before fusion. + * + * - 'none': trust DICOM geometry as-is. + * - 'bounds-center': translate each series so its 3D bounds center matches the reference series. + * This is a coarse but cheap stabilization when the scanner's spatial tags are inconsistent. + */ + seriesRegistrationMode: 'none' | 'bounds-center' | 'roi-rigid'; + + /** SVR refinement iterations (forward-project residuals back into the volume). */ + iterations: number; + /** Step size for each refinement iteration (0..1-ish). */ + stepSize: number; + + /** Clamp output voxel intensities to [0, 1]. */ + clampOutput: boolean; + + /** + * Slice-thickness forward model. + * + * - 'none': treat each pixel as a point sample on the slice plane. + * - 'box': integrate uniformly across the slice thickness support. + * - 'gaussian': distance-to-plane weighting within the thickness support. + */ + psfMode?: 'none' | 'box' | 'gaussian'; + + /** Robust loss applied to residuals during refinement iterations. */ + robustLoss?: 'none' | 'huber' | 'tukey'; + /** Residual scale parameter for robust loss (in normalized intensity units [0,1]). */ + robustDelta?: number; + + /** + * Light 3D Laplacian smoothing between iterations. + * 0 disables regularization. + */ + laplacianWeight?: number; + + /** Multi-resolution schedule: coarse grid bootstrapping before fine iterations. */ + multiResolution?: boolean; + /** Coarse voxel size factor relative to target voxel size (e.g. 2 -> 2x coarser). */ + multiResolutionFactor?: number; + /** How many iterations to run at the coarse level (0 disables coarse refinement). */ + multiResolutionCoarseIterations?: number; + + /** Optional reconstruction ROI. If set, the output grid is restricted to this region (faster + smaller). */ + roi?: SvrRoi | null; +}; + +export const DEFAULT_SVR_PARAMS: SvrParams = { + targetVoxelSizeMm: 1.0, + maxVolumeDim: 192, + sliceDownsampleMode: 'voxel-aware', + sliceDownsampleMaxSize: 128, + seriesRegistrationMode: 'roi-rigid', + + // Core solver defaults (chosen to be conservative but higher-fidelity than point-sample SVR). + psfMode: 'gaussian', + robustLoss: 'huber', + robustDelta: 0.1, + laplacianWeight: 0.02, + multiResolution: true, + multiResolutionFactor: 2, + multiResolutionCoarseIterations: 1, + + iterations: 3, + stepSize: 0.6, + clampOutput: true, +}; + +export type SvrVolume = { + data: Float32Array; + dims: [number, number, number]; + voxelSizeMm: [number, number, number]; + originMm: [number, number, number]; + boundsMm: { + min: [number, number, number]; + max: [number, number, number]; + }; +}; + +export type SvrPreviewImages = { + axial: Blob; + coronal: Blob; + sagittal: Blob; +}; + +export type SvrResult = { + volume: SvrVolume; + previews: SvrPreviewImages; +}; + +export type SvrSelectedSeries = { + seriesUid: string; + studyId: string; + dateIso: string; + instanceCount: number; + label: string; + plane?: string | null; + weight?: string | null; + sequence?: string | null; +}; diff --git a/frontend/src/utils/alignment.ts b/frontend/src/utils/alignment.ts index e37085d..95b0e1c 100644 --- a/frontend/src/utils/alignment.ts +++ b/frontend/src/utils/alignment.ts @@ -2,6 +2,30 @@ import type { ExclusionMask, HistogramStats, PanelSettings } from '../types/api' import { CONTROL_LIMITS, DEFAULT_PANEL_SETTINGS } from './constants'; import { clamp, nowMs } from './math'; import { computeMutualInformation, type MutualInformationOptions } from './mutualInformation'; +import { computeGradientMagnitudeL1Square } from './imageFeatures'; +import { computeBlockSimilarity } from './ssim'; +import { prepareMindReference, computeMindSimilarity } from './mind'; +import { + preparePhaseCorrelationReference, + computePhaseCorrelationSimilarity, + createPhaseCorrelationScratch, +} from './phaseCorrelation'; +import { resample2dAreaAverage } from './svr/resample2d'; + +const POPCOUNT_8 = (() => { + // 8-bit popcount lookup (used for Census Hamming distance). + const t = new Uint8Array(256); + for (let i = 0; i < 256; i++) { + let v = i; + let c = 0; + while (v) { + v &= v - 1; + c++; + } + t[i] = c; + } + return t; +})(); /** * Compute normalized mutual information (NMI) between two grayscale images. @@ -20,12 +44,16 @@ export function computeNMI(imageA: Float32Array, imageB: Float32Array, bins: num */ export interface SliceSearchResult { bestIndex: number; - /** Mutual information (natural log). Higher is better. */ + /** + * Slice-search score. Higher is better. + * + * Note: the meaning of this value depends on the selected `scoreMetric`. + */ bestMI: number; slicesChecked: number; /** Optional perf counters for profiling/debugging. */ timingMs?: { - /** Time spent computing MI/NMI scores (excludes rendering / warping). */ + /** Time spent computing similarity scores (excludes rendering / warping). */ scoreMs: number; }; } @@ -35,21 +63,117 @@ type SliceScoreDirection = 'start' | 'left' | 'right'; type FindBestMatchingSliceOptions = { /** Override the starting index with a better initial guess (e.g. from a coarse seed). */ startIndexOverride?: number; + + /** Optional search bounds (inclusive). Useful when applying a prior / window constraint. */ + minIndex?: number; + maxIndex?: number; + /** Histogram bins for MI/NMI scoring. Lower values are faster but less sensitive. */ miBins?: number; + /** How many consecutive decreases are required before stopping a direction. */ stopDecreaseStreak?: number; + + /** + * Minimum number of slices to score in *each* direction before early-stop logic is allowed + * to terminate that direction. + * + * This directly addresses “off by ~5 slices” misses when the metric has a noisy dip early. + */ + minSearchRadius?: number; + + /** + * Score metric to use for bestIndex selection. + * + * Notes: + * - SSIM tends to correspond best to perceived similarity. + * - LNCC/ZNCC can be strong baselines for MRI when intensity changes are mostly affine. + * - NGF focuses on gradient *direction* alignment (edge orientation). + * - Census is a rank-based local descriptor (robust to monotonic intensity changes). + * - MIND is a modality-robust self-similarity descriptor commonly used in medical registration. + * - Phase correlation is FFT-based and is most sensitive to translation agreement. + */ + scoreMetric?: 'ssim' | 'lncc' | 'zncc' | 'ngf' | 'census' | 'mind' | 'phase'; + + /** + * SSIM / LNCC block config. + * + * Note: we use fast block-based approximations (not Gaussian-window SSIM). + */ + ssimBlockSize?: number; + + /** + * Downsample size (square) used for the MIND descriptor metric. + * Default: 64. + */ + mindSize?: number; + + /** + * Downsample size (square, power-of-two) used for phase correlation. + * Default: 64. + */ + phaseSize?: number; + /** Optional hook for logging or debugging slice-level scores. */ - onSliceScored?: (index: number, metrics: { mi: number; nmi: number }, direction: SliceScoreDirection) => void; + onSliceScored?: ( + index: number, + metrics: { + /** Block-based SSIM similarity on intensity images (higher is better). */ + ssim: number; + /** Block-based LNCC similarity on intensity images (higher is better). */ + lncc: number; + /** Global ZNCC similarity on intensity images (higher is better). */ + zncc: number; + /** Normalized gradient field similarity (higher is better). */ + ngf: number; + /** Census similarity (higher is better). */ + census: number; + /** MIND-like descriptor similarity (higher is better). */ + mind?: number; + /** Phase correlation similarity (higher is better). */ + phase?: number; + /** Raw MI/NMI on intensity images (debug-only; can be used for comparison). */ + mi: number; + nmi: number; + /** Combined slice-search score used for bestIndex selection. */ + score: number; + /** Optional MI/NMI on gradient magnitude images (debug-only, when enabled). */ + miGrad?: number; + nmiGrad?: number; + /** Pixels used for scoring (after masks). */ + pixelsUsed?: number; + }, + direction: SliceScoreDirection + ) => void; + + /** + * Optional inclusion mask. + * If provided, only pixels where inclusionMask[idx] != 0 are used for scoring. + */ + inclusionMask?: Uint8Array; + /** * Optional exclusion rectangle in normalized [0,1] image coordinates. - * Pixels inside this rect are excluded from MI scoring (useful for ignoring tumors). + * Pixels inside this rect are excluded from scoring (useful for ignoring tumors). */ exclusionRect?: ExclusionMask; + /** Image width in pixels (required if exclusionRect is provided). */ imageWidth?: number; /** Image height in pixels (required if exclusionRect is provided). */ imageHeight?: number; + + /** + * Optional gradient-magnitude scoring (MI/NMI), used for debugging/comparison. + */ + gradient?: { + referenceGradPixels: Float32Array; + weight: number; + }; + + /** Optional yielding to keep UI responsive during heavy 512px scoring. */ + yieldEverySlices?: number; + yieldFn?: () => Promise; }; /** @@ -58,13 +182,13 @@ type FindBestMatchingSliceOptions = { * Strategy: * - Start at the normalized slice depth (refIndex/refCount mapped into targetCount) * - Search outward in both directions - * - Stop in each direction only after N consecutive MI decreases (per-direction) + * - Stop in each direction only after N consecutive score decreases (per-direction) * * Rationale: * - Adjacent slices can be noisy; a single decrease is not sufficient to stop. - * - We intentionally do NOT early-exit based on bestMI, and we do NOT enforce a minimum + * - We intentionally do NOT early-exit based on bestScore, and we do NOT enforce a minimum * search window. That keeps behavior deterministic and avoids premature termination when - * MI happens to spike early. + * the metric happens to spike early. */ export async function findBestMatchingSlice( referencePixels: Float32Array, @@ -82,26 +206,404 @@ export async function findBestMatchingSlice( const STOP_DECREASE_STREAK = options?.stopDecreaseStreak ?? 2; const MI_BINS = options?.miBins ?? 64; const exclusionRect = options?.exclusionRect; + const inclusionMask = options?.inclusionMask; const imageWidth = options?.imageWidth; const imageHeight = options?.imageHeight; + const squareSize = (() => { + if (typeof imageWidth === 'number' && typeof imageHeight === 'number' && imageWidth === imageHeight) { + return imageWidth; + } + + const s = Math.round(Math.sqrt(referencePixels.length)); + if (s <= 0 || s * s !== referencePixels.length) { + throw new Error('findBestMatchingSlice: expected square referencePixels (provide imageWidth/imageHeight)'); + } + return s; + })(); + + const minSearchRadius = Math.max(0, Math.round(options?.minSearchRadius ?? 0)); + + const minIndexBound = + typeof options?.minIndex === 'number' && Number.isFinite(options.minIndex) ? Math.round(options.minIndex) : 0; + const maxIndexBound = + typeof options?.maxIndex === 'number' && Number.isFinite(options.maxIndex) + ? Math.round(options.maxIndex) + : targetSliceCount - 1; + + let minIndex = clamp(minIndexBound, 0, Math.max(0, targetSliceCount - 1)); + let maxIndex = clamp(maxIndexBound, 0, Math.max(0, targetSliceCount - 1)); + + if (minIndex > maxIndex) { + // Defensive: if the caller provides inverted bounds, fall back to the full range. + minIndex = 0; + maxIndex = Math.max(0, targetSliceCount - 1); + } + + const yieldEverySlices = Math.max(0, Math.round(options?.yieldEverySlices ?? 0)); + const yieldFn = options?.yieldFn; + let slicesSinceYield = 0; + let scoreMs = 0; - const computeMetrics = (targetPixels: Float32Array): { mi: number; nmi: number } => { + const grad = options?.gradient; + if (grad && grad.referenceGradPixels.length !== referencePixels.length) { + throw new Error('findBestMatchingSlice: referenceGradPixels size mismatch'); + } + + const scoreMetric = options?.scoreMetric ?? 'ssim'; + + const wantDebugMetrics = typeof options?.onSliceScored === 'function'; + const wantBlockSimilarity = wantDebugMetrics || scoreMetric === 'ssim' || scoreMetric === 'lncc' || scoreMetric === 'zncc'; + const wantNgf = wantDebugMetrics || scoreMetric === 'ngf'; + const wantCensus = wantDebugMetrics || scoreMetric === 'census'; + + // MIND / phase correlation are more expensive, so we only compute them when selected, + // OR when debug metrics are enabled (so the in-viewer debug overlay is fully populated). + const wantMind = wantDebugMetrics || scoreMetric === 'mind'; + const wantPhase = wantDebugMetrics || scoreMetric === 'phase'; + + if (inclusionMask && inclusionMask.length !== referencePixels.length) { + throw new Error( + `findBestMatchingSlice: inclusionMask length mismatch (mask=${inclusionMask.length}, image=${referencePixels.length})` + ); + } + + // Precompute exclusion bounds once (shared across all similarity metrics). + let hasExclusion = false; + let exclX0 = 0; + let exclY0 = 0; + let exclX1 = 0; + let exclY1 = 0; + if (exclusionRect && squareSize > 0) { + exclX0 = Math.floor(exclusionRect.x * squareSize); + exclY0 = Math.floor(exclusionRect.y * squareSize); + exclX1 = Math.ceil((exclusionRect.x + exclusionRect.width) * squareSize); + exclY1 = Math.ceil((exclusionRect.y + exclusionRect.height) * squareSize); + hasExclusion = exclX1 > exclX0 && exclY1 > exclY0; + } + + // Optional downsampled metrics. + // + // MIND and phase correlation are expensive at 512px. We run them on a smaller grid. + const downsampleMaskSquare = (mask: Uint8Array, srcSize: number, dstSize: number): Uint8Array => { + if (dstSize === srcSize) return mask; + + const f = resample2dAreaAverage(mask, srcSize, srcSize, dstSize, dstSize); + const out = new Uint8Array(dstSize * dstSize); + for (let i = 0; i < out.length; i++) { + out[i] = (f[i] ?? 0) >= 0.5 ? 1 : 0; + } + return out; + }; + + const mindSizeRequested = Math.max(16, Math.round(options?.mindSize ?? 64)); + const mindSize = Math.min(squareSize, mindSizeRequested); + + const phaseSizeRequested = Math.max(8, Math.round(options?.phaseSize ?? 64)); + const phaseSizeClamped = Math.min(squareSize, phaseSizeRequested); + + // Ensure phase correlation size is power-of-two (FFT requirement). + const floorPowerOfTwo = (v: number): number => { + let n = 1; + while (n * 2 <= v) n *= 2; + return n; + }; + const phaseSize = floorPowerOfTwo(phaseSizeClamped); + + const mindPrepared = wantMind + ? (() => { + const refMindPixels = + mindSize === squareSize + ? referencePixels + : resample2dAreaAverage(referencePixels, squareSize, squareSize, mindSize, mindSize); + + const mindMask = inclusionMask + ? mindSize === squareSize + ? inclusionMask + : downsampleMaskSquare(inclusionMask, squareSize, mindSize) + : undefined; + + return prepareMindReference(refMindPixels, { + inclusionMask: mindMask, + exclusionRect, + imageWidth: mindSize, + imageHeight: mindSize, + patchRadius: 1, + }); + })() + : null; + + const phasePrepared = wantPhase + ? (() => { + const refPhasePixels = + phaseSize === squareSize + ? referencePixels + : resample2dAreaAverage(referencePixels, squareSize, squareSize, phaseSize, phaseSize); + + const phaseMask = inclusionMask + ? phaseSize === squareSize + ? inclusionMask + : downsampleMaskSquare(inclusionMask, squareSize, phaseSize) + : undefined; + + return preparePhaseCorrelationReference(refPhasePixels, { + inclusionMask: phaseMask, + exclusionRect, + imageWidth: phaseSize, + imageHeight: phaseSize, + window: true, + }); + })() + : null; + + const phaseScratch = wantPhase && phasePrepared ? createPhaseCorrelationScratch(phasePrepared.size) : null; + + // NGF reference: store normalized gradients for the reference slice. + const refNgf = wantNgf + ? (() => { + const nx = new Float32Array(referencePixels.length); + const ny = new Float32Array(referencePixels.length); + const eps = 1e-8; + + if (squareSize > 2) { + for (let y = 1; y < squareSize - 1; y++) { + const row = y * squareSize; + for (let x = 1; x < squareSize - 1; x++) { + const idx = row + x; + const dx = (referencePixels[idx + 1] ?? 0) - (referencePixels[idx - 1] ?? 0); + const dy = (referencePixels[idx + squareSize] ?? 0) - (referencePixels[idx - squareSize] ?? 0); + const denom = Math.sqrt(dx * dx + dy * dy + eps); + nx[idx] = dx / denom; + ny[idx] = dy / denom; + } + } + } + + return { nx, ny }; + })() + : null; + + // Census reference (3x3): store 8-bit codes for the reference slice. + const refCensus = wantCensus + ? (() => { + const codes = new Uint8Array(referencePixels.length); + if (squareSize > 2) { + for (let y = 1; y < squareSize - 1; y++) { + const row = y * squareSize; + for (let x = 1; x < squareSize - 1; x++) { + const idx = row + x; + const c = referencePixels[idx] ?? 0; + let code = 0; + // 8 neighbors (clockwise starting top-left) + if ((referencePixels[idx - squareSize - 1] ?? 0) < c) code |= 1 << 0; + if ((referencePixels[idx - squareSize] ?? 0) < c) code |= 1 << 1; + if ((referencePixels[idx - squareSize + 1] ?? 0) < c) code |= 1 << 2; + if ((referencePixels[idx - 1] ?? 0) < c) code |= 1 << 3; + if ((referencePixels[idx + 1] ?? 0) < c) code |= 1 << 4; + if ((referencePixels[idx + squareSize - 1] ?? 0) < c) code |= 1 << 5; + if ((referencePixels[idx + squareSize] ?? 0) < c) code |= 1 << 6; + if ((referencePixels[idx + squareSize + 1] ?? 0) < c) code |= 1 << 7; + codes[idx] = code; + } + } + } + return codes; + })() + : null; + + const computeMetrics = (targetPixels: Float32Array): { + ssim: number; + lncc: number; + zncc: number; + ngf: number; + census: number; + mind?: number; + phase?: number; + mi: number; + nmi: number; + score: number; + miGrad?: number; + nmiGrad?: number; + pixelsUsed?: number; + } => { const t0 = nowMs(); - // We compute MI + NMI together from the histogram. - // If an exclusion rect is provided, skip those pixels. - const miOptions: MutualInformationOptions = { - bins: MI_BINS, - exclusionRect, - imageWidth, - imageHeight, - }; - const miResult = computeMutualInformation(referencePixels, targetPixels, miOptions); + // SSIM/LNCC/ZNCC. + const sim = wantBlockSimilarity + ? computeBlockSimilarity(referencePixels, targetPixels, { + blockSize: options?.ssimBlockSize, + inclusionMask, + exclusionRect, + imageWidth, + imageHeight, + }) + : { ssim: 0, lncc: 0, zncc: 0, blocksUsed: 0, pixelsUsed: 0 }; + + // NGF. + let ngf = 0; + let ngfPixelsUsed = 0; + if (wantNgf && refNgf && squareSize > 2) { + const eps = 1e-8; + let sum = 0; + let used = 0; + for (let y = 1; y < squareSize - 1; y++) { + const row = y * squareSize; + for (let x = 1; x < squareSize - 1; x++) { + const idx = row + x; + if (inclusionMask && inclusionMask[idx] === 0) continue; + if (hasExclusion && x >= exclX0 && x < exclX1 && y >= exclY0 && y < exclY1) continue; + + const dx = (targetPixels[idx + 1] ?? 0) - (targetPixels[idx - 1] ?? 0); + const dy = (targetPixels[idx + squareSize] ?? 0) - (targetPixels[idx - squareSize] ?? 0); + const denom = Math.sqrt(dx * dx + dy * dy + eps); + const nx = dx / denom; + const ny = dy / denom; + + const dot = (refNgf.nx[idx] ?? 0) * nx + (refNgf.ny[idx] ?? 0) * ny; + // Use squared dot product so opposite directions aren't treated as dissimilar. + sum += dot * dot; + used++; + } + } + + ngfPixelsUsed = used; + ngf = used > 0 ? sum / used : 0; + } + + // Census (3x3). + let census = 0; + let censusPixelsUsed = 0; + if (wantCensus && refCensus && squareSize > 2) { + let diffBits = 0; + let used = 0; + + for (let y = 1; y < squareSize - 1; y++) { + const row = y * squareSize; + for (let x = 1; x < squareSize - 1; x++) { + const idx = row + x; + if (inclusionMask && inclusionMask[idx] === 0) continue; + if (hasExclusion && x >= exclX0 && x < exclX1 && y >= exclY0 && y < exclY1) continue; + + const c = targetPixels[idx] ?? 0; + let code = 0; + if ((targetPixels[idx - squareSize - 1] ?? 0) < c) code |= 1 << 0; + if ((targetPixels[idx - squareSize] ?? 0) < c) code |= 1 << 1; + if ((targetPixels[idx - squareSize + 1] ?? 0) < c) code |= 1 << 2; + if ((targetPixels[idx - 1] ?? 0) < c) code |= 1 << 3; + if ((targetPixels[idx + 1] ?? 0) < c) code |= 1 << 4; + if ((targetPixels[idx + squareSize - 1] ?? 0) < c) code |= 1 << 5; + if ((targetPixels[idx + squareSize] ?? 0) < c) code |= 1 << 6; + if ((targetPixels[idx + squareSize + 1] ?? 0) < c) code |= 1 << 7; + + const refCode = refCensus[idx] ?? 0; + diffBits += POPCOUNT_8[(refCode ^ code) & 0xff] ?? 0; + used++; + } + } + + censusPixelsUsed = used; + const totalBits = used * 8; + census = totalBits > 0 ? 1 - diffBits / totalBits : 0; + } + + // MIND (downsampled). + let mind: number | undefined; + let mindPixelsUsed = 0; + if (wantMind && mindPrepared && mindSize > 0) { + const targetMindPixels = + mindSize === squareSize + ? targetPixels + : resample2dAreaAverage(targetPixels, squareSize, squareSize, mindSize, mindSize); + + const r = computeMindSimilarity(mindPrepared, targetMindPixels); + mind = r.mind; + mindPixelsUsed = r.pixelsUsed; + } + + // Phase correlation (downsampled FFT). + let phase: number | undefined; + let phasePixelsUsed = 0; + if (wantPhase && phasePrepared && phaseScratch) { + const targetPhasePixels = + phasePrepared.size === squareSize + ? targetPixels + : resample2dAreaAverage(targetPixels, squareSize, squareSize, phasePrepared.size, phasePrepared.size); + + const r = computePhaseCorrelationSimilarity(phasePrepared, targetPhasePixels, phaseScratch); + phase = r.phase; + phasePixelsUsed = r.pixelsUsed; + } + + // For debug overlays/logs we also compute MI/NMI so we can compare metrics. + // Avoid this work unless the caller has requested per-slice metrics. + let mi = 0; + let nmi = 0; + let miGrad: number | undefined; + let nmiGrad: number | undefined; + let pixelsUsed: number | undefined = sim.pixelsUsed; + + if (wantDebugMetrics) { + // We compute MI + NMI together from the histogram. + const miOptions: MutualInformationOptions = { + bins: MI_BINS, + inclusionMask, + exclusionRect, + imageWidth, + imageHeight, + }; + + const raw = computeMutualInformation(referencePixels, targetPixels, miOptions); + mi = raw.mi; + nmi = raw.nmi; + pixelsUsed = raw.pixelsUsed; + + if (grad && Number.isFinite(grad.weight) && grad.weight !== 0) { + const targetGrad = computeGradientMagnitudeL1Square(targetPixels, squareSize); + const g = computeMutualInformation(grad.referenceGradPixels, targetGrad, miOptions); + miGrad = g.mi; + nmiGrad = g.nmi; + } + } else { + // For non-debug runs, still expose a useful pixel count for the selected metric. + if (scoreMetric === 'ngf') pixelsUsed = ngfPixelsUsed; + else if (scoreMetric === 'census') pixelsUsed = censusPixelsUsed; + else if (scoreMetric === 'mind') pixelsUsed = mindPixelsUsed; + else if (scoreMetric === 'phase') pixelsUsed = phasePixelsUsed; + } + + // Score used for bestIndex selection. + const score = + scoreMetric === 'lncc' + ? sim.lncc + : scoreMetric === 'zncc' + ? sim.zncc + : scoreMetric === 'ngf' + ? ngf + : scoreMetric === 'census' + ? census + : scoreMetric === 'mind' + ? mind ?? 0 + : scoreMetric === 'phase' + ? phase ?? 0 + : sim.ssim; scoreMs += nowMs() - t0; - return { mi: miResult.mi, nmi: miResult.nmi }; + return { + ssim: sim.ssim, + lncc: sim.lncc, + zncc: sim.zncc, + ngf, + census, + mind, + phase, + mi, + nmi, + score, + miGrad, + nmiGrad, + pixelsUsed, + }; }; // Compute starting index from normalized position. @@ -110,20 +612,20 @@ export async function findBestMatchingSlice( // noticeably off. Callers can override the start index with a better guess (e.g. from a // coarse registration seed). const startIdx = Math.round((refSliceIndex / Math.max(1, refSliceCount - 1)) * (targetSliceCount - 1)); - const fallbackStart = clamp(startIdx, 0, targetSliceCount - 1); + const fallbackStart = clamp(startIdx, minIndex, maxIndex); const startIndexOverride = options?.startIndexOverride; const clampedStart = typeof startIndexOverride === 'number' && Number.isFinite(startIndexOverride) - ? clamp(Math.round(startIndexOverride), 0, targetSliceCount - 1) + ? clamp(Math.round(startIndexOverride), minIndex, maxIndex) : fallbackStart; // Initialize with starting slice const startPixels = await getTargetSlicePixels(clampedStart); let bestIdx = clampedStart; const startMetrics = computeMetrics(startPixels); - let bestMI = startMetrics.mi; + let bestMI = startMetrics.score; let slicesChecked = 1; options?.onSliceScored?.(clampedStart, startMetrics, 'start'); @@ -133,11 +635,14 @@ export async function findBestMatchingSlice( let leftIdx = clampedStart - 1; let rightIdx = clampedStart + 1; - let leftDone = leftIdx < 0; - let rightDone = rightIdx >= targetSliceCount; + let leftDone = leftIdx < minIndex; + let rightDone = rightIdx > maxIndex; + + let leftSteps = 0; + let rightSteps = 0; - let leftPrevMI = bestMI; - let rightPrevMI = bestMI; + let leftPrevScore = startMetrics.score; + let rightPrevScore = startMetrics.score; let leftDecreaseStreak = 0; let rightDecreaseStreak = 0; @@ -146,78 +651,96 @@ export async function findBestMatchingSlice( // Search left if (!leftDone) { const idx = leftIdx; - if (idx < 0) { + if (idx < minIndex) { leftDone = true; } else { const leftPixels = await getTargetSlicePixels(idx); const leftMetrics = computeMetrics(leftPixels); - const leftMI = leftMetrics.mi; + const leftScore = leftMetrics.score; slicesChecked++; + leftSteps++; options?.onSliceScored?.(idx, leftMetrics, 'left'); - if (leftMI > bestMI) { - bestMI = leftMI; + if (leftScore > bestMI) { + bestMI = leftScore; bestIdx = idx; } // Track consecutive decreases in this direction. - if (leftMI < leftPrevMI) { + if (leftScore < leftPrevScore) { leftDecreaseStreak++; } else { leftDecreaseStreak = 0; } - leftPrevMI = leftMI; + leftPrevScore = leftScore; leftIdx = idx - 1; - if (leftIdx < 0) { + if (leftIdx < minIndex) { leftDone = true; } else { - // Stop only after N consecutive decreases. - if (leftDecreaseStreak >= STOP_DECREASE_STREAK) { + // Stop only after N consecutive decreases *and* after we have searched far enough. + if (leftDecreaseStreak >= STOP_DECREASE_STREAK && leftSteps >= minSearchRadius) { leftDone = true; } } onProgress?.(slicesChecked, bestMI); + + if (yieldEverySlices > 0 && yieldFn) { + slicesSinceYield++; + if (slicesSinceYield >= yieldEverySlices) { + slicesSinceYield = 0; + await yieldFn(); + } + } } } // Search right if (!rightDone) { const idx = rightIdx; - if (idx >= targetSliceCount) { + if (idx > maxIndex) { rightDone = true; } else { const rightPixels = await getTargetSlicePixels(idx); const rightMetrics = computeMetrics(rightPixels); - const rightMI = rightMetrics.mi; + const rightScore = rightMetrics.score; slicesChecked++; + rightSteps++; options?.onSliceScored?.(idx, rightMetrics, 'right'); - if (rightMI > bestMI) { - bestMI = rightMI; + if (rightScore > bestMI) { + bestMI = rightScore; bestIdx = idx; } - if (rightMI < rightPrevMI) { + if (rightScore < rightPrevScore) { rightDecreaseStreak++; } else { rightDecreaseStreak = 0; } - rightPrevMI = rightMI; + rightPrevScore = rightScore; rightIdx = idx + 1; - if (rightIdx >= targetSliceCount) { + if (rightIdx > maxIndex) { rightDone = true; } else { - if (rightDecreaseStreak >= STOP_DECREASE_STREAK) { + if (rightDecreaseStreak >= STOP_DECREASE_STREAK && rightSteps >= minSearchRadius) { rightDone = true; } } onProgress?.(slicesChecked, bestMI); + + if (yieldEverySlices > 0 && yieldFn) { + slicesSinceYield++; + if (slicesSinceYield >= yieldEverySlices) { + slicesSinceYield = 0; + await yieldFn(); + } + } } } } diff --git a/frontend/src/utils/alignmentSliceScoreStore.ts b/frontend/src/utils/alignmentSliceScoreStore.ts new file mode 100644 index 0000000..c92ce63 --- /dev/null +++ b/frontend/src/utils/alignmentSliceScoreStore.ts @@ -0,0 +1,93 @@ +export type AlignmentSliceScoreMetrics = { + ssim: number; + lncc: number; + zncc: number; + ngf: number; + census: number; + mind: number | null; + phase: number | null; + mi: number; + nmi: number; + miGrad: number | null; + nmiGrad: number | null; + score: number; +}; + +export type AlignmentSliceScoreContext = { + referenceSeriesUid: string; + referenceSliceIndex: number; + startedAtMs: number; +}; + +let context: AlignmentSliceScoreContext | null = null; + +// Keyed by series UID (moving series), then by instance index (0..instance_count-1). +const scoresBySeries = new Map>(); + +export function resetAlignmentSliceScoreStore(nextContext: { + referenceSeriesUid: string; + referenceSliceIndex: number; +}): void { + scoresBySeries.clear(); + context = { + referenceSeriesUid: nextContext.referenceSeriesUid, + referenceSliceIndex: nextContext.referenceSliceIndex, + startedAtMs: Date.now(), + }; +} + +export function getAlignmentSliceScoreContext(): AlignmentSliceScoreContext | null { + return context; +} + +export function recordAlignmentSliceScore( + seriesUid: string, + instanceIndex: number, + metrics: { + ssim: number; + lncc: number; + zncc: number; + ngf: number; + census: number; + mind?: number | null; + phase?: number | null; + mi: number; + nmi: number; + miGrad?: number | null; + nmiGrad?: number | null; + score: number; + } +): void { + if (!seriesUid) return; + if (!Number.isFinite(instanceIndex) || instanceIndex < 0) return; + + let perSeries = scoresBySeries.get(seriesUid); + if (!perSeries) { + perSeries = new Map(); + scoresBySeries.set(seriesUid, perSeries); + } + + perSeries.set(instanceIndex, { + ssim: metrics.ssim, + lncc: metrics.lncc, + zncc: metrics.zncc, + ngf: metrics.ngf, + census: metrics.census, + mind: metrics.mind ?? null, + phase: metrics.phase ?? null, + mi: metrics.mi, + nmi: metrics.nmi, + miGrad: metrics.miGrad ?? null, + nmiGrad: metrics.nmiGrad ?? null, + score: metrics.score, + }); +} + +export function getAlignmentSliceScore( + seriesUid: string, + instanceIndex: number +): AlignmentSliceScoreMetrics | null { + const perSeries = scoresBySeries.get(seriesUid); + if (!perSeries) return null; + return perSeries.get(instanceIndex) ?? null; +} diff --git a/frontend/src/utils/debugSvr.ts b/frontend/src/utils/debugSvr.ts new file mode 100644 index 0000000..e537539 --- /dev/null +++ b/frontend/src/utils/debugSvr.ts @@ -0,0 +1,33 @@ +/** + * Debug SVR utilities. + * + * Defaults: + * - In DEV builds, SVR debug logging is enabled by default. + * - In production builds, it is opt-in. + * + * You can always override via localStorage: + * localStorage.setItem('miraviewer:debug-svr', '1') // force on + * localStorage.setItem('miraviewer:debug-svr', '0') // force off + */ + +export const DEBUG_SVR_STORAGE_KEY = 'miraviewer:debug-svr'; + +export function isDebugSvrEnabled(): boolean { + if (typeof window === 'undefined') return false; + + try { + const v = window.localStorage.getItem(DEBUG_SVR_STORAGE_KEY); + if (v === '1') return true; + if (v === '0') return false; + + // If the key is unset, default to on in dev builds so SVR work is visible without setup. + return !!import.meta.env.DEV; + } catch { + return false; + } +} + +export function debugSvrLog(step: string, details: Record, enabled: boolean): void { + if (!enabled) return; + console.log(`[svr] ${step}`, details); +} diff --git a/frontend/src/utils/elastixRegistration.ts b/frontend/src/utils/elastixRegistration.ts index 25d132d..d37b044 100644 --- a/frontend/src/utils/elastixRegistration.ts +++ b/frontend/src/utils/elastixRegistration.ts @@ -157,6 +157,79 @@ function makeItkFloat32ScalarImage(pixels: Float32Array, size: number, name: str return img; } +type NormalizedRect = { x: number; y: number; width: number; height: number }; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +/** + * Best-effort exclusion for Elastix registration. + * + * The upstream elastix pipeline build we use appears to crash when passing `-fMask` / `-mMask`. + * Instead of true mask support, we neutralize (flatten) the excluded region to make it + * low-information for the optimizer. + * + * We feather the boundary so we don't introduce sharp edges that could become artificial + * alignment features. + */ +function applyExclusionRectFeather( + pixels: Float32Array, + size: number, + exclusionRect: NormalizedRect, + featherPx: number +): { pixels: Float32Array; excludedFrac: number } | null { + if (size <= 0) return null; + + const x0 = clamp(Math.floor(exclusionRect.x * size), 0, size); + const y0 = clamp(Math.floor(exclusionRect.y * size), 0, size); + const x1 = clamp(Math.ceil((exclusionRect.x + exclusionRect.width) * size), 0, size); + const y1 = clamp(Math.ceil((exclusionRect.y + exclusionRect.height) * size), 0, size); + + if (x1 <= x0 || y1 <= y0) return null; + + // Mean of pixels outside the rect (fallback to mid-gray). + let sum = 0; + let count = 0; + for (let y = 0; y < size; y++) { + const row = y * size; + const inY = y >= y0 && y < y1; + for (let x = 0; x < size; x++) { + if (inY && x >= x0 && x < x1) continue; + sum += pixels[row + x] ?? 0; + count++; + } + } + const mean = count > 0 ? sum / count : 0.5; + + const out = Float32Array.from(pixels); + + const feather = Math.max(0, Math.round(featherPx)); + + for (let y = y0; y < y1; y++) { + const row = y * size; + for (let x = x0; x < x1; x++) { + const idx = row + x; + + if (feather <= 0) { + out[idx] = mean; + continue; + } + + const d = Math.min(x - x0, x1 - 1 - x, y - y0, y1 - 1 - y); + // Replace more aggressively as we move deeper inside the rect. + const t = clamp((d + 1) / (feather + 1), 0, 1); + const v = pixels[idx] ?? mean; + out[idx] = v * (1 - t) + mean * t; + } + } + + const excludedPx = (x1 - x0) * (y1 - y0); + const excludedFrac = excludedPx / (size * size); + + return { pixels: out, excludedFrac }; +} + export type ElastixAffine2DRegistrationResult = { /** Moving -> fixed linear matrix (about image center when applied with translatePx). */ A: Mat2; @@ -248,6 +321,16 @@ export async function registerAffine2DWithElastix( numberOfResolutions?: number; initialTransformParameterObject?: JsonCompatible; webWorker?: Worker; + + /** + * Optional exclusion rectangle in normalized [0,1] coordinates (fixed image space). + * + * Note: the current `@itk-wasm/elastix` pipeline build we use appears to crash when passing + * real mask args (`-fMask` / `-mMask`) under `--memory-io`. As a practical workaround, we + * preprocess the pixels inside this rect (feathered fill) so the region becomes + * low-information for the optimizer. + */ + exclusionRect?: NormalizedRect; } ): Promise { assertSquareSize(fixedPixels, size, 'fixedPixels'); @@ -257,8 +340,28 @@ export async function registerAffine2DWithElastix( const webWorker = opts?.webWorker ?? (await getElastixWorker()); - const fixed = makeItkFloat32ScalarImage(fixedPixels, size, 'fixed'); - const moving = makeItkFloat32ScalarImage(movingPixels, size, 'moving'); + const debug = isDebugAlignmentEnabled(); + + const exclusion = + opts?.exclusionRect ? applyExclusionRectFeather(fixedPixels, size, opts.exclusionRect, 4) : null; + + const fixedPixelsForReg = exclusion ? exclusion.pixels : fixedPixels; + + const movingPixelsForReg = opts?.exclusionRect + ? (applyExclusionRectFeather(movingPixels, size, opts.exclusionRect, 4)?.pixels ?? movingPixels) + : movingPixels; + + if (debug && opts?.exclusionRect) { + console.info('[alignment] Elastix exclusion rect (preprocess)', { + size, + exclusionRect: opts.exclusionRect, + excludedFrac: exclusion ? Number(exclusion.excludedFrac.toFixed(4)) : null, + mode: 'feathered-mean-fill', + }); + } + + const fixed = makeItkFloat32ScalarImage(fixedPixelsForReg, size, 'fixed'); + const moving = makeItkFloat32ScalarImage(movingPixelsForReg, size, 'moving'); const affineParameterMap = await getAffineParameterMap(webWorker, numberOfResolutions); @@ -267,7 +370,6 @@ export async function registerAffine2DWithElastix( // We run the pipeline directly (instead of calling the generated `elastix()` wrapper) // so we can capture stdout/stderr and optionally parse Elastix' own metric trace. - const debug = isDebugAlignmentEnabled(); let result: { webWorker: Worker; @@ -290,6 +392,8 @@ export async function registerAffine2DWithElastix( | { type: typeof InterfaceTypes.JsonCompatible; data: JsonCompatible } | { type: typeof InterfaceTypes.Image; data: Image }; + const pipelineBaseUrl = getAppPipelinesBaseUrl(); + const inputs: ElastixPipelineInput[] = [{ type: InterfaceTypes.JsonCompatible, data: parameterObject }]; const args: string[] = []; @@ -325,8 +429,6 @@ export async function registerAffine2DWithElastix( args.push('--initial-transform-parameter-object', inputCountString); } - const pipelineBaseUrl = getAppPipelinesBaseUrl(); - const { webWorker: usedWebWorker, returnValue, stdout, stderr, outputs } = await runPipeline( 'elastix', args, @@ -391,7 +493,7 @@ export async function registerAffine2DWithElastix( // This prevents subtle convention mismatches (or chain ordering issues) from silently // producing incorrect on-screen geometry despite the registration output looking plausible. const { best, candidates } = chooseBestElastixTransformCandidateAboutOrigin({ - movingPixels, + movingPixels: movingPixelsForReg, resampledMovingPixels, size, candidatesStd, @@ -417,7 +519,18 @@ export async function registerAffine2DWithElastix( const m2fAboutOrigin = best.aboutOrigin; // Quality metrics (computed in fixed space against elastix' resampled moving). - const miResult = computeMutualInformation(fixedPixels, resampledMovingPixels, 64); + const miResult = computeMutualInformation( + fixedPixelsForReg, + resampledMovingPixels, + opts?.exclusionRect + ? { + bins: 64, + exclusionRect: opts.exclusionRect, + imageWidth: size, + imageHeight: size, + } + : 64 + ); const metricFromLogs = tryParseElastixFinalMetricFromLogs(result.stdout, result.stderr); const elastixLogTail = debug diff --git a/frontend/src/utils/imageFeatures.ts b/frontend/src/utils/imageFeatures.ts new file mode 100644 index 0000000..2574a99 --- /dev/null +++ b/frontend/src/utils/imageFeatures.ts @@ -0,0 +1,83 @@ +/** + * Small image-processing helpers for alignment. + * + * Notes: + * - These operate on normalized grayscale Float32 pixels (typically [0..1]). + * - Keep them fast and allocation-light; slice search may call them many times. + */ + +function assertSquareSize(pixels: Float32Array, size: number, label: string) { + const n = size * size; + if (pixels.length !== n) { + throw new Error(`${label}: expected ${size}x${size} (${n}) pixels, got ${pixels.length}`); + } +} + +/** + * Approximate gradient magnitude using a simple central-difference L1 norm: + * |dx| + |dy| + * + * This is cheaper than Sobel and avoids a sqrt. + */ +export function computeGradientMagnitudeL1Square(pixels: Float32Array, size: number): Float32Array { + assertSquareSize(pixels, size, 'computeGradientMagnitudeL1Square'); + + const out = new Float32Array(pixels.length); + if (size <= 2) return out; + + // Leave a 1px border as zeros. + for (let y = 1; y < size - 1; y++) { + const row = y * size; + for (let x = 1; x < size - 1; x++) { + const idx = row + x; + const dx = (pixels[idx + 1] ?? 0) - (pixels[idx - 1] ?? 0); + const dy = (pixels[idx + size] ?? 0) - (pixels[idx - size] ?? 0); + out[idx] = Math.abs(dx) + Math.abs(dy); + } + } + + return out; +} + +export type InclusionMaskBuildResult = { + mask: Uint8Array; + includedCount: number; + includedFrac: number; +}; + +/** + * Build a simple inclusion mask that keeps pixels above a fixed threshold. + * + * Returns null if the mask would be too sparse (so callers can fall back to "no mask"). + */ +export function buildInclusionMaskFromThresholdSquare( + pixels: Float32Array, + size: number, + threshold: number, + opts?: { + /** If includedFrac falls below this, return null. Default: 0.05 (5%). */ + minIncludedFrac?: number; + } +): InclusionMaskBuildResult | null { + assertSquareSize(pixels, size, 'buildInclusionMaskFromThresholdSquare'); + + const minIncludedFrac = opts?.minIncludedFrac ?? 0.05; + + const mask = new Uint8Array(pixels.length); + let includedCount = 0; + + for (let i = 0; i < pixels.length; i++) { + if ((pixels[i] ?? 0) > threshold) { + mask[i] = 1; + includedCount++; + } + } + + const includedFrac = includedCount / Math.max(1, pixels.length); + + if (!Number.isFinite(includedFrac) || includedFrac < minIncludedFrac) { + return null; + } + + return { mask, includedCount, includedFrac }; +} diff --git a/frontend/src/utils/localApi.ts b/frontend/src/utils/localApi.ts index 317416e..1e5cc87 100644 --- a/frontend/src/utils/localApi.ts +++ b/frontend/src/utils/localApi.ts @@ -1,5 +1,14 @@ import { getDB } from '../db/db'; -import type { DicomSeries } from '../db/schema'; +import type { + DicomSeries, + TumorSegmentationRow, + TumorGroundTruthRow, + TumorThreshold, + TumorPolygon, + NormalizedPoint, + ViewerTransform, + ViewportSize, +} from '../db/schema'; import type { ComparisonData, SequenceCombo, SeriesRef, PanelSettingsPartial, PanelSettings } from '../types/api'; import { parseSeriesDescription } from './dicomSeriesParsing'; @@ -313,7 +322,7 @@ function cacheSeriesInstanceOrder(seriesUid: string, uids: string[]) { } } -async function getSortedInstanceUidsForSeries(seriesUid: string): Promise { +export async function getSortedSopInstanceUidsForSeries(seriesUid: string): Promise { const cached = seriesInstanceOrderCache.get(seriesUid); if (cached) { // Touch LRU. @@ -359,9 +368,137 @@ async function getSortedInstanceUidsForSeries(seriesUid: string): Promise { - const uids = await getSortedInstanceUidsForSeries(seriesUid); +export async function getSopInstanceUidForInstanceIndex(seriesUid: string, instanceIndex: number): Promise { + const uids = await getSortedSopInstanceUidsForSeries(seriesUid); const uid = uids[instanceIndex]; if (!uid) throw new Error('Instance index out of range'); + return uid; +} + +export async function getImageIdForInstance(seriesUid: string, instanceIndex: number): Promise { + const uid = await getSopInstanceUidForInstanceIndex(seriesUid, instanceIndex); return `miradb:${uid}`; } + +function tumorSegmentationId(seriesUid: string, sopInstanceUid: string): string { + // Keep this stable and URL-safe. Series UID can contain dots. + return `${seriesUid}::${sopInstanceUid}`; +} + +export async function getTumorSegmentationForInstance( + seriesUid: string, + sopInstanceUid: string +): Promise { + const db = await getDB(); + const id = tumorSegmentationId(seriesUid, sopInstanceUid); + const row = await db.get('tumor_segmentations', id); + return row ?? null; +} + +export async function getTumorSegmentationsForSeries(seriesUid: string): Promise { + const db = await getDB(); + return db.getAllFromIndex('tumor_segmentations', 'by-series', seriesUid); +} + +export type SaveTumorSegmentationInput = { + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + sopInstanceUid: string; + polygon: TumorPolygon; + threshold: TumorThreshold; + seed?: NormalizedPoint; + meta?: TumorSegmentationRow['meta']; + algorithmVersion?: string; +}; + +export async function saveTumorSegmentation(input: SaveTumorSegmentationInput): Promise { + const db = await getDB(); + const now = Date.now(); + + const id = tumorSegmentationId(input.seriesUid, input.sopInstanceUid); + const existing = await db.get('tumor_segmentations', id); + + const row: TumorSegmentationRow = { + id, + comboId: input.comboId, + dateIso: input.dateIso, + studyId: input.studyId, + seriesUid: input.seriesUid, + sopInstanceUid: input.sopInstanceUid, + algorithmVersion: input.algorithmVersion ?? 'v1-display-domain-threshold', + polygon: input.polygon, + threshold: input.threshold, + seed: input.seed, + createdAtMs: existing?.createdAtMs ?? now, + updatedAtMs: now, + meta: input.meta, + }; + + await db.put('tumor_segmentations', row); +} + +export async function deleteTumorSegmentation(seriesUid: string, sopInstanceUid: string): Promise { + const db = await getDB(); + await db.delete('tumor_segmentations', tumorSegmentationId(seriesUid, sopInstanceUid)); +} + +function tumorGroundTruthId(seriesUid: string, sopInstanceUid: string): string { + return `${seriesUid}::${sopInstanceUid}`; +} + +export async function getTumorGroundTruthForInstance( + seriesUid: string, + sopInstanceUid: string +): Promise { + const db = await getDB(); + const id = tumorGroundTruthId(seriesUid, sopInstanceUid); + const row = await db.get('tumor_ground_truth', id); + return row ?? null; +} + +export async function getAllTumorGroundTruth(): Promise { + const db = await getDB(); + return db.getAll('tumor_ground_truth'); +} + +export type SaveTumorGroundTruthInput = { + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + sopInstanceUid: string; + polygon: TumorPolygon; + viewTransform?: ViewerTransform; + viewportSize?: ViewportSize; +}; + +export async function saveTumorGroundTruth(input: SaveTumorGroundTruthInput): Promise { + const db = await getDB(); + const now = Date.now(); + + const id = tumorGroundTruthId(input.seriesUid, input.sopInstanceUid); + const existing = await db.get('tumor_ground_truth', id); + + const row: TumorGroundTruthRow = { + id, + comboId: input.comboId, + dateIso: input.dateIso, + studyId: input.studyId, + seriesUid: input.seriesUid, + sopInstanceUid: input.sopInstanceUid, + polygon: input.polygon, + viewTransform: input.viewTransform, + viewportSize: input.viewportSize, + createdAtMs: existing?.createdAtMs ?? now, + updatedAtMs: now, + }; + + await db.put('tumor_ground_truth', row); +} + +export async function deleteTumorGroundTruth(seriesUid: string, sopInstanceUid: string): Promise { + const db = await getDB(); + await db.delete('tumor_ground_truth', tumorGroundTruthId(seriesUid, sopInstanceUid)); +} diff --git a/frontend/src/utils/mind.ts b/frontend/src/utils/mind.ts new file mode 100644 index 0000000..6456ac1 --- /dev/null +++ b/frontend/src/utils/mind.ts @@ -0,0 +1,334 @@ +import type { ExclusionMask } from '../types/api'; + +export type MindOffset = { dx: number; dy: number }; + +export type MindOptions = { + /** Optional inclusion mask (same shape as images). Keep pixels where mask[idx] != 0. */ + inclusionMask?: Uint8Array; + + /** Optional exclusion rectangle in normalized [0,1] image coordinates. */ + exclusionRect?: ExclusionMask; + + /** Image width in pixels (required if exclusionRect is provided). */ + imageWidth?: number; + /** Image height in pixels (required if exclusionRect is provided). */ + imageHeight?: number; + + /** Patch radius in pixels (default: 1 => 3x3). */ + patchRadius?: number; + + /** Offsets used for the self-similarity descriptor (default: 4-neighborhood at radius 1). */ + offsets?: MindOffset[]; +}; + +export type PreparedMindReference = { + size: number; + offsets: MindOffset[]; + patchRadius: number; + + /** Effective mask used for descriptor construction (1 = included). */ + effectiveMask: Uint8Array; + + /** Descriptor values per pixel (row-major), concatenated as [idx*K + k]. */ + descriptor: Float32Array; + + /** 1 if descriptor was computed for that pixel (else 0). */ + valid: Uint8Array; +}; + +function inferSquareSize(n: number): number { + const s = Math.round(Math.sqrt(n)); + if (s <= 0 || s * s !== n) { + throw new Error('mind: expected square image (provide imageWidth/imageHeight)'); + } + return s; +} + +function buildEffectiveMask(n: number, size: number, opts: MindOptions): Uint8Array { + const inclusionMask = opts.inclusionMask; + if (inclusionMask && inclusionMask.length !== n) { + throw new Error(`mind: inclusionMask length mismatch (mask=${inclusionMask.length}, image=${n})`); + } + + const out = new Uint8Array(n); + if (inclusionMask) { + for (let i = 0; i < n; i++) out[i] = inclusionMask[i] ? 1 : 0; + } else { + out.fill(1); + } + + const exclusionRect = opts.exclusionRect; + if (exclusionRect && size > 0) { + const x0 = Math.floor(exclusionRect.x * size); + const y0 = Math.floor(exclusionRect.y * size); + const x1 = Math.ceil((exclusionRect.x + exclusionRect.width) * size); + const y1 = Math.ceil((exclusionRect.y + exclusionRect.height) * size); + + if (x1 > x0 && y1 > y0) { + for (let y = Math.max(0, y0); y < Math.min(size, y1); y++) { + const row = y * size; + for (let x = Math.max(0, x0); x < Math.min(size, x1); x++) { + out[row + x] = 0; + } + } + } + } + + return out; +} + +function defaultOffsets(): MindOffset[] { + // 2D 4-neighborhood (radius 1). This is a simplified MIND-like descriptor. + return [ + { dx: 1, dy: 0 }, + { dx: -1, dy: 0 }, + { dx: 0, dy: 1 }, + { dx: 0, dy: -1 }, + ]; +} + +/** + * Prepare a reference MIND-like self-similarity descriptor. + * + * This is intentionally a simplified 2D variant (patch SSD to a small set of offsets, normalized + * by local variance). It is meant for experimentation in slice search. + */ +export function prepareMindReference(referencePixels: Float32Array, opts: MindOptions = {}): PreparedMindReference { + const n = referencePixels.length; + if (n === 0) { + return { + size: 0, + offsets: [], + patchRadius: 1, + effectiveMask: new Uint8Array(0), + descriptor: new Float32Array(0), + valid: new Uint8Array(0), + }; + } + + const size = + typeof opts.imageWidth === 'number' && typeof opts.imageHeight === 'number' && opts.imageWidth === opts.imageHeight + ? opts.imageWidth + : inferSquareSize(n); + + const offsets = (opts.offsets && opts.offsets.length > 0 ? opts.offsets : defaultOffsets()).map((o) => ({ + dx: Math.round(o.dx), + dy: Math.round(o.dy), + })); + + const patchRadius = Math.max(1, Math.round(opts.patchRadius ?? 1)); + + const effectiveMask = buildEffectiveMask(n, size, { ...opts, imageWidth: size, imageHeight: size }); + + const k = offsets.length; + const descriptor = new Float32Array(n * k); + const valid = new Uint8Array(n); + + // Determine how close to the border we can compute descriptors. + let maxAbsOffset = 0; + for (const o of offsets) { + maxAbsOffset = Math.max(maxAbsOffset, Math.abs(o.dx), Math.abs(o.dy)); + } + const margin = patchRadius + maxAbsOffset; + + const ssd = new Float64Array(k); + + const eps = 1e-12; + + for (let y = margin; y < size - margin; y++) { + const row = y * size; + for (let x = margin; x < size - margin; x++) { + const centerIdx = row + x; + if (effectiveMask[centerIdx] === 0) continue; + + let ok = true; + + for (let kk = 0; kk < k; kk++) { + const { dx, dy } = offsets[kk]!; + let sum = 0; + let count = 0; + + for (let py = -patchRadius; py <= patchRadius; py++) { + const y1 = y + py; + const y2 = y1 + dy; + const row1 = y1 * size; + const row2 = y2 * size; + + for (let px = -patchRadius; px <= patchRadius; px++) { + const x1 = x + px; + const x2 = x1 + dx; + + const idx1 = row1 + x1; + const idx2 = row2 + x2; + + if (effectiveMask[idx1] === 0 || effectiveMask[idx2] === 0) continue; + + const a = referencePixels[idx1]!; + const b = referencePixels[idx2]!; + const d = a - b; + sum += d * d; + count++; + } + } + + if (count === 0) { + ok = false; + break; + } + + // Use mean squared difference so patch size doesn't change the scale. + ssd[kk] = sum / count; + } + + if (!ok) continue; + + // Local variance estimate: mean SSD across offsets. + let v = 0; + for (let kk = 0; kk < k; kk++) v += ssd[kk]!; + v /= k; + if (v < eps) v = eps; + + // Build descriptor and normalize so the max component is 1. + let maxD = 0; + const base = centerIdx * k; + for (let kk = 0; kk < k; kk++) { + const d = Math.exp(-ssd[kk]! / v); + descriptor[base + kk] = d; + if (d > maxD) maxD = d; + } + + if (maxD > eps) { + const inv = 1 / maxD; + for (let kk = 0; kk < k; kk++) { + descriptor[base + kk] *= inv; + } + } + + valid[centerIdx] = 1; + } + } + + return { size, offsets, patchRadius, effectiveMask, descriptor, valid }; +} + +/** + * Compute similarity between a prepared reference descriptor and a target image. + * + * Returns a similarity in (0..1], where 1 means identical descriptors. + */ +export function computeMindSimilarity( + prepared: PreparedMindReference, + targetPixels: Float32Array +): { mind: number; pixelsUsed: number } { + const size = prepared.size; + if (size <= 0) return { mind: 0, pixelsUsed: 0 }; + + const n = size * size; + if (targetPixels.length !== n) { + throw new Error(`mind: target size mismatch (expected ${n}, got ${targetPixels.length})`); + } + + const offsets = prepared.offsets; + const k = offsets.length; + const patchRadius = prepared.patchRadius; + + // Determine compute margin based on offsets + patch radius. + let maxAbsOffset = 0; + for (const o of offsets) { + maxAbsOffset = Math.max(maxAbsOffset, Math.abs(o.dx), Math.abs(o.dy)); + } + const margin = patchRadius + maxAbsOffset; + + const ssd = new Float64Array(k); + + const eps = 1e-12; + + let distSum = 0; + let used = 0; + + for (let y = margin; y < size - margin; y++) { + const row = y * size; + for (let x = margin; x < size - margin; x++) { + const centerIdx = row + x; + if (prepared.valid[centerIdx] === 0) continue; + if (prepared.effectiveMask[centerIdx] === 0) continue; + + let ok = true; + + for (let kk = 0; kk < k; kk++) { + const { dx, dy } = offsets[kk]!; + let sum = 0; + let count = 0; + + for (let py = -patchRadius; py <= patchRadius; py++) { + const y1 = y + py; + const y2 = y1 + dy; + const row1 = y1 * size; + const row2 = y2 * size; + + for (let px = -patchRadius; px <= patchRadius; px++) { + const x1 = x + px; + const x2 = x1 + dx; + + const idx1 = row1 + x1; + const idx2 = row2 + x2; + + if (prepared.effectiveMask[idx1] === 0 || prepared.effectiveMask[idx2] === 0) continue; + + const a = targetPixels[idx1]!; + const b = targetPixels[idx2]!; + const d = a - b; + sum += d * d; + count++; + } + } + + if (count === 0) { + ok = false; + break; + } + + ssd[kk] = sum / count; + } + + if (!ok) continue; + + // Local variance estimate. + let v = 0; + for (let kk = 0; kk < k; kk++) v += ssd[kk]!; + v /= k; + if (v < eps) v = eps; + + // Build target descriptor and compare against the precomputed reference. + let maxD = 0; + for (let kk = 0; kk < k; kk++) { + const d = Math.exp(-ssd[kk]! / v); + ssd[kk] = d; + if (d > maxD) maxD = d; + } + + if (maxD <= eps) continue; + const invMax = 1 / maxD; + + const base = centerIdx * k; + + let perPixel = 0; + for (let kk = 0; kk < k; kk++) { + const t = (ssd[kk]! as number) * invMax; + const r = prepared.descriptor[base + kk]!; + const diff = t - r; + perPixel += diff * diff; + } + + distSum += perPixel / k; + used++; + } + } + + if (used === 0) return { mind: 0, pixelsUsed: 0 }; + + const meanDist = distSum / used; + const mind = Math.exp(-meanDist); + + return { mind, pixelsUsed: used }; +} diff --git a/frontend/src/utils/mutualInformation.ts b/frontend/src/utils/mutualInformation.ts index c81afa1..5b00136 100644 --- a/frontend/src/utils/mutualInformation.ts +++ b/frontend/src/utils/mutualInformation.ts @@ -23,11 +23,21 @@ import { clamp01 } from './math'; export type MutualInformationOptions = { /** Number of histogram bins (default: 64). */ bins?: number; + + /** + * Optional inclusion mask. + * + * If provided, only pixels where inclusionMask[idx] != 0 are used. + * This is useful for ignoring background / low-information regions during slice search. + */ + inclusionMask?: Uint8Array; + /** * Optional exclusion rectangle in normalized [0,1] image coordinates. * Pixels inside this rect are excluded from the histogram computation. */ exclusionRect?: { x: number; y: number; width: number; height: number }; + /** Image width in pixels (required if exclusionRect is provided). */ imageWidth?: number; /** Image height in pixels (required if exclusionRect is provided). */ @@ -70,6 +80,7 @@ export function computeMutualInformation( typeof optionsOrBins === 'number' ? { bins: optionsOrBins } : optionsOrBins; const bins = opts.bins ?? 64; + const inclusionMask = opts.inclusionMask; const exclusionRect = opts.exclusionRect; const imageWidth = opts.imageWidth; const imageHeight = opts.imageHeight; @@ -98,12 +109,20 @@ export function computeMutualInformation( hasExclusion = exclX1 > exclX0 && exclY1 > exclY0; } - // Helper to check if pixel index is inside exclusion rect. - const isExcluded = (idx: number): boolean => { - if (!hasExclusion) return false; + if (inclusionMask && inclusionMask.length !== n) { + throw new Error( + `computeMutualInformation: inclusionMask length mismatch (mask=${inclusionMask.length}, image=${n})` + ); + } + + // Helper to check if pixel index is included by masks. + const isIncluded = (idx: number): boolean => { + if (inclusionMask && inclusionMask[idx] === 0) return false; + + if (!hasExclusion) return true; const px = idx % imgW; const py = Math.floor(idx / imgW); - return px >= exclX0 && px < exclX1 && py >= exclY0 && py < exclY1; + return !(px >= exclX0 && px < exclX1 && py >= exclY0 && py < exclY1); }; let minA = Number.POSITIVE_INFINITY; @@ -112,7 +131,7 @@ export function computeMutualInformation( let maxB = Number.NEGATIVE_INFINITY; for (let i = 0; i < n; i++) { - if (isExcluded(i)) continue; + if (!isIncluded(i)) continue; const a = imageA[i]; const b = imageB[i]; @@ -145,7 +164,7 @@ export function computeMutualInformation( let pixelsUsed = 0; for (let i = 0; i < n; i++) { - if (isExcluded(i)) continue; + if (!isIncluded(i)) continue; const a = imageA[i]; const b = imageB[i]; diff --git a/frontend/src/utils/phaseCorrelation.ts b/frontend/src/utils/phaseCorrelation.ts new file mode 100644 index 0000000..464322a --- /dev/null +++ b/frontend/src/utils/phaseCorrelation.ts @@ -0,0 +1,416 @@ +import type { ExclusionMask } from '../types/api'; + +export type PhaseCorrelationOptions = { + /** Optional inclusion mask (same shape as images). Keep pixels where mask[idx] != 0. */ + inclusionMask?: Uint8Array; + + /** Optional exclusion rectangle in normalized [0,1] image coordinates. */ + exclusionRect?: ExclusionMask; + + /** Image width in pixels (required if exclusionRect is provided). */ + imageWidth?: number; + /** Image height in pixels (required if exclusionRect is provided). */ + imageHeight?: number; + + /** Apply a Hann window before FFT to reduce edge artifacts (default: true). */ + window?: boolean; +}; + +export type PreparedPhaseCorrelationReference = { + size: number; + window1d: Float32Array; + effectiveMask: Uint8Array; + pixelsUsed: number; + refFRe: Float32Array; + refFIm: Float32Array; +}; + +export type PhaseCorrelationScratch = { + size: number; + // Target FFT buffers + targetRe: Float32Array; + targetIm: Float32Array; + + // Cross-power spectrum buffers (reused for IFFT result) + crossRe: Float32Array; + crossIm: Float32Array; + + // Temp buffers for 1D FFTs + tmpRe: Float32Array; + tmpIm: Float32Array; + tmp2Re: Float32Array; + tmp2Im: Float32Array; +}; + +function inferSquareSize(n: number): number { + const s = Math.round(Math.sqrt(n)); + if (s <= 0 || s * s !== n) { + throw new Error('phaseCorrelation: expected square image (provide imageWidth/imageHeight)'); + } + return s; +} + +function isPowerOfTwo(n: number): boolean { + return n > 0 && (n & (n - 1)) === 0; +} + +function buildEffectiveMask(n: number, size: number, opts: PhaseCorrelationOptions): { mask: Uint8Array; pixelsUsed: number } { + const inclusionMask = opts.inclusionMask; + if (inclusionMask && inclusionMask.length !== n) { + throw new Error(`phaseCorrelation: inclusionMask length mismatch (mask=${inclusionMask.length}, image=${n})`); + } + + const out = new Uint8Array(n); + if (inclusionMask) { + for (let i = 0; i < n; i++) out[i] = inclusionMask[i] ? 1 : 0; + } else { + out.fill(1); + } + + const exclusionRect = opts.exclusionRect; + if (exclusionRect && size > 0) { + const x0 = Math.floor(exclusionRect.x * size); + const y0 = Math.floor(exclusionRect.y * size); + const x1 = Math.ceil((exclusionRect.x + exclusionRect.width) * size); + const y1 = Math.ceil((exclusionRect.y + exclusionRect.height) * size); + + if (x1 > x0 && y1 > y0) { + for (let y = Math.max(0, y0); y < Math.min(size, y1); y++) { + const row = y * size; + for (let x = Math.max(0, x0); x < Math.min(size, x1); x++) { + out[row + x] = 0; + } + } + } + } + + let pixelsUsed = 0; + for (let i = 0; i < n; i++) { + if (out[i]) pixelsUsed++; + } + + return { mask: out, pixelsUsed }; +} + +function buildHannWindow1d(n: number, enabled: boolean): Float32Array { + const w = new Float32Array(n); + if (!enabled) { + w.fill(1); + return w; + } + + if (n <= 1) { + w.fill(1); + return w; + } + + const denom = n - 1; + for (let i = 0; i < n; i++) { + w[i] = 0.5 * (1 - Math.cos((2 * Math.PI * i) / denom)); + } + return w; +} + +function fftRadix2InPlace(re: Float32Array, im: Float32Array, inverse: boolean): void { + const n = re.length; + if (im.length !== n) throw new Error('phaseCorrelation: fft buffer size mismatch'); + if (!isPowerOfTwo(n)) throw new Error(`phaseCorrelation: fft length must be power of two (got ${n})`); + + // Bit reversal permutation. + for (let i = 1, j = 0; i < n; i++) { + let bit = n >> 1; + while (j & bit) { + j ^= bit; + bit >>= 1; + } + j ^= bit; + + if (i < j) { + const tr = re[i]!; + re[i] = re[j]!; + re[j] = tr; + + const ti = im[i]!; + im[i] = im[j]!; + im[j] = ti; + } + } + + for (let len = 2; len <= n; len <<= 1) { + const ang = (2 * Math.PI) / len * (inverse ? 1 : -1); + const wlenRe = Math.cos(ang); + const wlenIm = Math.sin(ang); + + for (let i = 0; i < n; i += len) { + let wRe = 1; + let wIm = 0; + + const half = len >> 1; + for (let j = 0; j < half; j++) { + const uRe = re[i + j]!; + const uIm = im[i + j]!; + + const vRe0 = re[i + j + half]!; + const vIm0 = im[i + j + half]!; + + // v = v0 * w + const vRe = vRe0 * wRe - vIm0 * wIm; + const vIm = vRe0 * wIm + vIm0 * wRe; + + re[i + j] = uRe + vRe; + im[i + j] = uIm + vIm; + re[i + j + half] = uRe - vRe; + im[i + j + half] = uIm - vIm; + + // w *= wlen + const nextWRe = wRe * wlenRe - wIm * wlenIm; + const nextWIm = wRe * wlenIm + wIm * wlenRe; + wRe = nextWRe; + wIm = nextWIm; + } + } + } + + if (inverse) { + const invN = 1 / n; + for (let i = 0; i < n; i++) { + re[i] = (re[i] ?? 0) * invN; + im[i] = (im[i] ?? 0) * invN; + } + } +} + +function fft2dInPlace( + re: Float32Array, + im: Float32Array, + size: number, + inverse: boolean, + scratch: { tmpRe: Float32Array; tmpIm: Float32Array; tmp2Re: Float32Array; tmp2Im: Float32Array } +): void { + const n = size * size; + if (re.length !== n || im.length !== n) { + throw new Error('phaseCorrelation: fft2d buffers size mismatch'); + } + + const rowRe = scratch.tmpRe; + const rowIm = scratch.tmpIm; + const colRe = scratch.tmp2Re; + const colIm = scratch.tmp2Im; + + if (rowRe.length !== size || rowIm.length !== size || colRe.length !== size || colIm.length !== size) { + throw new Error('phaseCorrelation: fft2d scratch size mismatch'); + } + + // Rows + for (let y = 0; y < size; y++) { + const row = y * size; + for (let x = 0; x < size; x++) { + rowRe[x] = re[row + x]!; + rowIm[x] = im[row + x]!; + } + + fftRadix2InPlace(rowRe, rowIm, inverse); + + for (let x = 0; x < size; x++) { + re[row + x] = rowRe[x]!; + im[row + x] = rowIm[x]!; + } + } + + // Columns + for (let x = 0; x < size; x++) { + for (let y = 0; y < size; y++) { + const idx = y * size + x; + colRe[y] = re[idx]!; + colIm[y] = im[idx]!; + } + + fftRadix2InPlace(colRe, colIm, inverse); + + for (let y = 0; y < size; y++) { + const idx = y * size + x; + re[idx] = colRe[y]!; + im[idx] = colIm[y]!; + } + } +} + +function fillPreprocessedReal( + outRe: Float32Array, + outIm: Float32Array, + pixels: Float32Array, + size: number, + mask: Uint8Array, + window1d: Float32Array +): void { + const n = size * size; + if (pixels.length !== n || outRe.length !== n || outIm.length !== n || mask.length !== n) { + throw new Error('phaseCorrelation: preprocess size mismatch'); + } + + // Mean over included pixels. + let sum = 0; + let count = 0; + for (let i = 0; i < n; i++) { + if (!mask[i]) continue; + const v = pixels[i]!; + if (!Number.isFinite(v)) continue; + sum += v; + count++; + } + const mean = count > 0 ? sum / count : 0; + + for (let y = 0; y < size; y++) { + const wy = window1d[y]!; + const row = y * size; + for (let x = 0; x < size; x++) { + const idx = row + x; + if (!mask[idx]) { + outRe[idx] = 0; + outIm[idx] = 0; + continue; + } + const wx = window1d[x]!; + const w = wx * wy; + const v = pixels[idx]!; + outRe[idx] = (Number.isFinite(v) ? v - mean : -mean) * w; + outIm[idx] = 0; + } + } +} + +export function createPhaseCorrelationScratch(size: number): PhaseCorrelationScratch { + const n = size * size; + return { + size, + targetRe: new Float32Array(n), + targetIm: new Float32Array(n), + crossRe: new Float32Array(n), + crossIm: new Float32Array(n), + tmpRe: new Float32Array(size), + tmpIm: new Float32Array(size), + tmp2Re: new Float32Array(size), + tmp2Im: new Float32Array(size), + }; +} + +export function preparePhaseCorrelationReference( + referencePixels: Float32Array, + opts: PhaseCorrelationOptions = {} +): PreparedPhaseCorrelationReference { + const n = referencePixels.length; + if (n === 0) { + return { + size: 0, + window1d: new Float32Array(0), + effectiveMask: new Uint8Array(0), + pixelsUsed: 0, + refFRe: new Float32Array(0), + refFIm: new Float32Array(0), + }; + } + + const size = + typeof opts.imageWidth === 'number' && typeof opts.imageHeight === 'number' && opts.imageWidth === opts.imageHeight + ? opts.imageWidth + : inferSquareSize(n); + + if (!isPowerOfTwo(size)) { + throw new Error(`phaseCorrelation: size must be power of two (got ${size})`); + } + + const { mask: effectiveMask, pixelsUsed } = buildEffectiveMask(n, size, { ...opts, imageWidth: size, imageHeight: size }); + + const window1d = buildHannWindow1d(size, opts.window ?? true); + + const refFRe = new Float32Array(n); + const refFIm = new Float32Array(n); + + // Preprocess: mean-subtract, mask, and window. + fillPreprocessedReal(refFRe, refFIm, referencePixels, size, effectiveMask, window1d); + + // FFT reference. + const scratch = createPhaseCorrelationScratch(size); + fft2dInPlace(refFRe, refFIm, size, false, scratch); + + return { size, window1d, effectiveMask, pixelsUsed, refFRe, refFIm }; +} + +/** + * Compute phase correlation similarity between a prepared reference and a target. + * + * Returns `phase` as the peak value of the phase-only correlation surface (higher is better). + * + * Notes: + * - This is primarily a translation-focused similarity. In our slice-search pipeline, the candidate + * slice is already pre-warped by a seed affine transform. + */ +export function computePhaseCorrelationSimilarity( + prepared: PreparedPhaseCorrelationReference, + targetPixels: Float32Array, + scratch: PhaseCorrelationScratch +): { phase: number; pixelsUsed: number } { + const size = prepared.size; + if (size <= 0) return { phase: 0, pixelsUsed: 0 }; + + if (scratch.size !== size) { + throw new Error('phaseCorrelation: scratch size mismatch'); + } + + const n = size * size; + if (targetPixels.length !== n) { + throw new Error(`phaseCorrelation: target size mismatch (expected ${n}, got ${targetPixels.length})`); + } + + // Target FFT buffers. + fillPreprocessedReal( + scratch.targetRe, + scratch.targetIm, + targetPixels, + size, + prepared.effectiveMask, + prepared.window1d + ); + + fft2dInPlace(scratch.targetRe, scratch.targetIm, size, false, scratch); + + // Cross-power spectrum: R = F_ref * conj(F_tgt) / |F_ref * conj(F_tgt)| + const eps = 1e-12; + for (let i = 0; i < n; i++) { + const aRe = prepared.refFRe[i]!; + const aIm = prepared.refFIm[i]!; + const bRe = scratch.targetRe[i]!; + const bIm = scratch.targetIm[i]!; + + // a * conj(b) + const cRe = aRe * bRe + aIm * bIm; + const cIm = aIm * bRe - aRe * bIm; + + const mag = Math.sqrt(cRe * cRe + cIm * cIm); + if (mag > eps) { + const inv = 1 / mag; + scratch.crossRe[i] = cRe * inv; + scratch.crossIm[i] = cIm * inv; + } else { + scratch.crossRe[i] = 0; + scratch.crossIm[i] = 0; + } + } + + // Inverse FFT to get correlation surface. + fft2dInPlace(scratch.crossRe, scratch.crossIm, size, true, scratch); + + // Find peak (real part). + let peak = Number.NEGATIVE_INFINITY; + for (let i = 0; i < n; i++) { + const v = scratch.crossRe[i]!; + if (v > peak) peak = v; + } + + if (!Number.isFinite(peak)) peak = 0; + + // Phase correlation peak should be in ~[0..1]. Clamp for sanity. + const phase = Math.max(0, Math.min(1, peak)); + + return { phase, pixelsUsed: prepared.pixelsUsed }; +} diff --git a/frontend/src/utils/segmentation/geodesicDistance.ts b/frontend/src/utils/segmentation/geodesicDistance.ts new file mode 100644 index 0000000..e046d58 --- /dev/null +++ b/frontend/src/utils/segmentation/geodesicDistance.ts @@ -0,0 +1,195 @@ +export type Roi = { x0: number; y0: number; x1: number; y1: number }; + +type HeapItem = { idx: number; d: number }; + +class MinHeap { + private items: HeapItem[] = []; + + push(item: HeapItem) { + const a = this.items; + a.push(item); + let i = a.length - 1; + while (i > 0) { + const p = (i - 1) >> 1; + if (a[p]!.d <= a[i]!.d) break; + const tmp = a[p]!; + a[p] = a[i]!; + a[i] = tmp; + i = p; + } + } + + pop(): HeapItem | null { + const a = this.items; + const n = a.length; + if (n === 0) return null; + + const out = a[0]!; + const last = a.pop()!; + if (n > 1) { + a[0] = last; + + let i = 0; + for (;;) { + const l = i * 2 + 1; + const r = l + 1; + let smallest = i; + + if (l < a.length && a[l]!.d < a[smallest]!.d) smallest = l; + if (r < a.length && a[r]!.d < a[smallest]!.d) smallest = r; + + if (smallest === i) break; + const tmp = a[i]!; + a[i] = a[smallest]!; + a[smallest] = tmp; + i = smallest; + } + } + + return out; + } + + get size() { + return this.items.length; + } +} + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +/** + * Compute an edge-aware geodesic distance from seed pixels within an ROI. + * + * The cost to step into a pixel is: + * 1 + edgeCostStrength * (grad/255) + * + * This makes crossing strong edges more expensive, which helps prevent leakage + * across boundaries without requiring the user to paint negative/background strokes. + */ +export function computeGeodesicDistanceToSeeds(params: { + w: number; + h: number; + roi: Roi; + /** Seed pixels (image pixel coords). */ + seeds: Array<{ x: number; y: number }>; + /** Gradient magnitude image (0..255). If omitted, treated as all zeros. */ + grad?: Uint8Array; + edgeCostStrength: number; + /** + * Optional onset for edge costs (0..255). + * + * If provided, gradients below this barrier are treated as "not an edge" (zero extra cost), + * and only gradients above the barrier increase step cost. + */ + edgeBarrier?: number; + /** Optional cutoff: distances beyond this are not expanded (remain Infinity). */ + maxDist?: number; +}): Float32Array { + const { w, h } = params; + const dist = new Float32Array(w * h); + dist.fill(Number.POSITIVE_INFINITY); + + if (w <= 0 || h <= 0) return dist; + + const x0 = clamp(Math.floor(params.roi.x0), 0, w - 1); + const y0 = clamp(Math.floor(params.roi.y0), 0, h - 1); + const x1 = clamp(Math.ceil(params.roi.x1), 0, w - 1); + const y1 = clamp(Math.ceil(params.roi.y1), 0, h - 1); + + const grad = params.grad; + const k = Math.max(0, params.edgeCostStrength); + const maxDist = typeof params.maxDist === 'number' && Number.isFinite(params.maxDist) ? params.maxDist : null; + + const barrier = + typeof params.edgeBarrier === 'number' && Number.isFinite(params.edgeBarrier) + ? clamp(params.edgeBarrier, 0, 255) + : null; + + const heap = new MinHeap(); + + // Seed the heap. + for (const s of params.seeds) { + const sx = clamp(Math.round(s.x), x0, x1); + const sy = clamp(Math.round(s.y), y0, y1); + const idx = sy * w + sx; + if (dist[idx] === 0) continue; + dist[idx] = 0; + heap.push({ idx, d: 0 }); + } + + if (heap.size === 0) return dist; + + const stepCost = (idx: number) => { + const gRaw = grad ? grad[idx] ?? 0 : 0; + + const edgeFrac = (() => { + // Default behavior (no barrier): treat grad as a continuous 0..255 edge weight. + if (barrier == null) return gRaw / 255; + + // Barrier behavior: only penalize gradients above the onset. + // This avoids making *all* mild texture act like a distance wall. + if (gRaw <= barrier) return 0; + + const denom = Math.max(1, 255 - barrier); + return clamp((gRaw - barrier) / denom, 0, 1); + })(); + + return 1 + k * edgeFrac; + }; + + while (heap.size > 0) { + const item = heap.pop(); + if (!item) break; + + const d = item.d; + const idx = item.idx; + + // Skip stale heap entries. + if (d !== dist[idx]) continue; + + if (maxDist != null && d > maxDist) { + // With positive costs, all remaining paths will be >= d, so we can stop. + break; + } + + const x = idx % w; + const y = (idx - x) / w; + + // 4-neighborhood within ROI. + if (x > x0) { + const ni = idx - 1; + const nd = d + stepCost(ni); + if (nd < dist[ni]) { + dist[ni] = nd; + heap.push({ idx: ni, d: nd }); + } + } + if (x < x1) { + const ni = idx + 1; + const nd = d + stepCost(ni); + if (nd < dist[ni]) { + dist[ni] = nd; + heap.push({ idx: ni, d: nd }); + } + } + if (y > y0) { + const ni = idx - w; + const nd = d + stepCost(ni); + if (nd < dist[ni]) { + dist[ni] = nd; + heap.push({ idx: ni, d: nd }); + } + } + if (y < y1) { + const ni = idx + w; + const nd = d + stepCost(ni); + if (nd < dist[ni]) { + dist[ni] = nd; + heap.push({ idx: ni, d: nd }); + } + } + } + + return dist; +} diff --git a/frontend/src/utils/segmentation/gtBenchmark.ts b/frontend/src/utils/segmentation/gtBenchmark.ts new file mode 100644 index 0000000..cd74707 --- /dev/null +++ b/frontend/src/utils/segmentation/gtBenchmark.ts @@ -0,0 +1,614 @@ +import cornerstone from 'cornerstone-core'; +import type { NormalizedPoint, TumorPolygon, TumorThreshold, ViewerTransform } from '../../db/schema'; +import { + estimateThresholdFromSeedPoints, + segmentTumorFromGrayscale, + type SegmentTumorOptions, +} from './segmentTumor'; +import { remapPolygonToImage01 } from './harness/canonicalize'; +import { computeMaskMetrics, type MaskMetrics } from './maskMetrics'; +import { + computePolygonBoundaryMetrics, + type PolygonBoundaryMetrics, +} from './polygonBoundaryMetrics'; +import { rasterizePolygonToMask } from './rasterizePolygon'; + +type CornerstoneImageLike = { + rows: number; + columns: number; + getPixelData: () => ArrayLike; + minPixelValue?: number; + maxPixelValue?: number; +}; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function clamp01(v: number) { + return clamp(v, 0, 1); +} + +function safeViewportSize(v?: { w: number; h: number } | null): { w: number; h: number } { + const w = Math.max(1, Math.round(Number.isFinite(v?.w) ? v!.w : 0)); + const h = Math.max(1, Math.round(Number.isFinite(v?.h) ? v!.h : 0)); + return { w, h }; +} + +function toByte(v: number) { + return Math.max(0, Math.min(255, Math.round(v))); +} + +function hashStringToSeed(s: string): number { + // FNV-1a 32-bit + let h = 2166136261; + for (let i = 0; i < s.length; i++) { + h ^= s.charCodeAt(i); + h = Math.imul(h, 16777619); + } + return h >>> 0; +} + +function makeLcg(seed: number) { + let state = seed >>> 0; + return () => { + state = (Math.imul(1664525, state) + 1013904223) >>> 0; + return state / 4294967296; + }; +} + +function polygonBounds01(poly: TumorPolygon): { minX: number; minY: number; maxX: number; maxY: number } { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const p of poly.points) { + minX = Math.min(minX, p.x); + minY = Math.min(minY, p.y); + maxX = Math.max(maxX, p.x); + maxY = Math.max(maxY, p.y); + } + + return { + minX: clamp01(minX), + minY: clamp01(minY), + maxX: clamp01(maxX), + maxY: clamp01(maxY), + }; +} + +function pointInPolygon(pt: NormalizedPoint, poly: TumorPolygon): boolean { + // Ray casting. + const pts = poly.points; + const n = pts.length; + if (n < 3) return false; + + let inside = false; + for (let i = 0, j = n - 1; i < n; j = i++) { + const a = pts[i]!; + const b = pts[j]!; + const intersects = + a.y > pt.y !== b.y > pt.y && + pt.x < ((b.x - a.x) * (pt.y - a.y)) / (b.y - a.y + 1e-12) + a.x; + + if (intersects) inside = !inside; + } + + return inside; +} + +function polygonAreaCentroid01(poly: TumorPolygon): NormalizedPoint { + const pts = poly.points; + const n = pts.length; + if (n < 3) { + // Fallback: average. + let sx = 0; + let sy = 0; + for (const p of pts) { + sx += p.x; + sy += p.y; + } + const d = Math.max(1, n); + return { x: clamp01(sx / d), y: clamp01(sy / d) }; + } + + // Polygon centroid (shoelace). Works for simple polygons. + let a2 = 0; + let cx = 0; + let cy = 0; + + for (let i = 0; i < n; i++) { + const p0 = pts[i]!; + const p1 = pts[(i + 1) % n]!; + const cross = p0.x * p1.y - p1.x * p0.y; + a2 += cross; + cx += (p0.x + p1.x) * cross; + cy += (p0.y + p1.y) * cross; + } + + if (Math.abs(a2) < 1e-10) { + // Degenerate polygon. + let sx = 0; + let sy = 0; + for (const p of pts) { + sx += p.x; + sy += p.y; + } + const d = Math.max(1, n); + return { x: clamp01(sx / d), y: clamp01(sy / d) }; + } + + const inv6a = 1 / (3 * a2); + return { x: clamp01(cx * inv6a), y: clamp01(cy * inv6a) }; +} + +function findInteriorPoint01(poly: TumorPolygon): NormalizedPoint { + const c = polygonAreaCentroid01(poly); + if (pointInPolygon(c, poly)) return c; + + // Try bbox center. + const b = polygonBounds01(poly); + const mid = { x: (b.minX + b.maxX) / 2, y: (b.minY + b.maxY) / 2 }; + if (pointInPolygon(mid, poly)) return mid; + + // Brute force a small grid search. + const steps = 9; + for (let yi = 0; yi < steps; yi++) { + for (let xi = 0; xi < steps; xi++) { + const x = b.minX + ((xi + 0.5) / steps) * (b.maxX - b.minX); + const y = b.minY + ((yi + 0.5) / steps) * (b.maxY - b.minY); + const p = { x, y }; + if (pointInPolygon(p, poly)) return p; + } + } + + // Give up: return clamped centroid. + return c; +} + +function generatePaintPointsFromGt( + gt: TumorPolygon, + seedKey: string, + targetCount: number +): NormalizedPoint[] { + const pts: NormalizedPoint[] = []; + const seed = findInteriorPoint01(gt); + + // Always include a small cross around the seed for robustness. + const j = 0.004; + pts.push(seed); + pts.push({ x: clamp01(seed.x + j), y: seed.y }); + pts.push({ x: clamp01(seed.x - j), y: seed.y }); + pts.push({ x: seed.x, y: clamp01(seed.y + j) }); + pts.push({ x: seed.x, y: clamp01(seed.y - j) }); + + const b = polygonBounds01(gt); + const rand = makeLcg(hashStringToSeed(seedKey)); + + const want = Math.max(8, targetCount); + const maxAttempts = want * 80; + + for (let attempt = 0; attempt < maxAttempts && pts.length < want; attempt++) { + // Bias sampling toward the seed by mixing uniform bbox with a seed-centered jitter. + const mix = rand(); + + let x: number; + let y: number; + + if (mix < 0.7) { + // Seed-centered jitter (roughly "scribble" sized). + const r = 0.03; + x = clamp01(seed.x + (rand() * 2 - 1) * r); + y = clamp01(seed.y + (rand() * 2 - 1) * r); + } else { + // Uniform in bbox. + x = b.minX + rand() * (b.maxX - b.minX); + y = b.minY + rand() * (b.maxY - b.minY); + } + + const p = { x, y }; + if (pointInPolygon(p, gt)) { + pts.push(p); + } + } + + return pts; +} + +function computeMaskMetricsFromCounts(tp: number, fp: number, fn: number, tn: number): MaskMetrics { + const safeDiv = (num: number, den: number) => (den > 0 ? num / den : 0); + + const precision = safeDiv(tp, tp + fp); + const recall = safeDiv(tp, tp + fn); + const dice = safeDiv(2 * tp, 2 * tp + fp + fn); + const iou = safeDiv(tp, tp + fp + fn); + + const beta2 = 4; + const f2 = safeDiv((1 + beta2) * precision * recall, beta2 * precision + recall); + + return { tp, fp, fn, tn, precision, recall, dice, iou, f2 }; +} + +async function yieldToUi() { + await new Promise((resolve) => { + (globalThis.setTimeout ?? setTimeout)(resolve, 0); + }); +} + +export type GtBenchmarkCase = { + id: string; + comboId: string; + dateIso: string; + seriesUid: string; + sopInstanceUid: string; + + // GT polygon is stored in viewer-normalized coordinates. + gtPolygon: TumorPolygon; + + // Optional metadata needed to canonicalize GT into image coordinates. + gtViewTransform?: ViewerTransform; + gtViewportSize?: { w: number; h: number }; +}; + +export type GtBenchmarkConfig = { + name: string; + opts?: SegmentTumorOptions; +}; + +export type GtBenchmarkCaseConfigResult = { + ok: boolean; + error?: string; + threshold?: TumorThreshold; + metrics?: MaskMetrics; + boundary?: PolygonBoundaryMetrics; + predPolygonPointCount?: number; + timingMs?: { + segment: number; + evaluate: number; + }; +}; + +export type GtBenchmarkCaseResult = { + id: string; + comboId: string; + dateIso: string; + seriesUid: string; + sopInstanceUid: string; + image: { + imageId: string; + sourceW: number; + sourceH: number; + evalW: number; + evalH: number; + }; + paintPointsCount: number; + resultsByConfig: Record; + timingMs: { + loadImage: number; + total: number; + }; +}; + +export type GtBenchmarkSummary = { + config: string; + casesTotal: number; + casesOk: number; + casesError: number; + micro: MaskMetrics; + boundary: { + meanPredToGtPx: number; + meanGtToPredPx: number; + meanSymPx: number; + maxSymPx: number; + count: number; + }; +}; + +export type GtBenchmarkReport = { + version: 1; + generatedAtIso: string; + settings: { + maxEvalDim: number; + paintPointsPerCase: number; + }; + configs: Array<{ name: string; opts?: SegmentTumorOptions }>; + summary: GtBenchmarkSummary[]; + cases: GtBenchmarkCaseResult[]; + note: string; +}; + +async function loadAndNormalizeImage( + sopInstanceUid: string, + maxEvalDim: number +): Promise<{ imageId: string; gray: Uint8Array; w: number; h: number; sourceW: number; sourceH: number }> { + const imageId = `miradb:${sopInstanceUid}`; + const image = (await cornerstone.loadImage(imageId)) as unknown as CornerstoneImageLike; + + const rows = image.rows; + const cols = image.columns; + const getPixelData = image.getPixelData; + if (!rows || !cols || typeof getPixelData !== 'function') { + throw new Error('Cornerstone image missing pixel data'); + } + + const pd = getPixelData(); + + let min = image.minPixelValue; + let max = image.maxPixelValue; + + if (!Number.isFinite(min) || !Number.isFinite(max)) { + min = Number.POSITIVE_INFINITY; + max = Number.NEGATIVE_INFINITY; + for (let i = 0; i < pd.length; i++) { + const v = pd[i]; + if (v < min) min = v; + if (v > max) max = v; + } + } + + const denom = (max as number) - (min as number); + + // Downsample for speed, preserving aspect ratio. + const scale = Math.max(cols, rows) / Math.max(16, maxEvalDim); + const w = scale > 1 ? Math.max(16, Math.round(cols / scale)) : cols; + const h = scale > 1 ? Math.max(16, Math.round(rows / scale)) : rows; + + const gray = new Uint8Array(w * h); + + if (!Number.isFinite(denom) || Math.abs(denom) < 1e-8) { + gray.fill(0); + return { imageId, gray, w, h, sourceW: cols, sourceH: rows }; + } + + for (let y = 0; y < h; y++) { + const sy = h <= 1 ? 0 : Math.round((y * (rows - 1)) / (h - 1)); + for (let x = 0; x < w; x++) { + const sx = w <= 1 ? 0 : Math.round((x * (cols - 1)) / (w - 1)); + const v = pd[sy * cols + sx]; + const t = ((v - (min as number)) / denom) * 255; + gray[y * w + x] = toByte(t); + } + } + + return { imageId, gray, w, h, sourceW: cols, sourceH: rows }; +} + +export type RunGtBenchmarkInput = { + cases: GtBenchmarkCase[]; + configs: GtBenchmarkConfig[]; + maxEvalDim?: number; + paintPointsPerCase?: number; + yieldEveryCases?: number; + onProgress?: (p: { caseIndex: number; caseCount: number; configName?: string; message: string }) => void; +}; + +export async function runGtBenchmark(input: RunGtBenchmarkInput): Promise { + const maxEvalDim = input.maxEvalDim ?? 256; + const paintPointsPerCase = input.paintPointsPerCase ?? 24; + const yieldEveryCases = input.yieldEveryCases ?? 1; + + const configs = input.configs.map((c) => ({ name: c.name, opts: c.opts })); + + const cases: GtBenchmarkCaseResult[] = []; + + type Agg = { + casesOk: number; + casesError: number; + tp: number; + fp: number; + fn: number; + tn: number; + bMeanPredToGtSum: number; + bMeanGtToPredSum: number; + bMeanSymSum: number; + bMaxSymMax: number; + bCount: number; + }; + + const aggs: Record = {}; + for (const c of configs) { + aggs[c.name] = { + casesOk: 0, + casesError: 0, + tp: 0, + fp: 0, + fn: 0, + tn: 0, + bMeanPredToGtSum: 0, + bMeanGtToPredSum: 0, + bMeanSymSum: 0, + bMaxSymMax: 0, + bCount: 0, + }; + } + + const caseCount = input.cases.length; + + for (let caseIndex = 0; caseIndex < caseCount; caseIndex++) { + const c = input.cases[caseIndex]!; + const tCase0 = performance.now(); + + input.onProgress?.({ + caseIndex, + caseCount, + message: `Benchmark: loading slice ${caseIndex + 1}/${caseCount}…`, + }); + + const tLoad0 = performance.now(); + let image: + | { imageId: string; gray: Uint8Array; w: number; h: number; sourceW: number; sourceH: number } + | null = null; + let loadError: string | null = null; + + try { + image = await loadAndNormalizeImage(c.sopInstanceUid, maxEvalDim); + } catch (e) { + loadError = e instanceof Error ? e.message : 'Failed to load image'; + } + const tLoad1 = performance.now(); + + const resultsByConfig: Record = {}; + + if (!image) { + for (const cfg of configs) { + resultsByConfig[cfg.name] = { ok: false, error: `Image load failed: ${loadError ?? 'unknown error'}` }; + aggs[cfg.name].casesError++; + } + + const tCase1 = performance.now(); + cases.push({ + id: c.id, + comboId: c.comboId, + dateIso: c.dateIso, + seriesUid: c.seriesUid, + sopInstanceUid: c.sopInstanceUid, + image: { + imageId: `miradb:${c.sopInstanceUid}`, + sourceW: 0, + sourceH: 0, + evalW: 0, + evalH: 0, + }, + paintPointsCount: 0, + resultsByConfig, + timingMs: { + loadImage: tLoad1 - tLoad0, + total: tCase1 - tCase0, + }, + }); + + if (yieldEveryCases > 0 && caseIndex % yieldEveryCases === 0) { + await yieldToUi(); + } + continue; + } + + const gtPolyImage01 = remapPolygonToImage01({ + polygon: c.gtPolygon, + viewportSize: safeViewportSize(c.gtViewportSize ?? { w: 512, h: 512 }), + fromViewTransform: c.gtViewTransform, + imageSize: { w: image.w, h: image.h }, + }); + + const paintPoints = generatePaintPointsFromGt(gtPolyImage01, c.id, paintPointsPerCase); + const threshold = estimateThresholdFromSeedPoints(image.gray, image.w, image.h, paintPoints); + + for (const cfg of configs) { + const tSeg0 = performance.now(); + input.onProgress?.({ + caseIndex, + caseCount, + configName: cfg.name, + message: `Benchmark: segmenting (${cfg.name}) ${caseIndex + 1}/${caseCount}…`, + }); + + try { + const res = segmentTumorFromGrayscale(image.gray, image.w, image.h, paintPoints, threshold, cfg.opts); + + const tEval0 = performance.now(); + const gtMask = rasterizePolygonToMask(gtPolyImage01, image.w, image.h); + const predMask = rasterizePolygonToMask(res.polygon, image.w, image.h); + const metrics = computeMaskMetrics(predMask, gtMask); + const boundary = computePolygonBoundaryMetrics(res.polygon, gtPolyImage01, image.w, image.h); + const tEval1 = performance.now(); + + resultsByConfig[cfg.name] = { + ok: true, + threshold, + metrics, + boundary, + predPolygonPointCount: res.polygon.points.length, + timingMs: { + segment: tEval0 - tSeg0, + evaluate: tEval1 - tEval0, + }, + }; + + const a = aggs[cfg.name]; + a.casesOk++; + a.tp += metrics.tp; + a.fp += metrics.fp; + a.fn += metrics.fn; + a.tn += metrics.tn; + + if (Number.isFinite(boundary.meanSymPx)) { + a.bMeanPredToGtSum += boundary.meanPredToGtPx; + a.bMeanGtToPredSum += boundary.meanGtToPredPx; + a.bMeanSymSum += boundary.meanSymPx; + a.bMaxSymMax = Math.max(a.bMaxSymMax, boundary.maxSymPx); + a.bCount++; + } + } catch (e) { + const msg = e instanceof Error ? e.message : 'Segmentation failed'; + resultsByConfig[cfg.name] = { + ok: false, + error: msg, + }; + aggs[cfg.name].casesError++; + } + } + + const tCase1 = performance.now(); + + cases.push({ + id: c.id, + comboId: c.comboId, + dateIso: c.dateIso, + seriesUid: c.seriesUid, + sopInstanceUid: c.sopInstanceUid, + image: { + imageId: image.imageId, + sourceW: image.sourceW, + sourceH: image.sourceH, + evalW: image.w, + evalH: image.h, + }, + paintPointsCount: paintPoints.length, + resultsByConfig, + timingMs: { + loadImage: tLoad1 - tLoad0, + total: tCase1 - tCase0, + }, + }); + + if (yieldEveryCases > 0 && caseIndex % yieldEveryCases === 0) { + await yieldToUi(); + } + } + + const summary: GtBenchmarkSummary[] = []; + for (const cfg of configs) { + const a = aggs[cfg.name]; + const micro = computeMaskMetricsFromCounts(a.tp, a.fp, a.fn, a.tn); + + summary.push({ + config: cfg.name, + casesTotal: caseCount, + casesOk: a.casesOk, + casesError: a.casesError, + micro, + boundary: { + meanPredToGtPx: a.bCount ? a.bMeanPredToGtSum / a.bCount : Number.POSITIVE_INFINITY, + meanGtToPredPx: a.bCount ? a.bMeanGtToPredSum / a.bCount : Number.POSITIVE_INFINITY, + meanSymPx: a.bCount ? a.bMeanSymSum / a.bCount : Number.POSITIVE_INFINITY, + maxSymPx: a.bCount ? a.bMaxSymMax : Number.POSITIVE_INFINITY, + count: a.bCount, + }, + }); + } + + return { + version: 1, + generatedAtIso: new Date().toISOString(), + settings: { + maxEvalDim, + paintPointsPerCase, + }, + configs, + summary, + cases, + note: + 'This benchmark uses auto-generated paint points inside the GT polygon (deterministic per GT id) and thresholds estimated from those samples. Images are loaded from Cornerstone pixel data and downsampled (preserving aspect ratio) to maxEvalDim for speed.', + }; +} diff --git a/frontend/src/utils/segmentation/harness/base64.ts b/frontend/src/utils/segmentation/harness/base64.ts new file mode 100644 index 0000000..1dee4d8 --- /dev/null +++ b/frontend/src/utils/segmentation/harness/base64.ts @@ -0,0 +1,62 @@ +// Cross-platform base64 helpers. +// +// We want the tumor harness dataset to be usable both: +// - in the browser (exporter UI), and +// - in Node (offline harness runner). +// +// Node has Buffer; browsers typically have atob/btoa. +// These helpers avoid pulling in additional dependencies. + +type BufferCtor = { + from(data: Uint8Array | string, encoding?: string): Uint8Array & { toString(encoding: string): string }; +}; + +function getBufferCtor(): BufferCtor | null { + const maybe = (globalThis as unknown as { Buffer?: BufferCtor }).Buffer; + return typeof maybe?.from === 'function' ? maybe : null; +} + +export function bytesToBase64(bytes: Uint8Array): string { + const Buffer = getBufferCtor(); + if (Buffer) { + return Buffer.from(bytes).toString('base64'); + } + + if (typeof btoa !== 'function') { + throw new Error('bytesToBase64: no Buffer and no btoa available'); + } + + // btoa expects a binary string. Build it in chunks to avoid stack/arg limits. + const chunkSize = 0x8000; + let binary = ''; + for (let i = 0; i < bytes.length; i += chunkSize) { + const chunk = bytes.subarray(i, Math.min(bytes.length, i + chunkSize)); + + let s = ''; + for (let j = 0; j < chunk.length; j++) { + s += String.fromCharCode(chunk[j] ?? 0); + } + + binary += s; + } + + return btoa(binary); +} + +export function base64ToBytes(b64: string): Uint8Array { + const Buffer = getBufferCtor(); + if (Buffer) { + return Buffer.from(b64, 'base64'); + } + + if (typeof atob !== 'function') { + throw new Error('base64ToBytes: no Buffer and no atob available'); + } + + const binary = atob(b64); + const out = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i++) { + out[i] = binary.charCodeAt(i) & 0xff; + } + return out; +} diff --git a/frontend/src/utils/segmentation/harness/canonicalize.ts b/frontend/src/utils/segmentation/harness/canonicalize.ts new file mode 100644 index 0000000..c826141 --- /dev/null +++ b/frontend/src/utils/segmentation/harness/canonicalize.ts @@ -0,0 +1,37 @@ +import type { NormalizedPoint, TumorPolygon, ViewerTransform } from '../../../db/schema'; +import { normalizeViewerTransform, remapPointsBetweenViewerTransforms, remapPolygonBetweenViewerTransforms, type ViewportSize } from '../../viewTransform'; +import { viewerNormToImageNorm, type ImageSizePx } from '../../viewportMapping'; + +function clamp01(v: number) { + return Math.max(0, Math.min(1, v)); +} + +export function remapPointsToImage01(args: { + points: NormalizedPoint[]; + viewportSize: ViewportSize; + fromViewTransform?: ViewerTransform | null; + imageSize: ImageSizePx; +}): NormalizedPoint[] { + const { points, viewportSize, fromViewTransform, imageSize } = args; + const from = normalizeViewerTransform(fromViewTransform ?? null); + const to = normalizeViewerTransform(null); + + const pointsIdentity = remapPointsBetweenViewerTransforms(points, viewportSize, from, to); + return pointsIdentity.map((p) => viewerNormToImageNorm({ x: clamp01(p.x), y: clamp01(p.y) }, viewportSize, imageSize)); +} + +export function remapPolygonToImage01(args: { + polygon: TumorPolygon; + viewportSize: ViewportSize; + fromViewTransform?: ViewerTransform | null; + imageSize: ImageSizePx; +}): TumorPolygon { + const { polygon, viewportSize, fromViewTransform, imageSize } = args; + const from = normalizeViewerTransform(fromViewTransform ?? null); + const to = normalizeViewerTransform(null); + + const polyIdentity = remapPolygonBetweenViewerTransforms(polygon, viewportSize, from, to); + return { + points: polyIdentity.points.map((p) => viewerNormToImageNorm({ x: clamp01(p.x), y: clamp01(p.y) }, viewportSize, imageSize)), + }; +} diff --git a/frontend/src/utils/segmentation/harness/dataset.ts b/frontend/src/utils/segmentation/harness/dataset.ts new file mode 100644 index 0000000..85bac44 --- /dev/null +++ b/frontend/src/utils/segmentation/harness/dataset.ts @@ -0,0 +1,98 @@ +import type { NormalizedPoint, TumorPolygon, TumorThreshold } from '../../../db/schema'; +import type { SegmentTumorOptions } from '../segmentTumor'; + +export type TumorHarnessImageV1 = { + // Evaluated image size (may be downsampled). + w: number; + h: number; + + // Original DICOM pixel dimensions (for reference/debug). + sourceW: number; + sourceH: number; + + // Base64 of raw grayscale bytes (Uint8Array, length = w*h). + grayB64: string; +}; + +export type TumorHarnessCaseV1 = { + id: string; + + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + sopInstanceUid: string; + + image: TumorHarnessImageV1; + + // Ground truth polygon in *image* normalized coordinates (0..1). + gtPolygonImage01: TumorPolygon; + + // Optional paint points (image coords). If omitted, the harness can synthesize paint from GT. + paintPointsImage01?: NormalizedPoint[]; +}; + +export type TumorHarnessPropagationFrameV1 = { + // Index in the series' effective ordering (0..N-1). + effectiveIndex: number; + sopInstanceUid: string; + + image: TumorHarnessImageV1; + + // Present only for frames where GT exists. + gtPolygonImage01?: TumorPolygon; +}; + +export type TumorHarnessPropagationScenarioV1 = { + id: string; + + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + + // Ordered frames for a slice range. + frames: TumorHarnessPropagationFrameV1[]; + + start: { + effectiveIndex: number; + sopInstanceUid: string; + + // Starting paint gesture in image coords (used to compute initial threshold + seed). + paintPointsImage01: NormalizedPoint[]; + + // If provided, this is the threshold the user used at the start slice. + // If omitted, the harness can initialize via estimateThresholdFromSeedPoints. + threshold?: TumorThreshold; + + // Optional overrides for the initial segmentation/seed computation. + // Propagation may still use defaults unless explicitly threaded through. + initialOpts?: SegmentTumorOptions; + }; + + // Propagation stopping rules (defaults should mirror the UI). + stop?: { + minAreaPx: number; + maxMissesInARow: number; + }; + + note?: string; +}; + +export type TumorHarnessDatasetV1 = { + version: 1; + generatedAtIso: string; + + settings: { + // Max dimension used during export downsampling (preserves aspect ratio). + maxEvalDim: number; + }; + + // Single-slice cases (GT rows). + cases: TumorHarnessCaseV1[]; + + // Optional propagation scenarios. + propagationScenarios?: TumorHarnessPropagationScenarioV1[]; + + note?: string; +}; diff --git a/frontend/src/utils/segmentation/harness/exportTumorHarnessDataset.ts b/frontend/src/utils/segmentation/harness/exportTumorHarnessDataset.ts new file mode 100644 index 0000000..ca3139c --- /dev/null +++ b/frontend/src/utils/segmentation/harness/exportTumorHarnessDataset.ts @@ -0,0 +1,261 @@ +import JSZip from 'jszip'; +import type { NormalizedPoint, TumorGroundTruthRow, TumorThreshold, ViewerTransform } from '../../../db/schema'; +import type { ViewportSize } from '../../viewTransform'; +import type { TumorHarnessCaseV1, TumorHarnessDatasetV1, TumorHarnessPropagationScenarioV1 } from './dataset'; +import { bytesToBase64 } from './base64'; +import { remapPointsToImage01, remapPolygonToImage01 } from './canonicalize'; +import { loadCornerstoneSliceToGrayscale } from './loadCornerstoneGrayscale'; +import { generateSyntheticPaintPointsFromGt } from './syntheticPaint'; +import { getSortedSopInstanceUidsForSeries } from '../../localApi'; + +export type ExportTumorHarnessDatasetInput = { + maxEvalDim: number; + + // Export per-slice cases for all GT rows. + gtRows: TumorGroundTruthRow[]; + paintPointsPerCase?: number; + + // Optional propagation scenario derived from the current paint gesture. + propagationScenario?: { + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + + startEffectiveIndex: number; + startSopInstanceUid: string; + + paintPointsViewer01: NormalizedPoint[]; + paintPointsViewTransform?: ViewerTransform | null; + viewportSize: ViewportSize; + + threshold?: TumorThreshold; + + stop?: { + minAreaPx: number; + maxMissesInARow: number; + }; + + // How many slices to include beyond the GT min/max range. + marginSlices?: number; + }; + + onProgress?: (msg: string) => void; +}; + +function sanitizeNumber(v: number): number { + return Number.isFinite(v) ? v : 0; +} + +function safeViewportSize(v: { w: number; h: number } | null | undefined): ViewportSize { + const w = Math.max(1, Math.round(sanitizeNumber(v?.w ?? 0))); + const h = Math.max(1, Math.round(sanitizeNumber(v?.h ?? 0))); + return { w, h }; +} + +function downloadBlob(blob: Blob, filename: string) { + const url = URL.createObjectURL(blob); + try { + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.rel = 'noopener'; + a.click(); + } finally { + URL.revokeObjectURL(url); + } +} + +export async function exportTumorHarnessDatasetToZip(input: ExportTumorHarnessDatasetInput): Promise<{ + dataset: TumorHarnessDatasetV1; + zipBlob: Blob; +}> { + const maxEvalDim = Math.max(16, Math.round(input.maxEvalDim)); + const paintPointsPerCase = Math.max(8, Math.round(input.paintPointsPerCase ?? 24)); + + const cases: TumorHarnessCaseV1[] = []; + + const gtRows = input.gtRows.filter((r) => (r.polygon?.points?.length ?? 0) >= 3); + + input.onProgress?.(`Export: building cases from ${gtRows.length} GT rows…`); + + for (let i = 0; i < gtRows.length; i++) { + const r = gtRows[i]!; + + input.onProgress?.(`Export: loading GT slice ${i + 1}/${gtRows.length}…`); + + const img = await loadCornerstoneSliceToGrayscale({ sopInstanceUid: r.sopInstanceUid, maxEvalDim }); + const imageSize = { w: img.w, h: img.h }; + + // Prefer recorded viewport size; fall back to the common capture size. + const viewport = safeViewportSize(r.viewportSize ?? { w: 512, h: 512 }); + + const gtPolygonImage01 = remapPolygonToImage01({ + polygon: r.polygon, + viewportSize: viewport, + fromViewTransform: r.viewTransform, + imageSize, + }); + + const paintPointsImage01 = generateSyntheticPaintPointsFromGt(gtPolygonImage01, r.id, paintPointsPerCase); + + cases.push({ + id: r.id, + comboId: r.comboId, + dateIso: r.dateIso, + studyId: r.studyId, + seriesUid: r.seriesUid, + sopInstanceUid: r.sopInstanceUid, + image: { + w: img.w, + h: img.h, + sourceW: img.sourceW, + sourceH: img.sourceH, + grayB64: bytesToBase64(img.gray), + }, + gtPolygonImage01, + paintPointsImage01, + }); + } + + let propagationScenarios: TumorHarnessPropagationScenarioV1[] | undefined; + + if (input.propagationScenario) { + const s = input.propagationScenario; + + input.onProgress?.('Export: building propagation scenario…'); + + const uids = await getSortedSopInstanceUidsForSeries(s.seriesUid); + const uidToIndex = new Map(); + for (let idx = 0; idx < uids.length; idx++) { + const uid = uids[idx]; + if (uid) uidToIndex.set(uid, idx); + } + + const gtInSeries = gtRows.filter((r) => r.seriesUid === s.seriesUid); + const gtIdxs = gtInSeries + .map((r) => uidToIndex.get(r.sopInstanceUid)) + .filter((v): v is number => typeof v === 'number' && Number.isFinite(v)); + + const margin = Math.max(0, Math.round(s.marginSlices ?? 2)); + let minIdx = gtIdxs.length ? Math.max(0, Math.min(...gtIdxs) - margin) : Math.max(0, s.startEffectiveIndex - margin); + let maxIdx = gtIdxs.length + ? Math.min(uids.length - 1, Math.max(...gtIdxs) + margin) + : Math.min(uids.length - 1, s.startEffectiveIndex + margin); + + const startIdx = Math.max(0, Math.min(uids.length - 1, Math.round(s.startEffectiveIndex))); + + // Always include the start slice, even if GT range is elsewhere. + minIdx = Math.min(minIdx, startIdx); + maxIdx = Math.max(maxIdx, startIdx); + + // Precompute GT polygons per SOP in image coords (using each row's saved viewport size). + const gtBySop = new Map(); + for (const r of gtInSeries) { + gtBySop.set(r.sopInstanceUid, { polygon: r.polygon, viewTransform: r.viewTransform, viewportSize: r.viewportSize }); + } + + const frames: TumorHarnessPropagationScenarioV1['frames'] = []; + + for (let idx = minIdx; idx <= maxIdx; idx++) { + const sop = uids[idx]; + if (!sop) continue; + + input.onProgress?.(`Export: loading series slice ${idx + 1}/${uids.length}…`); + + const img = await loadCornerstoneSliceToGrayscale({ sopInstanceUid: sop, maxEvalDim }); + const imageSize = { w: img.w, h: img.h }; + + const gt = gtBySop.get(sop); + const gtPolygonImage01 = gt + ? remapPolygonToImage01({ + polygon: gt.polygon, + viewportSize: safeViewportSize(gt.viewportSize ?? { w: 512, h: 512 }), + fromViewTransform: gt.viewTransform, + imageSize, + }) + : undefined; + + frames.push({ + effectiveIndex: idx, + sopInstanceUid: sop, + image: { + w: img.w, + h: img.h, + sourceW: img.sourceW, + sourceH: img.sourceH, + grayB64: bytesToBase64(img.gray), + }, + gtPolygonImage01, + }); + + // Small yield to keep the UI responsive. + await new Promise((resolve) => window.setTimeout(resolve, 0)); + } + + // Convert recorded start paint points to image coords at the start frame's eval size. + const startFrame = frames.find((f) => f.effectiveIndex === startIdx); + if (!startFrame) { + throw new Error('Propagation scenario export failed: start frame not found in exported range'); + } + + const startPaintImage01 = remapPointsToImage01({ + points: s.paintPointsViewer01, + viewportSize: safeViewportSize(s.viewportSize), + fromViewTransform: s.paintPointsViewTransform, + imageSize: { w: startFrame.image.w, h: startFrame.image.h }, + }); + + const scenario: TumorHarnessPropagationScenarioV1 = { + id: `${s.seriesUid}::start=${startIdx}::${new Date().toISOString()}`, + comboId: s.comboId, + dateIso: s.dateIso, + studyId: s.studyId, + seriesUid: s.seriesUid, + frames, + start: { + effectiveIndex: startIdx, + sopInstanceUid: s.startSopInstanceUid, + paintPointsImage01: startPaintImage01, + threshold: s.threshold, + }, + stop: s.stop, + note: 'Frames are downsampled+normalized DICOM pixel data. Start paint points are exported from the overlay and remapped into image coords.', + }; + + propagationScenarios = [scenario]; + } + + const dataset: TumorHarnessDatasetV1 = { + version: 1, + generatedAtIso: new Date().toISOString(), + settings: { maxEvalDim }, + cases, + propagationScenarios, + note: 'Generated from MiraViewer IndexedDB GT polygons and Cornerstone pixel data. GT is remapped via saved viewTransform + viewportSize into image coordinates.', + }; + + const zip = new JSZip(); + zip.file('dataset.json', JSON.stringify(dataset, null, 2)); + + input.onProgress?.('Export: creating zip…'); + + const zipBlob = await zip.generateAsync({ + type: 'blob', + compression: 'DEFLATE', + compressionOptions: { level: 6 }, + }); + + return { dataset, zipBlob }; +} + +export async function exportTumorHarnessDatasetAndDownload(input: ExportTumorHarnessDatasetInput): Promise { + const { zipBlob } = await exportTumorHarnessDatasetToZip(input); + + const dt = new Date(); + const stamp = `${dt.getFullYear()}-${String(dt.getMonth() + 1).padStart(2, '0')}-${String(dt.getDate()).padStart(2, '0')}_${String( + dt.getHours() + ).padStart(2, '0')}${String(dt.getMinutes()).padStart(2, '0')}${String(dt.getSeconds()).padStart(2, '0')}`; + + downloadBlob(zipBlob, `miraviewer_tumor_harness_${stamp}.zip`); +} diff --git a/frontend/src/utils/segmentation/harness/loadCornerstoneGrayscale.ts b/frontend/src/utils/segmentation/harness/loadCornerstoneGrayscale.ts new file mode 100644 index 0000000..752f9a6 --- /dev/null +++ b/frontend/src/utils/segmentation/harness/loadCornerstoneGrayscale.ts @@ -0,0 +1,77 @@ +import cornerstone from 'cornerstone-core'; + +type CornerstoneImageLike = { + rows: number; + columns: number; + getPixelData: () => ArrayLike; + minPixelValue?: number; + maxPixelValue?: number; +}; + +function toByte(v: number) { + return Math.max(0, Math.min(255, Math.round(v))); +} + +/** + * Load a DICOM slice via Cornerstone (miradb:) and normalize pixel data to a 0..255 grayscale byte array. + * + * If maxEvalDim is smaller than the source image, the output is downsampled (nearest-neighbor) preserving + * aspect ratio. This is a speed/size tradeoff for offline harness runs. + */ +export async function loadCornerstoneSliceToGrayscale(args: { + sopInstanceUid: string; + maxEvalDim: number; +}): Promise<{ imageId: string; gray: Uint8Array; w: number; h: number; sourceW: number; sourceH: number }> { + const { sopInstanceUid, maxEvalDim } = args; + + const imageId = `miradb:${sopInstanceUid}`; + const image = (await cornerstone.loadImage(imageId)) as unknown as CornerstoneImageLike; + + const rows = image.rows; + const cols = image.columns; + const getPixelData = image.getPixelData; + if (!rows || !cols || typeof getPixelData !== 'function') { + throw new Error('Cornerstone image missing pixel data'); + } + + const pd = getPixelData(); + + let min = image.minPixelValue; + let max = image.maxPixelValue; + + if (!Number.isFinite(min) || !Number.isFinite(max)) { + min = Number.POSITIVE_INFINITY; + max = Number.NEGATIVE_INFINITY; + for (let i = 0; i < pd.length; i++) { + const v = pd[i]; + if (v < (min as number)) min = v; + if (v > (max as number)) max = v; + } + } + + const denom = (max as number) - (min as number); + + // Downsample for speed, preserving aspect ratio. + const scale = Math.max(cols, rows) / Math.max(16, maxEvalDim); + const w = scale > 1 ? Math.max(16, Math.round(cols / scale)) : cols; + const h = scale > 1 ? Math.max(16, Math.round(rows / scale)) : rows; + + const gray = new Uint8Array(w * h); + + if (!Number.isFinite(denom) || Math.abs(denom) < 1e-8) { + gray.fill(0); + return { imageId, gray, w, h, sourceW: cols, sourceH: rows }; + } + + for (let y = 0; y < h; y++) { + const sy = h <= 1 ? 0 : Math.round((y * (rows - 1)) / (h - 1)); + for (let x = 0; x < w; x++) { + const sx = w <= 1 ? 0 : Math.round((x * (cols - 1)) / (w - 1)); + const v = pd[sy * cols + sx]; + const t = ((v - (min as number)) / denom) * 255; + gray[y * w + x] = toByte(t); + } + } + + return { imageId, gray, w, h, sourceW: cols, sourceH: rows }; +} diff --git a/frontend/src/utils/segmentation/harness/runTumorHarness.ts b/frontend/src/utils/segmentation/harness/runTumorHarness.ts new file mode 100644 index 0000000..19d5b3c --- /dev/null +++ b/frontend/src/utils/segmentation/harness/runTumorHarness.ts @@ -0,0 +1,443 @@ +import type { NormalizedPoint, TumorPolygon, TumorThreshold } from '../../../db/schema'; +import type { SegmentTumorOptions } from '../segmentTumor'; +import { estimateThresholdFromSeedPoints, segmentTumorFromGrayscale } from '../segmentTumor'; +import { computeMaskMetrics, type MaskMetrics } from '../maskMetrics'; +import { + computePolygonBoundaryMetrics, + type PolygonBoundaryMetrics, +} from '../polygonBoundaryMetrics'; +import { rasterizePolygonToMask } from '../rasterizePolygon'; +import { + propagateTumorAcrossFramesCore, + type PropagationFrame, +} from '../../tumorPropagationCore'; +import { base64ToBytes } from './base64'; +import type { + TumorHarnessCaseV1, + TumorHarnessDatasetV1, + TumorHarnessPropagationFrameV1, +} from './dataset'; +import { generateSyntheticPaintPointsFromGt } from './syntheticPaint'; + +export type TumorHarnessConfig = { + name: string; + // Segmentation opts used for both single-slice and propagation runs. + opts?: SegmentTumorOptions; +}; + +export type TumorHarnessCaseResult = { + caseId: string; + ok: boolean; + error?: string; + threshold?: TumorThreshold; + metrics?: MaskMetrics; + boundary?: PolygonBoundaryMetrics; + predPolygonPointCount?: number; +}; + +export type TumorHarnessPropagationSliceEval = { + effectiveIndex: number; + sopInstanceUid: string; + ok: boolean; + hadPrediction: boolean; + metrics: MaskMetrics; + boundary?: PolygonBoundaryMetrics; +}; + +export type TumorHarnessPropagationScenarioResult = { + scenarioId: string; + config: string; + startIndex: number; + savedCount: number; + scoredGtSlices: number; + micro: MaskMetrics; + boundaryAgg: { + meanPredToGtPx: number; + meanGtToPredPx: number; + meanSymPx: number; + maxSymPx: number; + count: number; + }; + slices: TumorHarnessPropagationSliceEval[]; +}; + +export type TumorHarnessReport = { + version: 1; + generatedAtIso: string; + dataset: { + generatedAtIso: string; + maxEvalDim: number; + cases: number; + propagationScenarios: number; + }; + configs: Array<{ name: string; opts?: SegmentTumorOptions }>; + segmentation: { + byConfig: Array<{ + config: string; + casesTotal: number; + casesOk: number; + casesError: number; + micro: MaskMetrics; + boundaryAgg: { + meanPredToGtPx: number; + meanGtToPredPx: number; + meanSymPx: number; + maxSymPx: number; + count: number; + }; + cases: TumorHarnessCaseResult[]; + }>; + }; + propagation: { + byScenarioConfig: TumorHarnessPropagationScenarioResult[]; + }; +}; + +function safeDiv(num: number, den: number) { + return den > 0 ? num / den : 0; +} + +function metricsFromCounts(tp: number, fp: number, fn: number, tn: number): MaskMetrics { + const precision = safeDiv(tp, tp + fp); + const recall = safeDiv(tp, tp + fn); + const dice = safeDiv(2 * tp, 2 * tp + fp + fn); + const iou = safeDiv(tp, tp + fp + fn); + + const beta2 = 4; + const f2 = safeDiv((1 + beta2) * precision * recall, beta2 * precision + recall); + + return { tp, fp, fn, tn, precision, recall, dice, iou, f2 }; +} + +function decodeGray(caseOrFrame: { image: { w: number; h: number; grayB64: string } }): Uint8Array { + const { w, h, grayB64 } = caseOrFrame.image; + const bytes = base64ToBytes(grayB64); + const want = w * h; + if (bytes.length !== want) { + throw new Error(`Decoded gray length mismatch (got ${bytes.length}, want ${want} for ${w}x${h})`); + } + return bytes; +} + +function getPaintPointsForCase(c: TumorHarnessCaseV1): NormalizedPoint[] { + if (c.paintPointsImage01 && c.paintPointsImage01.length >= 2) return c.paintPointsImage01; + return generateSyntheticPaintPointsFromGt(c.gtPolygonImage01, c.id, 24); +} + +function jitterCross(seed: NormalizedPoint, w: number, h: number): NormalizedPoint[] { + const clamp01 = (v: number) => Math.max(0, Math.min(1, v)); + const jitter = Math.max(0.002, 1 / Math.max(w, h)); + return [ + { x: clamp01(seed.x), y: clamp01(seed.y) }, + { x: clamp01(seed.x + jitter), y: clamp01(seed.y) }, + { x: clamp01(seed.x - jitter), y: clamp01(seed.y) }, + { x: clamp01(seed.x), y: clamp01(seed.y + jitter) }, + { x: clamp01(seed.x), y: clamp01(seed.y - jitter) }, + ]; +} + +export async function runTumorHarnessDataset(args: { + dataset: TumorHarnessDatasetV1; + configs: TumorHarnessConfig[]; +}): Promise { + const { dataset, configs } = args; + + const report: TumorHarnessReport = { + version: 1, + generatedAtIso: new Date().toISOString(), + dataset: { + generatedAtIso: dataset.generatedAtIso, + maxEvalDim: dataset.settings.maxEvalDim, + cases: dataset.cases.length, + propagationScenarios: dataset.propagationScenarios?.length ?? 0, + }, + configs: configs.map((c) => ({ name: c.name, opts: c.opts })), + segmentation: { byConfig: [] }, + propagation: { byScenarioConfig: [] }, + }; + + // -------- + // Single-slice segmentation evaluation + // -------- + for (const cfg of configs) { + const casesOut: TumorHarnessCaseResult[] = []; + + let tp = 0; + let fp = 0; + let fn = 0; + let tn = 0; + + let bCount = 0; + let bMeanPredToGtSum = 0; + let bMeanGtToPredSum = 0; + let bMeanSymSum = 0; + let bMaxSymMax = 0; + + for (const c of dataset.cases) { + try { + const gray = decodeGray(c); + const w = c.image.w; + const h = c.image.h; + + const paint = getPaintPointsForCase(c); + const threshold = estimateThresholdFromSeedPoints(gray, w, h, paint); + const seg = segmentTumorFromGrayscale(gray, w, h, paint, threshold, cfg.opts); + + const gtMask = rasterizePolygonToMask(c.gtPolygonImage01, w, h); + const predMask = rasterizePolygonToMask(seg.polygon, w, h); + + const metrics = computeMaskMetrics(predMask, gtMask); + const boundary = computePolygonBoundaryMetrics(seg.polygon, c.gtPolygonImage01, w, h); + + casesOut.push({ + caseId: c.id, + ok: true, + threshold, + metrics, + boundary, + predPolygonPointCount: seg.polygon.points.length, + }); + + tp += metrics.tp; + fp += metrics.fp; + fn += metrics.fn; + tn += metrics.tn; + + if (Number.isFinite(boundary.meanSymPx)) { + bCount++; + bMeanPredToGtSum += boundary.meanPredToGtPx; + bMeanGtToPredSum += boundary.meanGtToPredPx; + bMeanSymSum += boundary.meanSymPx; + bMaxSymMax = Math.max(bMaxSymMax, boundary.maxSymPx); + } + } catch (e) { + casesOut.push({ + caseId: c.id, + ok: false, + error: e instanceof Error ? e.message : 'Segmentation failed', + }); + } + } + + report.segmentation.byConfig.push({ + config: cfg.name, + casesTotal: dataset.cases.length, + casesOk: casesOut.filter((r) => r.ok).length, + casesError: casesOut.filter((r) => !r.ok).length, + micro: metricsFromCounts(tp, fp, fn, tn), + boundaryAgg: { + meanPredToGtPx: bCount ? bMeanPredToGtSum / bCount : Number.POSITIVE_INFINITY, + meanGtToPredPx: bCount ? bMeanGtToPredSum / bCount : Number.POSITIVE_INFINITY, + meanSymPx: bCount ? bMeanSymSum / bCount : Number.POSITIVE_INFINITY, + maxSymPx: bCount ? bMaxSymMax : Number.POSITIVE_INFINITY, + count: bCount, + }, + cases: casesOut, + }); + } + + // -------- + // Propagation evaluation + // -------- + const scenarios = dataset.propagationScenarios ?? []; + for (const scenario of scenarios) { + for (const cfg of configs) { + const startIdx = scenario.start.effectiveIndex; + + if (scenario.frames.length === 0) { + report.propagation.byScenarioConfig.push({ + scenarioId: scenario.id, + config: cfg.name, + startIndex: startIdx, + savedCount: 0, + scoredGtSlices: 0, + micro: metricsFromCounts(0, 0, 0, 0), + boundaryAgg: { + meanPredToGtPx: Number.POSITIVE_INFINITY, + meanGtToPredPx: Number.POSITIVE_INFINITY, + meanSymPx: Number.POSITIVE_INFINITY, + maxSymPx: Number.POSITIVE_INFINITY, + count: 0, + }, + slices: [], + }); + continue; + } + + const framesByIndex = new Map(); + for (const f of scenario.frames) { + framesByIndex.set(f.effectiveIndex, f); + } + + const minIdx = Math.min(...scenario.frames.map((f) => f.effectiveIndex)); + const maxIdx = Math.max(...scenario.frames.map((f) => f.effectiveIndex)); + + const startFrame = framesByIndex.get(startIdx); + if (!startFrame) { + report.propagation.byScenarioConfig.push({ + scenarioId: scenario.id, + config: cfg.name, + startIndex: startIdx, + savedCount: 0, + scoredGtSlices: 0, + micro: metricsFromCounts(0, 0, 0, 0), + boundaryAgg: { + meanPredToGtPx: Number.POSITIVE_INFINITY, + meanGtToPredPx: Number.POSITIVE_INFINITY, + meanSymPx: Number.POSITIVE_INFINITY, + maxSymPx: Number.POSITIVE_INFINITY, + count: 0, + }, + slices: [], + }); + continue; + } + + // Initial segmentation on the start slice (from paint). + const startGray = decodeGray(startFrame); + const startW = startFrame.image.w; + const startH = startFrame.image.h; + + const startPaint = scenario.start.paintPointsImage01; + const startThreshold = + scenario.start.threshold ?? estimateThresholdFromSeedPoints(startGray, startW, startH, startPaint); + + const startSeg = segmentTumorFromGrayscale(startGray, startW, startH, startPaint, startThreshold, cfg.opts); + const seed = startSeg.seed; + + const getFrame = async (index: number): Promise => { + const f = framesByIndex.get(index); + if (!f) return null; + const gray = decodeGray(f); + return { + sopInstanceUid: f.sopInstanceUid, + w: f.image.w, + h: f.image.h, + gray, + seedPointsNorm: jitterCross(seed, f.image.w, f.image.h), + }; + }; + + const stop = scenario.stop ?? { minAreaPx: 80, maxMissesInARow: 3 }; + const propRes = await propagateTumorAcrossFramesCore({ + minIndex: minIdx, + maxIndex: maxIdx, + startEffectiveIndex: startIdx, + getFrame, + threshold: startThreshold, + opts: cfg.opts, + stop, + }); + + // Merge start slice into predictions. + const predByIndex = new Map(); + predByIndex.set(startIdx, startSeg.polygon); + for (const r of propRes.results) { + predByIndex.set(r.index, r.segmentation.polygon); + } + + let tp = 0; + let fp = 0; + let fn = 0; + let tn = 0; + + let bCount = 0; + let bMeanPredToGtSum = 0; + let bMeanGtToPredSum = 0; + let bMeanSymSum = 0; + let bMaxSymMax = 0; + + const sliceEvals: TumorHarnessPropagationSliceEval[] = []; + + for (const f of scenario.frames) { + const gt = f.gtPolygonImage01; + if (!gt || (gt.points?.length ?? 0) < 3) continue; + + const w = f.image.w; + const h = f.image.h; + + const gtMask = rasterizePolygonToMask(gt, w, h); + + const predPoly = predByIndex.get(f.effectiveIndex) ?? null; + const hadPrediction = !!predPoly; + const predMask = predPoly ? rasterizePolygonToMask(predPoly, w, h) : new Uint8Array(w * h); + + const metrics = computeMaskMetrics(predMask, gtMask); + + const boundary = predPoly ? computePolygonBoundaryMetrics(predPoly, gt, w, h) : undefined; + + sliceEvals.push({ + effectiveIndex: f.effectiveIndex, + sopInstanceUid: f.sopInstanceUid, + ok: true, + hadPrediction, + metrics, + boundary, + }); + + tp += metrics.tp; + fp += metrics.fp; + fn += metrics.fn; + tn += metrics.tn; + + if (boundary && Number.isFinite(boundary.meanSymPx)) { + bCount++; + bMeanPredToGtSum += boundary.meanPredToGtPx; + bMeanGtToPredSum += boundary.meanGtToPredPx; + bMeanSymSum += boundary.meanSymPx; + bMaxSymMax = Math.max(bMaxSymMax, boundary.maxSymPx); + } + } + + report.propagation.byScenarioConfig.push({ + scenarioId: scenario.id, + config: cfg.name, + startIndex: startIdx, + savedCount: propRes.saved, + scoredGtSlices: sliceEvals.length, + micro: metricsFromCounts(tp, fp, fn, tn), + boundaryAgg: { + meanPredToGtPx: bCount ? bMeanPredToGtSum / bCount : Number.POSITIVE_INFINITY, + meanGtToPredPx: bCount ? bMeanGtToPredSum / bCount : Number.POSITIVE_INFINITY, + meanSymPx: bCount ? bMeanSymSum / bCount : Number.POSITIVE_INFINITY, + maxSymPx: bCount ? bMaxSymMax : Number.POSITIVE_INFINITY, + count: bCount, + }, + slices: sliceEvals, + }); + } + } + + return report; +} + +export function parseTumorHarnessDataset(jsonText: string): TumorHarnessDatasetV1 { + const raw = JSON.parse(jsonText) as unknown; + const d = raw as TumorHarnessDatasetV1; + + if (!d || d.version !== 1) { + throw new Error('Unsupported tumor harness dataset version'); + } + + if (!Array.isArray(d.cases)) { + throw new Error('Invalid tumor harness dataset: missing cases'); + } + + return d; +} + +export function summarizeReport(report: TumorHarnessReport): { + bestSegConfigByF2: { name: string; f2: number } | null; + bestSegConfigByDice: { name: string; dice: number } | null; +} { + const seg = report.segmentation.byConfig; + if (!seg.length) return { bestSegConfigByF2: null, bestSegConfigByDice: null }; + + const bestF2 = seg.reduce((a, b) => (b.micro.f2 > a.micro.f2 ? b : a)); + const bestDice = seg.reduce((a, b) => (b.micro.dice > a.micro.dice ? b : a)); + + return { + bestSegConfigByF2: { name: bestF2.config, f2: bestF2.micro.f2 }, + bestSegConfigByDice: { name: bestDice.config, dice: bestDice.micro.dice }, + }; +} diff --git a/frontend/src/utils/segmentation/harness/syntheticPaint.ts b/frontend/src/utils/segmentation/harness/syntheticPaint.ts new file mode 100644 index 0000000..d35e4b9 --- /dev/null +++ b/frontend/src/utils/segmentation/harness/syntheticPaint.ts @@ -0,0 +1,185 @@ +import type { NormalizedPoint, TumorPolygon } from '../../../db/schema'; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function clamp01(v: number) { + return clamp(v, 0, 1); +} + +function hashStringToSeed(s: string): number { + // FNV-1a 32-bit + let h = 2166136261; + for (let i = 0; i < s.length; i++) { + h ^= s.charCodeAt(i); + h = Math.imul(h, 16777619); + } + return h >>> 0; +} + +function makeLcg(seed: number) { + let state = seed >>> 0; + return () => { + state = (Math.imul(1664525, state) + 1013904223) >>> 0; + return state / 4294967296; + }; +} + +function polygonBounds01(poly: TumorPolygon): { minX: number; minY: number; maxX: number; maxY: number } { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const p of poly.points) { + minX = Math.min(minX, p.x); + minY = Math.min(minY, p.y); + maxX = Math.max(maxX, p.x); + maxY = Math.max(maxY, p.y); + } + + return { + minX: clamp01(minX), + minY: clamp01(minY), + maxX: clamp01(maxX), + maxY: clamp01(maxY), + }; +} + +function pointInPolygon(pt: NormalizedPoint, poly: TumorPolygon): boolean { + // Ray casting. + const pts = poly.points; + const n = pts.length; + if (n < 3) return false; + + let inside = false; + for (let i = 0, j = n - 1; i < n; j = i++) { + const a = pts[i]!; + const b = pts[j]!; + const intersects = a.y > pt.y !== b.y > pt.y && pt.x < ((b.x - a.x) * (pt.y - a.y)) / (b.y - a.y + 1e-12) + a.x; + + if (intersects) inside = !inside; + } + + return inside; +} + +function polygonAreaCentroid01(poly: TumorPolygon): NormalizedPoint { + const pts = poly.points; + const n = pts.length; + if (n < 3) { + // Fallback: average. + let sx = 0; + let sy = 0; + for (const p of pts) { + sx += p.x; + sy += p.y; + } + const d = Math.max(1, n); + return { x: clamp01(sx / d), y: clamp01(sy / d) }; + } + + // Polygon centroid (shoelace). Works for simple polygons. + let a2 = 0; + let cx = 0; + let cy = 0; + + for (let i = 0; i < n; i++) { + const p0 = pts[i]!; + const p1 = pts[(i + 1) % n]!; + const cross = p0.x * p1.y - p1.x * p0.y; + a2 += cross; + cx += (p0.x + p1.x) * cross; + cy += (p0.y + p1.y) * cross; + } + + if (Math.abs(a2) < 1e-10) { + // Degenerate polygon. + let sx = 0; + let sy = 0; + for (const p of pts) { + sx += p.x; + sy += p.y; + } + const d = Math.max(1, n); + return { x: clamp01(sx / d), y: clamp01(sy / d) }; + } + + const inv6a = 1 / (3 * a2); + return { x: clamp01(cx * inv6a), y: clamp01(cy * inv6a) }; +} + +function findInteriorPoint01(poly: TumorPolygon): NormalizedPoint { + const c = polygonAreaCentroid01(poly); + if (pointInPolygon(c, poly)) return c; + + // Try bbox center. + const b = polygonBounds01(poly); + const mid = { x: (b.minX + b.maxX) / 2, y: (b.minY + b.maxY) / 2 }; + if (pointInPolygon(mid, poly)) return mid; + + // Brute force a small grid search. + const steps = 9; + for (let yi = 0; yi < steps; yi++) { + for (let xi = 0; xi < steps; xi++) { + const x = b.minX + ((xi + 0.5) / steps) * (b.maxX - b.minX); + const y = b.minY + ((yi + 0.5) / steps) * (b.maxY - b.minY); + const p = { x, y }; + if (pointInPolygon(p, poly)) return p; + } + } + + // Give up: return clamped centroid. + return c; +} + +/** + * Deterministically generate a set of "paint" points inside the GT polygon. + * + * This lets us benchmark/tune without requiring real user paint gestures. + */ +export function generateSyntheticPaintPointsFromGt(gt: TumorPolygon, seedKey: string, targetCount: number): NormalizedPoint[] { + const pts: NormalizedPoint[] = []; + const seed = findInteriorPoint01(gt); + + // Always include a small cross around the seed for robustness. + const j = 0.004; + pts.push(seed); + pts.push({ x: clamp01(seed.x + j), y: seed.y }); + pts.push({ x: clamp01(seed.x - j), y: seed.y }); + pts.push({ x: seed.x, y: clamp01(seed.y + j) }); + pts.push({ x: seed.x, y: clamp01(seed.y - j) }); + + const b = polygonBounds01(gt); + const rand = makeLcg(hashStringToSeed(seedKey)); + + const want = Math.max(8, targetCount); + const maxAttempts = want * 80; + + for (let attempt = 0; attempt < maxAttempts && pts.length < want; attempt++) { + // Bias sampling toward the seed by mixing uniform bbox with a seed-centered jitter. + const mix = rand(); + + let x: number; + let y: number; + + if (mix < 0.7) { + // Seed-centered jitter (roughly "scribble" sized). + const r = 0.03; + x = clamp01(seed.x + (rand() * 2 - 1) * r); + y = clamp01(seed.y + (rand() * 2 - 1) * r); + } else { + // Uniform in bbox. + x = b.minX + rand() * (b.maxX - b.minX); + y = b.minY + rand() * (b.maxY - b.minY); + } + + const p = { x, y }; + if (pointInPolygon(p, gt)) { + pts.push(p); + } + } + + return pts; +} diff --git a/frontend/src/utils/segmentation/marchingSquares.ts b/frontend/src/utils/segmentation/marchingSquares.ts new file mode 100644 index 0000000..e85c58f --- /dev/null +++ b/frontend/src/utils/segmentation/marchingSquares.ts @@ -0,0 +1,201 @@ +export type Roi = { x0: number; y0: number; x1: number; y1: number }; +export type PxPoint = { x: number; y: number }; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function edgeKey(a: string, b: string) { + return a < b ? `${a}|${b}` : `${b}|${a}`; +} + +function parseKey(k: string): { x2: number; y2: number } { + const [xs, ys] = k.split(','); + return { x2: Number(xs), y2: Number(ys) }; +} + +function keyOf(x2: number, y2: number): string { + return `${x2},${y2}`; +} + +function addUndirectedEdge(adj: Map, a: string, b: string) { + const la = adj.get(a); + if (la) la.push(b); + else adj.set(a, [b]); + + const lb = adj.get(b); + if (lb) lb.push(a); + else adj.set(b, [a]); +} + +function polygonArea(points: PxPoint[]): number { + let a = 0; + for (let i = 0; i < points.length; i++) { + const p = points[i]; + const q = points[(i + 1) % points.length]; + a += p.x * q.y - q.x * p.y; + } + return a / 2; +} + +// Marching-squares-style table that connects midpoints on the 4 cell edges. +// Edge indices: +// 0 = top, 1 = right, 2 = bottom, 3 = left +const CASE_TO_SEGMENTS: ReadonlyArray> = [ + [], // 0 + [[3, 2]], // 1 + [[2, 1]], // 2 + [[3, 1]], // 3 + [[0, 1]], // 4 + [ + [0, 1], + [3, 2], + ], // 5 (ambiguous) + [[0, 2]], // 6 + [[0, 3]], // 7 + [[0, 3]], // 8 + [[0, 2]], // 9 + [ + [0, 3], + [2, 1], + ], // 10 (ambiguous) + [[0, 1]], // 11 + [[3, 1]], // 12 + [[2, 1]], // 13 + [[3, 2]], // 14 + [], // 15 +]; + +function edgeMidpointKey2x(edge: number, cellX: number, cellY: number): string { + // Coordinates are in pixel-index space, scaled by 2 to keep integers. + // Cell corners are at (cellX, cellY) .. (cellX+1, cellY+1). + switch (edge) { + case 0: // top + return keyOf(2 * cellX + 1, 2 * cellY); + case 1: // right + return keyOf(2 * cellX + 2, 2 * cellY + 1); + case 2: // bottom + return keyOf(2 * cellX + 1, 2 * cellY + 2); + case 3: // left + return keyOf(2 * cellX, 2 * cellY + 1); + default: + return keyOf(0, 0); + } +} + +/** + * Extract the outer contour of a binary mask. + * + * Returns a single polygon (the loop with the largest absolute area). + * Coordinates are returned in pixel-index space (0..w-1 / 0..h-1) with 0.5 increments. + */ +export function marchingSquaresContour( + mask: Uint8Array, + w: number, + h: number, + roi?: Roi +): PxPoint[] { + if (w <= 1 || h <= 1) return []; + + // Build segment adjacency graph. + const adj = new Map(); + + const cellX0 = roi ? clamp(Math.floor(roi.x0) - 1, 0, w - 2) : 0; + const cellY0 = roi ? clamp(Math.floor(roi.y0) - 1, 0, h - 2) : 0; + const cellX1 = roi ? clamp(Math.ceil(roi.x1) + 1, 0, w - 1) : w - 1; // exclusive upper bound + const cellY1 = roi ? clamp(Math.ceil(roi.y1) + 1, 0, h - 1) : h - 1; // exclusive upper bound + + for (let y = cellY0; y < cellY1; y++) { + const row0 = y * w; + const row1 = (y + 1) * w; + + for (let x = cellX0; x < cellX1; x++) { + const tl = mask[row0 + x] ? 1 : 0; + const tr = mask[row0 + x + 1] ? 1 : 0; + const br = mask[row1 + x + 1] ? 1 : 0; + const bl = mask[row1 + x] ? 1 : 0; + + const idx = (tl << 3) | (tr << 2) | (br << 1) | bl; + const segments = CASE_TO_SEGMENTS[idx]; + if (!segments.length) continue; + + for (const [e0, e1] of segments) { + const a = edgeMidpointKey2x(e0, x, y); + const b = edgeMidpointKey2x(e1, x, y); + addUndirectedEdge(adj, a, b); + } + } + } + + if (adj.size === 0) return []; + + // Trace all loops. + const visitedEdges = new Set(); + const loops: PxPoint[][] = []; + + for (const [start, nbrs] of adj) { + for (const first of nbrs) { + const ek0 = edgeKey(start, first); + if (visitedEdges.has(ek0)) continue; + + const loopKeys: string[] = [start]; + let prev = start; + let curr = first; + + // Safety to avoid infinite walks on malformed graphs. + const maxSteps = 200000; + let steps = 0; + + while (steps++ < maxSteps) { + visitedEdges.add(edgeKey(prev, curr)); + loopKeys.push(curr); + + if (curr === start) break; + + const nextCandidates = adj.get(curr); + if (!nextCandidates || nextCandidates.length === 0) break; + + // Prefer an unvisited edge that doesn't go back to prev. + const next = + nextCandidates.find((n) => n !== prev && !visitedEdges.has(edgeKey(curr, n))) ?? + nextCandidates.find((n) => n !== prev) ?? + nextCandidates[0]; + + prev = curr; + curr = next; + } + + // Keep only closed loops. + if (loopKeys.length >= 4 && loopKeys[loopKeys.length - 1] === start) { + // Remove duplicated closing point. + loopKeys.pop(); + + const pts: PxPoint[] = loopKeys.map((k) => { + const { x2, y2 } = parseKey(k); + return { x: x2 / 2, y: y2 / 2 }; + }); + + // Ignore tiny loops. + if (pts.length >= 3) { + loops.push(pts); + } + } + } + } + + if (loops.length === 0) return []; + + // Return the largest-area loop as the outer contour. + let best = loops[0]; + let bestArea = Math.abs(polygonArea(best)); + + for (let i = 1; i < loops.length; i++) { + const a = Math.abs(polygonArea(loops[i])); + if (a > bestArea) { + bestArea = a; + best = loops[i]; + } + } + + return best; +} diff --git a/frontend/src/utils/segmentation/maskMetrics.ts b/frontend/src/utils/segmentation/maskMetrics.ts new file mode 100644 index 0000000..8e09f15 --- /dev/null +++ b/frontend/src/utils/segmentation/maskMetrics.ts @@ -0,0 +1,51 @@ +export type MaskMetrics = { + tp: number; + fp: number; + fn: number; + tn: number; + + precision: number; + recall: number; + dice: number; + iou: number; + + /** F-beta with beta=2 (weights recall higher than precision). */ + f2: number; +}; + +function safeDiv(num: number, den: number) { + return den > 0 ? num / den : 0; +} + +export function computeMaskMetrics(pred: Uint8Array, gt: Uint8Array): MaskMetrics { + if (pred.length !== gt.length) { + throw new Error('Mask sizes do not match'); + } + + let tp = 0; + let fp = 0; + let fn = 0; + let tn = 0; + + for (let i = 0; i < pred.length; i++) { + const p = pred[i] ? 1 : 0; + const g = gt[i] ? 1 : 0; + + if (p && g) tp++; + else if (p && !g) fp++; + else if (!p && g) fn++; + else tn++; + } + + const precision = safeDiv(tp, tp + fp); + const recall = safeDiv(tp, tp + fn); + + const dice = safeDiv(2 * tp, 2 * tp + fp + fn); + const iou = safeDiv(tp, tp + fp + fn); + + // F2 emphasizes recall. + const beta2 = 4; + const f2 = safeDiv((1 + beta2) * precision * recall, beta2 * precision + recall); + + return { tp, fp, fn, tn, precision, recall, dice, iou, f2 }; +} diff --git a/frontend/src/utils/segmentation/morphology.ts b/frontend/src/utils/segmentation/morphology.ts new file mode 100644 index 0000000..0f1f7a4 --- /dev/null +++ b/frontend/src/utils/segmentation/morphology.ts @@ -0,0 +1,97 @@ +function idx(x: number, y: number, w: number) { + return y * w + x; +} + +/** + * Morphological dilation (3x3) for a binary mask. + * + * Mask values are expected to be 0 or 1. + */ +export function dilate3x3(mask: Uint8Array, w: number, h: number): Uint8Array { + const out = new Uint8Array(w * h); + + for (let y = 0; y < h; y++) { + for (let x = 0; x < w; x++) { + let on = 0; + + for (let dy = -1; dy <= 1 && !on; dy++) { + const yy = y + dy; + if (yy < 0 || yy >= h) continue; + const row = yy * w; + + for (let dx = -1; dx <= 1; dx++) { + const xx = x + dx; + if (xx < 0 || xx >= w) continue; + if (mask[row + xx]) { + on = 1; + break; + } + } + } + + out[idx(x, y, w)] = on; + } + } + + return out; +} + +/** + * Morphological erosion (3x3) for a binary mask. + * + * Out-of-bounds neighbors are treated as 0. + */ +export function erode3x3(mask: Uint8Array, w: number, h: number): Uint8Array { + const out = new Uint8Array(w * h); + + for (let y = 0; y < h; y++) { + for (let x = 0; x < w; x++) { + let on = 1; + + for (let dy = -1; dy <= 1 && on; dy++) { + const yy = y + dy; + if (yy < 0 || yy >= h) { + on = 0; + break; + } + const row = yy * w; + + for (let dx = -1; dx <= 1; dx++) { + const xx = x + dx; + if (xx < 0 || xx >= w) { + on = 0; + break; + } + if (!mask[row + xx]) { + on = 0; + break; + } + } + } + + out[idx(x, y, w)] = on; + } + } + + return out; +} + +/** + * Morphological close: dilate then erode. + * + * This fills small holes and bridges tiny gaps in the mask. + */ +export function morphologicalClose(mask: Uint8Array, w: number, h: number): Uint8Array { + const dilated = dilate3x3(mask, w, h); + return erode3x3(dilated, w, h); +} + +/** + * Morphological open: erode then dilate. + * + * This removes thin spurs / narrow connections and can reduce small leakage regions. + */ +export function morphologicalOpen(mask: Uint8Array, w: number, h: number): Uint8Array { + const eroded = erode3x3(mask, w, h); + return dilate3x3(eroded, w, h); +} diff --git a/frontend/src/utils/segmentation/polygonBoundaryMetrics.ts b/frontend/src/utils/segmentation/polygonBoundaryMetrics.ts new file mode 100644 index 0000000..0119981 --- /dev/null +++ b/frontend/src/utils/segmentation/polygonBoundaryMetrics.ts @@ -0,0 +1,174 @@ +import type { TumorPolygon } from '../../db/schema'; + +export type PolygonBoundaryMetrics = { + /** Mean distance from predicted boundary samples to GT boundary (pixels). */ + meanPredToGtPx: number; + /** Mean distance from GT boundary samples to predicted boundary (pixels). */ + meanGtToPredPx: number; + /** Symmetric mean boundary distance (pixels). */ + meanSymPx: number; + + /** Max distance from predicted boundary samples to GT boundary (pixels). */ + maxPredToGtPx: number; + /** Max distance from GT boundary samples to predicted boundary (pixels). */ + maxGtToPredPx: number; + /** Symmetric max boundary distance (pixels). */ + maxSymPx: number; + + /** Sample counts (for debugging). */ + samplesPred: number; + samplesGt: number; +}; + +type Pt = { x: number; y: number }; + +type Seg = { x0: number; y0: number; x1: number; y1: number }; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function toPixelPoints(poly: TumorPolygon, w: number, h: number): Pt[] { + const pts = poly.points ?? []; + if (pts.length < 3) return []; + + const out: Pt[] = []; + for (const p of pts) { + out.push({ + x: clamp(p.x, 0, 1) * (w - 1), + y: clamp(p.y, 0, 1) * (h - 1), + }); + } + return out; +} + +function toSegments(pts: Pt[]): Seg[] { + if (pts.length < 2) return []; + const segs: Seg[] = []; + for (let i = 0; i < pts.length; i++) { + const a = pts[i]!; + const b = pts[(i + 1) % pts.length]!; + segs.push({ x0: a.x, y0: a.y, x1: b.x, y1: b.y }); + } + return segs; +} + +function pointToSegmentDist2(px: number, py: number, s: Seg): number { + const vx = s.x1 - s.x0; + const vy = s.y1 - s.y0; + const wx = px - s.x0; + const wy = py - s.y0; + + const vv = vx * vx + vy * vy; + if (vv <= 1e-12) { + const dx = px - s.x0; + const dy = py - s.y0; + return dx * dx + dy * dy; + } + + let t = (wx * vx + wy * vy) / vv; + t = clamp(t, 0, 1); + + const cx = s.x0 + t * vx; + const cy = s.y0 + t * vy; + + const dx = px - cx; + const dy = py - cy; + return dx * dx + dy * dy; +} + +function samplePolygonBoundary(pts: Pt[], stepPx: number): Pt[] { + if (pts.length < 3) return []; + const step = Math.max(0.25, stepPx); + + const out: Pt[] = []; + + for (let i = 0; i < pts.length; i++) { + const a = pts[i]!; + const b = pts[(i + 1) % pts.length]!; + const dx = b.x - a.x; + const dy = b.y - a.y; + const len = Math.hypot(dx, dy); + + // Always sample at least the segment start. + const n = Math.max(1, Math.ceil(len / step)); + for (let s = 0; s < n; s++) { + const t = n <= 1 ? 0 : s / n; + out.push({ x: a.x + t * dx, y: a.y + t * dy }); + } + } + + return out; +} + +function meanMaxDistanceToSegments(samples: Pt[], segs: Seg[]): { mean: number; max: number; count: number } { + if (samples.length === 0 || segs.length === 0) { + return { mean: Number.POSITIVE_INFINITY, max: Number.POSITIVE_INFINITY, count: 0 }; + } + + let sum = 0; + let max = 0; + + for (const p of samples) { + let best2 = Number.POSITIVE_INFINITY; + for (const s of segs) { + const d2 = pointToSegmentDist2(p.x, p.y, s); + if (d2 < best2) best2 = d2; + } + + const d = Math.sqrt(best2); + sum += d; + if (d > max) max = d; + } + + return { mean: sum / samples.length, max, count: samples.length }; +} + +export function computePolygonBoundaryMetrics( + pred: TumorPolygon, + gt: TumorPolygon, + w: number, + h: number, + opts?: { sampleStepPx?: number } +): PolygonBoundaryMetrics { + const predPts = toPixelPoints(pred, w, h); + const gtPts = toPixelPoints(gt, w, h); + + if (predPts.length < 3 || gtPts.length < 3) { + return { + meanPredToGtPx: Number.POSITIVE_INFINITY, + meanGtToPredPx: Number.POSITIVE_INFINITY, + meanSymPx: Number.POSITIVE_INFINITY, + maxPredToGtPx: Number.POSITIVE_INFINITY, + maxGtToPredPx: Number.POSITIVE_INFINITY, + maxSymPx: Number.POSITIVE_INFINITY, + samplesPred: 0, + samplesGt: 0, + }; + } + + const step = opts?.sampleStepPx ?? 1.25; + + const predSegs = toSegments(predPts); + const gtSegs = toSegments(gtPts); + + const predSamples = samplePolygonBoundary(predPts, step); + const gtSamples = samplePolygonBoundary(gtPts, step); + + const a = meanMaxDistanceToSegments(predSamples, gtSegs); + const b = meanMaxDistanceToSegments(gtSamples, predSegs); + + const meanSym = (a.mean + b.mean) / 2; + const maxSym = Math.max(a.max, b.max); + + return { + meanPredToGtPx: a.mean, + meanGtToPredPx: b.mean, + meanSymPx: meanSym, + maxPredToGtPx: a.max, + maxGtToPredPx: b.max, + maxSymPx: maxSym, + samplesPred: a.count, + samplesGt: b.count, + }; +} diff --git a/frontend/src/utils/segmentation/rasterizePolygon.ts b/frontend/src/utils/segmentation/rasterizePolygon.ts new file mode 100644 index 0000000..d5d7c43 --- /dev/null +++ b/frontend/src/utils/segmentation/rasterizePolygon.ts @@ -0,0 +1,71 @@ +import type { TumorPolygon } from '../../db/schema'; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +/** + * Rasterize a polygon into a binary mask using an even-odd scanline fill. + * + * - Polygon points are expected in normalized image coordinates (0..1). + * - Output mask is 1 for pixels whose center lies inside the polygon. + */ +export function rasterizePolygonToMask(polygon: TumorPolygon, w: number, h: number): Uint8Array { + const out = new Uint8Array(Math.max(0, w * h)); + if (w <= 0 || h <= 0) return out; + if (!polygon.points || polygon.points.length < 3) return out; + + const n = polygon.points.length; + + // Convert to pixel-space float coordinates. + const xs = new Float64Array(n); + const ys = new Float64Array(n); + for (let i = 0; i < n; i++) { + const p = polygon.points[i]!; + xs[i] = clamp(p.x, 0, 1) * (w - 1); + ys[i] = clamp(p.y, 0, 1) * (h - 1); + } + + const intersections: number[] = []; + + for (let y = 0; y < h; y++) { + intersections.length = 0; + + const yCenter = y + 0.5; + + for (let i = 0; i < n; i++) { + const j = (i + 1) % n; + const y0 = ys[i]!; + const y1 = ys[j]!; + + // Edge crosses scanline? (Half-open to avoid double counting vertices.) + const crosses = (y0 <= yCenter && y1 > yCenter) || (y1 <= yCenter && y0 > yCenter); + if (!crosses) continue; + + const x0 = xs[i]!; + const x1 = xs[j]!; + + const t = (yCenter - y0) / (y1 - y0); + const x = x0 + t * (x1 - x0); + intersections.push(x); + } + + if (intersections.length < 2) continue; + + intersections.sort((a, b) => a - b); + + for (let k = 0; k + 1 < intersections.length; k += 2) { + const a = intersections[k]!; + const b = intersections[k + 1]!; + const xStart = clamp(Math.ceil(Math.min(a, b)), 0, w - 1); + const xEnd = clamp(Math.floor(Math.max(a, b)), 0, w - 1); + + const rowBase = y * w; + for (let x = xStart; x <= xEnd; x++) { + out[rowBase + x] = 1; + } + } + } + + return out; +} diff --git a/frontend/src/utils/segmentation/segmentTumor.ts b/frontend/src/utils/segmentation/segmentTumor.ts new file mode 100644 index 0000000..3a96da5 --- /dev/null +++ b/frontend/src/utils/segmentation/segmentTumor.ts @@ -0,0 +1,1484 @@ +import type { NormalizedPoint, TumorPolygon, TumorThreshold } from '../../db/schema'; +import { computeGeodesicDistanceToSeeds } from './geodesicDistance'; +import { marchingSquaresContour } from './marchingSquares'; +import { morphologicalClose, morphologicalOpen } from './morphology'; +import { rdpSimplify } from './simplify'; +import { chaikinSmooth } from './smooth'; + +export type SegmentationResult = { + polygon: TumorPolygon; + threshold: TumorThreshold; + /** Seed centroid in normalized image coordinates. */ + seed: NormalizedPoint; + meta: { + areaPx: number; + areaNorm: number; + imageWidth: number; + imageHeight: number; + }; +}; + +export type SegmentTumorOptions = { + /** + * Optional overrides for max distance gating from the painted boundary. + * + * maxDist ~= max(baseMin, paintScale * paintScaleFactor) + thresholdWidth * thresholdWidthFactor + */ + maxDistToPaint?: { + baseMin: number; + paintScaleFactor: number; + thresholdWidthFactor: number; + }; + + /** + * Optional *soft* distance penalty. + * + * If provided, the intensity tolerance is linearly scaled by distance-to-painted-boundary: + * - distance=0 => scale=1 + * - distance=max => scale=distanceToleranceScaleMin + * + * Lower values reduce leakage/FP and usually improve boundary alignment (more "granular"), + * but can increase FN if the tumor extends far beyond the painted region. + */ + distanceToleranceScaleMin?: number; + + /** + * Optional edge-aware tightening of the intensity tolerance. + * + * When enabled (>0), pixels with higher local gradient magnitude are held to a tighter tolerance. + * This helps stop region-grow leakage across edges and typically improves boundary granularity. + * + * 0 disables the edge penalty. + */ + edgePenaltyStrength?: number; + + /** + * Optional asymmetric tolerance band around the anchor. + * + * Default is symmetric: [anchor - tolerance, anchor + tolerance]. + * With these scales, the band becomes: + * - low = anchor - tolerance * toleranceLowScale + * - high = anchor + tolerance * toleranceHighScale + * + * This can help reduce leakage when only one side of the intensity spectrum is problematic. + */ + toleranceLowScale?: number; + toleranceHighScale?: number; + + /** + * Automatic background model derived from an annulus around the painted stroke. + * + * This is brush-only: it does not require explicit negative/background marking. + * + * If `enabled` is undefined, the implementation may fall back to a localStorage gate + * (e.g. `miraviewer:segmentation-v2`). + */ + bgModel?: { + enabled?: boolean; + /** Minimum manhattan distance (px) from paint to consider as background samples. */ + annulusMinPx?: number; + /** Maximum manhattan distance (px) from paint to consider as background samples. */ + annulusMaxPx?: number; + /** Cap on background sample count for performance. */ + maxSamples?: number; + /** How much more background-like a pixel must be to get rejected (z-score margin). */ + rejectMarginZ?: number; + /** Exclude very strong edges from background sampling to reduce mixing. */ + edgeExclusionGrad?: number; + }; + + /** + * Edge-aware geodesic distance gating. Distances grow faster when crossing strong edges. + * + * If `enabled` is undefined, the implementation may fall back to a localStorage gate + * (e.g. `miraviewer:segmentation-v2`). + */ + geodesic?: { + enabled?: boolean; + /** How strongly edges penalize distance (0 disables edge penalty in the distance metric). */ + edgeCostStrength?: number; + }; + + /** + * How many times to run morphological open before contour extraction. + * 0 disables the open. + */ + morphologicalOpenIterations?: number; + + /** + * How many times to run morphological close before contour extraction. + * 0 disables the close. + */ + morphologicalCloseIterations?: number; + + /** Chaikin smoothing iterations for the output polygon. (0 disables smoothing.) */ + smoothingIterations?: number; + + /** RDP epsilon for output polygon simplification. */ + simplifyEpsilon?: number; + + /** + * Force-enable/disable the experimental local-adaptive path. + * + * - If undefined, falls back to the localStorage gate. + * - If true/false, overrides the gate. + */ + adaptiveEnabled?: boolean; +}; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function medianOfNumbers(values: number[]): number { + if (values.length === 0) return 0; + const sorted = [...values].sort((a, b) => a - b); + const mid = Math.floor(sorted.length / 2); + if (sorted.length % 2 === 1) return sorted[mid] ?? 0; + const a = sorted[mid - 1] ?? 0; + const b = sorted[mid] ?? 0; + return (a + b) / 2; +} + +type RobustStats = { mu: number; sigma: number }; + +function robustStats(samples: number[], sigmaFloor = 6): RobustStats | null { + if (samples.length < 16) return null; + + const mu = medianOfNumbers(samples); + const abs = samples.map((v) => Math.abs(v - mu)); + const mad = medianOfNumbers(abs); + + // Convert MAD to a robust estimate of sigma (normal distribution factor). + const sigmaMad = 1.4826 * mad; + + // Keep a floor so we don't become overly confident from small/noisy samples. + const sigma = Math.max(sigmaFloor, sigmaMad); + + if (!Number.isFinite(mu) || !Number.isFinite(sigma)) return null; + return { mu, sigma }; +} + +type IntegralImages = { + w: number; + h: number; + sum: Float64Array; // (w+1)*(h+1) + sumSq: Float64Array; // (w+1)*(h+1) +}; + +const integralCache = new WeakMap(); + +function getIntegralImages(gray: Uint8Array, w: number, h: number): IntegralImages { + const existing = integralCache.get(gray); + if (existing && existing.w === w && existing.h === h) return existing; + + const w1 = w + 1; + const h1 = h + 1; + const sum = new Float64Array(w1 * h1); + const sumSq = new Float64Array(w1 * h1); + + for (let y = 0; y < h; y++) { + let rowSum = 0; + let rowSumSq = 0; + for (let x = 0; x < w; x++) { + const v = gray[y * w + x] ?? 0; + rowSum += v; + rowSumSq += v * v; + + const i = (y + 1) * w1 + (x + 1); + const above = y * w1 + (x + 1); + sum[i] = sum[above] + rowSum; + sumSq[i] = sumSq[above] + rowSumSq; + } + } + + const computed: IntegralImages = { w, h, sum, sumSq }; + integralCache.set(gray, computed); + return computed; +} + +function rectSum(prefix: Float64Array, w1: number, x0: number, y0: number, x1: number, y1: number): number { + // x0,y0,x1,y1 are inclusive in image coordinates. + // Convert to integral image coordinates (exclusive upper bounds). + const xa = x0; + const ya = y0; + const xb = x1 + 1; + const yb = y1 + 1; + + const A = prefix[ya * w1 + xa] ?? 0; + const B = prefix[ya * w1 + xb] ?? 0; + const C = prefix[yb * w1 + xa] ?? 0; + const D = prefix[yb * w1 + xb] ?? 0; + return D - B - C + A; +} + +function localMeanStd( + integrals: IntegralImages, + x: number, + y: number, + radius: number +): { mean: number; std: number } { + const w = integrals.w; + const h = integrals.h; + const w1 = w + 1; + + const x0 = clamp(x - radius, 0, w - 1); + const y0 = clamp(y - radius, 0, h - 1); + const x1 = clamp(x + radius, 0, w - 1); + const y1 = clamp(y + radius, 0, h - 1); + + const area = Math.max(1, (x1 - x0 + 1) * (y1 - y0 + 1)); + const s = rectSum(integrals.sum, w1, x0, y0, x1, y1); + const s2 = rectSum(integrals.sumSq, w1, x0, y0, x1, y1); + const mean = s / area; + const v = Math.max(0, s2 / area - mean * mean); + const std = Math.sqrt(v); + return { mean, std }; +} + +type GradientCache = { + w: number; + h: number; + grad: Uint8Array; +}; + +const gradientCache = new WeakMap(); + +function getGradientMagnitude(gray: Uint8Array, w: number, h: number): Uint8Array { + const existing = gradientCache.get(gray); + if (existing && existing.w === w && existing.h === h) return existing.grad; + + const out = new Uint8Array(w * h); + + // Simple Sobel gradient magnitude (L1 approx), scaled to 0..255. + for (let y = 1; y < h - 1; y++) { + for (let x = 1; x < w - 1; x++) { + const i00 = gray[(y - 1) * w + (x - 1)] ?? 0; + const i01 = gray[(y - 1) * w + x] ?? 0; + const i02 = gray[(y - 1) * w + (x + 1)] ?? 0; + const i10 = gray[y * w + (x - 1)] ?? 0; + const i12 = gray[y * w + (x + 1)] ?? 0; + const i20 = gray[(y + 1) * w + (x - 1)] ?? 0; + const i21 = gray[(y + 1) * w + x] ?? 0; + const i22 = gray[(y + 1) * w + (x + 1)] ?? 0; + + const gx = -i00 - 2 * i10 - i20 + i02 + 2 * i12 + i22; + const gy = -i00 - 2 * i01 - i02 + i20 + 2 * i21 + i22; + + const mag = (Math.abs(gx) + Math.abs(gy)) / 4; + out[y * w + x] = clamp(Math.round(mag), 0, 255); + } + } + + gradientCache.set(gray, { w, h, grad: out }); + return out; +} + +function toGrayscaleByte(r: number, g: number, b: number): number { + // Perceptual luminance approximation. + return Math.round(0.2126 * r + 0.7152 * g + 0.0722 * b); +} + +async function decodePngBlobToImageData(blob: Blob): Promise { + // Prefer createImageBitmap (fast), but fall back to decoding for broader compatibility. + try { + const bitmap = await createImageBitmap(blob); + const canvas = document.createElement('canvas'); + canvas.width = bitmap.width; + canvas.height = bitmap.height; + const ctx = canvas.getContext('2d'); + if (!ctx) throw new Error('Failed to create canvas context'); + ctx.drawImage(bitmap, 0, 0); + return ctx.getImageData(0, 0, canvas.width, canvas.height); + } catch { + const url = URL.createObjectURL(blob); + try { + const img = new Image(); + img.decoding = 'async'; + img.src = url; + + if (typeof img.decode === 'function') { + await img.decode(); + } else { + await new Promise((resolve, reject) => { + img.onload = () => resolve(); + img.onerror = () => reject(new Error('Failed to decode PNG')); + }); + } + + const canvas = document.createElement('canvas'); + canvas.width = img.naturalWidth; + canvas.height = img.naturalHeight; + const ctx = canvas.getContext('2d'); + if (!ctx) throw new Error('Failed to create canvas context'); + ctx.drawImage(img, 0, 0); + return ctx.getImageData(0, 0, canvas.width, canvas.height); + } finally { + URL.revokeObjectURL(url); + } + } +} + +function computeSeedCentroid(seedPx: Array<{ x: number; y: number }>, w: number, h: number): NormalizedPoint { + let sx = 0; + let sy = 0; + for (const p of seedPx) { + sx += p.x; + sy += p.y; + } + const n = Math.max(1, seedPx.length); + const cx = sx / n; + const cy = sy / n; + return { x: clamp(cx / Math.max(1, w - 1), 0, 1), y: clamp(cy / Math.max(1, h - 1), 0, 1) }; +} + +function estimateThresholdFromSeeds(gray: Uint8Array, w: number, h: number, seedPx: Array<{ x: number; y: number }>): TumorThreshold { + if (seedPx.length === 0) { + console.warn('[estimateThresholdFromSeeds] No seed points, using default range'); + return { low: 64, high: 192 }; + } + + const samples: number[] = []; + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const p of seedPx) { + const x = clamp(Math.round(p.x), 0, w - 1); + const y = clamp(Math.round(p.y), 0, h - 1); + + minX = Math.min(minX, x); + minY = Math.min(minY, y); + maxX = Math.max(maxX, x); + maxY = Math.max(maxY, y); + + samples.push(gray[y * w + x] ?? 0); + } + + if (samples.length === 0) { + console.warn('[estimateThresholdFromSeeds] No valid samples, using default range'); + return { low: 64, high: 192 }; + } + + const bboxW = Number.isFinite(maxX) ? Math.max(1, maxX - minX + 1) : 1; + const bboxH = Number.isFinite(maxY) ? Math.max(1, maxY - minY + 1) : 1; + const paintAreaPx = bboxW * bboxH; + const isLargePaintBlob = paintAreaPx > 2500 || seedPx.length > 200; + + samples.sort((a, b) => a - b); + const pick = (q: number) => samples[Math.floor(clamp(q, 0, 1) * (samples.length - 1))] ?? 0; + + // Initial threshold should start *near the paint* (so the first polygon looks reasonable). + // + // For very large/filled paint blobs, stroke samples often span multiple tissues (crossing the + // boundary). In practice, auto-tune frequently lands on a very large tolerance and relies on + // distance gating to prevent huge FP leaks. If we start too narrow here, the default result can be + // extremely conservative (high precision but terrible recall), which matches the failure mode + // we keep seeing in GT reports. + const p05 = pick(0.05); + const p20 = pick(0.2); + const p50 = pick(0.5); + const p80 = pick(0.8); + const p95 = pick(0.95); + + const isVeryLargePaintBlob = paintAreaPx > 8000 || seedPx.length > 400; + + // Keep the anchor stable (paint median). Any asymmetry should come from explicit opts / auto-tune + // rather than trying to infer contrast direction from paint samples. + const anchor = clamp(p50, 0, 255); + + const width = (() => { + if (isVeryLargePaintBlob) { + // Near-max band: rely on distance gating + background model (if enabled) to contain leakage. + // This intentionally mirrors what auto-tune often finds for big paint blobs. + return 240; // tolerance = 120 + } + + if (isLargePaintBlob) { + // Wider band for filled strokes so we don't miss heterogeneous tumor signal. + const base = (p95 - p05) + 24; + return clamp(base, 160, 240); + } + + // Small/medium strokes: keep it tighter so the first result doesn't explode. + return clamp((p80 - p20) + 12, 24, 64); + })(); + + const tolerance = clamp(Math.round(width / 2), 0, 127); + + return { + low: clamp(anchor - tolerance, 0, 255), + high: clamp(anchor + tolerance, 0, 255), + anchor, + tolerance, + }; +} + +export type GrayscaleImage = { + gray: Uint8Array; + width: number; + height: number; +}; + +export async function decodeCapturedPngToGrayscale(png: Blob): Promise { + const imageData = await decodePngBlobToImageData(png); + const w = imageData.width; + const h = imageData.height; + + const gray = new Uint8Array(w * h); + const d = imageData.data; + for (let i = 0, p = 0; i < d.length; i += 4, p++) { + gray[p] = toGrayscaleByte(d[i], d[i + 1], d[i + 2]); + } + + return { gray, width: w, height: h }; +} + +export function estimateThresholdFromSeedPoints( + gray: Uint8Array, + w: number, + h: number, + seedPointsNorm: NormalizedPoint[] +): TumorThreshold { + const seedPx = seedPointsNorm.map((p) => ({ + x: clamp(p.x, 0, 1) * (w - 1), + y: clamp(p.y, 0, 1) * (h - 1), + })); + return estimateThresholdFromSeeds(gray, w, h, seedPx); +} + +export function regionGrowMask( + allowed: Uint8Array, + w: number, + h: number, + seeds: Array<{ x: number; y: number }>, + roi?: { x0: number; y0: number; x1: number; y1: number } +): { mask: Uint8Array; area: number } { + const mask = new Uint8Array(w * h); + const visited = new Uint8Array(w * h); + + const x0 = roi ? clamp(Math.floor(roi.x0), 0, w - 1) : 0; + const y0 = roi ? clamp(Math.floor(roi.y0), 0, h - 1) : 0; + const x1 = roi ? clamp(Math.ceil(roi.x1), 0, w - 1) : w - 1; + const y1 = roi ? clamp(Math.ceil(roi.y1), 0, h - 1) : h - 1; + + // Use a smaller queue size based on ROI to avoid memory issues. + const roiW = x1 - x0 + 1; + const roiH = y1 - y0 + 1; + const maxQueueSize = roiW * roiH; + + const qx = new Int32Array(maxQueueSize); + const qy = new Int32Array(maxQueueSize); + let qh = 0; + let qt = 0; + + const push = (x: number, y: number) => { + const i = y * w + x; + if (visited[i]) return; + visited[i] = 1; // mark enqueued so we never overflow the queue with duplicates + if (qt >= maxQueueSize) { + // This should be impossible if `visited` is correct, but keep a guard anyway. + console.error('[regionGrowMask] Queue overflow (bug)', { qt, maxQueueSize }); + return; + } + qx[qt] = x; + qy[qt] = y; + qt++; + }; + + for (const s of seeds) { + const x = clamp(Math.round(s.x), x0, x1); + const y = clamp(Math.round(s.y), y0, y1); + push(x, y); + } + + let area = 0; + let iterations = 0; + const maxIterations = maxQueueSize; // each pixel can be enqueued at most once + + while (qh < qt && iterations++ < maxIterations) { + const x = qx[qh]; + const y = qy[qh]; + qh++; + + const i = y * w + x; + + if (!allowed[i]) continue; + + mask[i] = 1; + area++; + + // 4-neighborhood. + if (x > x0) push(x - 1, y); + if (x < x1) push(x + 1, y); + if (y > y0) push(x, y - 1); + if (y < y1) push(x, y + 1); + } + + if (iterations >= maxIterations) { + console.warn('[regionGrowMask] Hit max iteration guard (unexpected)'); + } + + return { mask, area }; +} + +let cachedDistKey: string | null = null; +let cachedDist: Int32Array | null = null; + +type GeodesicCacheEntry = { key: string; dist: Float32Array; maxComputed: number }; +const geodesicCache = new WeakMap(); + +function computeDistanceToPaint( + paintPx: Array<{ x: number; y: number }>, + w: number, + h: number, + roi: { x0: number; y0: number; x1: number; y1: number } +): Int32Array { + const dist = new Int32Array(w * h); + dist.fill(-1); + + if (paintPx.length === 0) return dist; + + const x0 = clamp(Math.floor(roi.x0), 0, w - 1); + const y0 = clamp(Math.floor(roi.y0), 0, h - 1); + const x1 = clamp(Math.ceil(roi.x1), 0, w - 1); + const y1 = clamp(Math.ceil(roi.y1), 0, h - 1); + + // We want distance-to-*boundary* rather than distance-to-any-stroke-point. + // + // If the user paints a filled scribble, interior pixels have many nearby stroke points, + // so distance-to-stroke would be ~0 everywhere inside and wouldn't express "how deep inside" + // a pixel is. Instead, we approximate the painted boundary by using the outer ring of paint points. + let cx = 0; + let cy = 0; + for (const p of paintPx) { + cx += p.x; + cy += p.y; + } + cx /= paintPx.length; + cy /= paintPx.length; + + let maxD2 = 0; + const d2s = new Float64Array(paintPx.length); + for (let i = 0; i < paintPx.length; i++) { + const dx = paintPx[i].x - cx; + const dy = paintPx[i].y - cy; + const d2 = dx * dx + dy * dy; + d2s[i] = d2; + if (d2 > maxD2) maxD2 = d2; + } + + // Decide whether the paint looks like a filled "blob" or a thin/elongated stroke. + // + // Why: + // - For blob-like paint, we want distance-to-*boundary* (prevents interior pixels being treated as "distance 0"). + // - For thin/elongated strokes, the "outer ring" heuristic tends to pick only the endpoints, which makes + // distance-to-paint meaningless and can cause large FP leaks or brittle FN behavior. + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + for (const p of paintPx) { + if (p.x < minX) minX = p.x; + if (p.y < minY) minY = p.y; + if (p.x > maxX) maxX = p.x; + if (p.y > maxY) maxY = p.y; + } + const bboxW = Math.max(1, maxX - minX); + const bboxH = Math.max(1, maxY - minY); + const aspect = Math.min(bboxW, bboxH) / Math.max(bboxW, bboxH); + + // Outer ring threshold (~70% of max radius). If too few points qualify, fall back to all points. + const ringD2 = maxD2 * 0.7 * 0.7; + const boundarySeeds: Array<{ x: number; y: number }> = []; + for (let i = 0; i < paintPx.length; i++) { + if (d2s[i] >= ringD2) boundarySeeds.push(paintPx[i]); + } + + // If paint is very elongated, prefer distance-to-stroke to avoid endpoint-only seeding. + const seeds = aspect < 0.35 ? paintPx : boundarySeeds.length >= 8 ? boundarySeeds : paintPx; + + const roiW = x1 - x0 + 1; + const roiH = y1 - y0 + 1; + const maxQueueSize = roiW * roiH; + + const qx = new Int32Array(maxQueueSize); + const qy = new Int32Array(maxQueueSize); + let qh = 0; + let qt = 0; + + const push = (x: number, y: number, d: number) => { + const i = y * w + x; + if (dist[i] !== -1) return; + dist[i] = d; + qx[qt] = x; + qy[qt] = y; + qt++; + }; + + // Guard against pathological cases where we have more seeds than ROI pixels. + const maxSeeds = Math.max(1, Math.min(seeds.length, maxQueueSize)); + const seedStep = Math.max(1, Math.floor(seeds.length / maxSeeds)); + + for (let si = 0; si < seeds.length; si += seedStep) { + const s = seeds[si]!; + const x = clamp(Math.round(s.x), x0, x1); + const y = clamp(Math.round(s.y), y0, y1); + push(x, y, 0); + } + + while (qh < qt) { + const x = qx[qh]; + const y = qy[qh]; + qh++; + + const base = dist[y * w + x]; + const nd = base + 1; + + if (x > x0) push(x - 1, y, nd); + if (x < x1) push(x + 1, y, nd); + if (y > y0) push(x, y - 1, nd); + if (y < y1) push(x, y + 1, nd); + } + + return dist; +} + +function computeSeedRoi(seeds: Array<{ x: number; y: number }>, w: number, h: number): { x0: number; y0: number; x1: number; y1: number } { + if (seeds.length === 0) { + return { x0: 0, y0: 0, x1: w - 1, y1: h - 1 }; + } + + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + + for (const s of seeds) { + minX = Math.min(minX, s.x); + minY = Math.min(minY, s.y); + maxX = Math.max(maxX, s.x); + maxY = Math.max(maxY, s.y); + } + + // Expand ROI so the tumor can extend beyond the rough paint strokes. + // + // We treat the paint region as a hint (where to start), not a hard boundary. + const bboxW = Math.max(1, maxX - minX); + const bboxH = Math.max(1, maxY - minY); + + // Minimum expansion based on image size so small paint strokes don't overly constrain ROI. + // + // Precision note: + // If this margin is too large, the allowed mask may include large same-intensity regions far from the paint, + // which can produce huge false-positive expansions (low precision), especially on FLAIR. + const minDim = Math.min(w, h); + const minMargin = Math.max(24, Math.round(minDim * 0.08)); + + // Expand to ~2.0x bbox search region (margin ~= 0.5*bbox). + const marginX = Math.max(minMargin, Math.round(bboxW * 0.5)); + const marginY = Math.max(minMargin, Math.round(bboxH * 0.5)); + + return { + x0: clamp(minX - marginX, 0, w - 1), + y0: clamp(minY - marginY, 0, h - 1), + x1: clamp(maxX + marginX, 0, w - 1), + y1: clamp(maxY + marginY, 0, h - 1), + }; +} + +function buildAllowedMask(params: { + gray: Uint8Array; + w: number; + h: number; + roi: { x0: number; y0: number; x1: number; y1: number }; + paintPx: Array<{ x: number; y: number }>; + threshold: TumorThreshold; + looksLikePaintGesture: boolean; + distToPaint?: Int32Array; + maxDistToPaint?: number; + /** Optional soft distance penalty (see SegmentTumorOptions). */ + distanceToleranceScaleMin?: number; + /** Optional edge penalty (see SegmentTumorOptions). */ + edgePenaltyStrength?: number; + /** Optional asymmetric tolerance scales (see SegmentTumorOptions). */ + toleranceLowScale?: number; + toleranceHighScale?: number; + /** Optional brush-only background model (see SegmentTumorOptions). */ + bgModel?: SegmentTumorOptions['bgModel']; + /** Optional edge-aware geodesic gating (see SegmentTumorOptions). */ + geodesic?: SegmentTumorOptions['geodesic']; + adaptiveEnabled?: boolean; +}): Uint8Array { + const { gray, w, h, roi, paintPx, threshold, looksLikePaintGesture, distToPaint, maxDistToPaint } = params; + + const allowed = new Uint8Array(w * h); + if (w <= 0 || h <= 0 || paintPx.length === 0) return allowed; + + const x0 = clamp(Math.floor(roi.x0), 0, w - 1); + const y0 = clamp(Math.floor(roi.y0), 0, h - 1); + const x1 = clamp(Math.ceil(roi.x1), 0, w - 1); + const y1 = clamp(Math.ceil(roi.y1), 0, h - 1); + + const anchor = + typeof threshold.anchor === 'number' + ? clamp(Math.round(threshold.anchor), 0, 255) + : clamp(Math.round((threshold.low + threshold.high) / 2), 0, 255); + const tolerance = + typeof threshold.tolerance === 'number' + ? clamp(Math.round(threshold.tolerance), 0, 127) + : clamp(Math.round((threshold.high - threshold.low) / 2), 0, 127); + + const maxDist = typeof maxDistToPaint === 'number' ? Math.max(0, Math.round(maxDistToPaint)) : undefined; + const distTolScaleMin = + typeof params.distanceToleranceScaleMin === 'number' + ? clamp(params.distanceToleranceScaleMin, 0.15, 1) + : looksLikePaintGesture + ? 0.25 + : 1; + const edgePenaltyStrength = + typeof params.edgePenaltyStrength === 'number' ? clamp(params.edgePenaltyStrength, 0, 1) : 0; + + const tolLowScale = + typeof params.toleranceLowScale === 'number' ? clamp(params.toleranceLowScale, 0.25, 2) : 1; + const tolHighScale = + typeof params.toleranceHighScale === 'number' ? clamp(params.toleranceHighScale, 0.25, 2) : 1; + + const lowTolBase = tolerance * tolLowScale; + const highTolBase = tolerance * tolHighScale; + + // Approximate "inner" paint region by a radial cutoff relative to the paint centroid. + // + // IMPORTANT: We apply distance gating / penalties mainly to prevent *outward leakage*. + // Penalizing deep interior pixels can create false negatives when the user paints a small + // or off-center scribble. So we treat the inner region as always eligible and only apply + // distance-based constraints outside of it. + let paintCx = 0; + let paintCy = 0; + for (const p of paintPx) { + paintCx += p.x; + paintCy += p.y; + } + paintCx /= paintPx.length; + paintCy /= paintPx.length; + + let paintMaxD2 = 0; + const paintD2s = new Float64Array(paintPx.length); + for (let i = 0; i < paintPx.length; i++) { + const dx = paintPx[i].x - paintCx; + const dy = paintPx[i].y - paintCy; + const d2 = dx * dx + dy * dy; + paintD2s[i] = d2; + if (d2 > paintMaxD2) paintMaxD2 = d2; + } + + // Compute an "inner" paint region cutoff. + // + // IMPORTANT: using max radius can make this way too large for thin/elongated strokes (line scribbles), + // which disables distance gating over a huge area and can lead to low precision leaks. + // + // We scale the inner fraction by paint bbox aspect ratio: + // - blob-like paint (aspect~1) => innerFrac ~0.7 (original behavior) + // - elongated paint (aspect<<1) => innerFrac shrinks toward ~0.35 + let bboxMinX = Number.POSITIVE_INFINITY; + let bboxMinY = Number.POSITIVE_INFINITY; + let bboxMaxX = Number.NEGATIVE_INFINITY; + let bboxMaxY = Number.NEGATIVE_INFINITY; + for (const p of paintPx) { + if (p.x < bboxMinX) bboxMinX = p.x; + if (p.y < bboxMinY) bboxMinY = p.y; + if (p.x > bboxMaxX) bboxMaxX = p.x; + if (p.y > bboxMaxY) bboxMaxY = p.y; + } + const bboxW = Math.max(1, bboxMaxX - bboxMinX); + const bboxH = Math.max(1, bboxMaxY - bboxMinY); + const aspect = Math.min(bboxW, bboxH) / Math.max(bboxW, bboxH); + + const innerFrac = 0.35 + 0.35 * clamp(aspect, 0, 1); + const paintInnerD2 = paintMaxD2 * innerFrac * innerFrac; + + // "Outer ring" (fixed fraction) used for edge sampling + boundary seeding. + const ringD2 = paintMaxD2 * 0.7 * 0.7; + const boundarySeeds: Array<{ x: number; y: number }> = []; + for (let i = 0; i < paintPx.length; i++) { + if (paintD2s[i] >= ringD2) boundarySeeds.push(paintPx[i]!); + } + + const distSeeds = (() => { + // Prefer boundary seeding for blob-like paint so distance expresses "how far outside the paint boundary". + // For thin/elongated strokes, boundary seeding tends to pick endpoints and becomes unstable, so fall back + // to distance-to-stroke instead. + const raw = aspect < 0.35 ? paintPx : boundarySeeds.length >= 8 ? boundarySeeds : paintPx; + + // Cap seed count for performance + stable cache keys. + const maxSeeds = 256; + const step = Math.max(1, Math.floor(raw.length / maxSeeds)); + return raw.filter((_, i) => i % step === 0); + })(); + + // Rollout gate for segmentation v2 (background model + geodesic distance). + const v2Enabled = + looksLikePaintGesture && + typeof localStorage !== 'undefined' && + localStorage.getItem('miraviewer:segmentation-v2') === '1'; + + const geodesicEnabled = + typeof params.geodesic?.enabled === 'boolean' ? params.geodesic.enabled : v2Enabled; + const bgModelEnabled = typeof params.bgModel?.enabled === 'boolean' ? params.bgModel.enabled : v2Enabled; + + // Compute gradient once when needed; cached by grayscale buffer identity. + const grad = + looksLikePaintGesture && (edgePenaltyStrength > 0 || geodesicEnabled || bgModelEnabled) + ? getGradientMagnitude(gray, w, h) + : null; + + const edgePenalty = (() => { + if (edgePenaltyStrength <= 0) return null; + if (!looksLikePaintGesture) return null; + if (!grad) return null; + + // Estimate edge strength near painted boundary ring. + const sampleCount = 64; + const step = Math.max(1, Math.floor(paintPx.length / sampleCount)); + + const edgeSamples: number[] = []; + for (let i = 0; i < paintPx.length; i += step) { + if (paintD2s[i] < paintInnerD2) continue; + const x = clamp(Math.round(paintPx[i].x), 0, w - 1); + const y = clamp(Math.round(paintPx[i].y), 0, h - 1); + edgeSamples.push(grad[y * w + x] ?? 0); + } + + if (edgeSamples.length === 0) return null; + + const edgeMedian = medianOfNumbers(edgeSamples); + + // Use a floor so we don't massively over-penalize weak/noisy gradients. + // This keeps the edge penalty focused on truly strong edges (e.g. tissue boundaries) + // while still allowing it to activate even if the painted ring isn't perfectly on the edge. + const barrier = Math.max(25, edgeMedian * 1.2); + return { grad, barrier }; + })(); + + // Local-adaptive, edge-aware thresholding is still experimental. Gate it behind a flag so we don't + // regress segmentation quality by default. + const adaptiveEnabled = + typeof params.adaptiveEnabled === 'boolean' + ? params.adaptiveEnabled + : looksLikePaintGesture && + typeof localStorage !== 'undefined' && + localStorage.getItem('miraviewer:segmentation-adaptive') === '1'; + + const hasDist = distToPaint && typeof maxDist === 'number' && maxDist > 0; + + const geoDistToPaint = (() => { + if (!hasDist) return null; + if (!geodesicEnabled) return null; + if (!grad) return null; + if (distSeeds.length === 0) return null; + + const k = + typeof params.geodesic?.edgeCostStrength === 'number' ? clamp(params.geodesic.edgeCostStrength, 0, 20) : 6; + + // Cache keyed by image identity (gray buffer) + ROI + a subsample signature of boundary/stroke seeds. + // NOTE: We deliberately do NOT key on maxDist, so slider moves don't trigger recomputes. + const sampleCount = 64; + const step = Math.max(1, Math.floor(distSeeds.length / sampleCount)); + const sampled = distSeeds + .filter((_, i) => i % step === 0) + .map((p) => `${Math.round(p.x)},${Math.round(p.y)}`) + .join(';'); + + const key = `${w}x${h}|${x0},${y0},${x1},${y1}|k=${k}|${sampled}`; + const cached = geodesicCache.get(gray); + if (cached && cached.key === key && cached.maxComputed >= maxDist!) return cached.dist; + + // Compute a bit beyond the current maxDist so small slider moves can reuse the cached map. + const computeMaxDist = Math.ceil(maxDist! + 8); + + const edgeBarrier = (() => { + if (!grad) return null; + + const edgeSamples: number[] = []; + const seeds = boundarySeeds.length >= 8 ? boundarySeeds : distSeeds; + const step = Math.max(1, Math.floor(seeds.length / 64)); + + for (let i = 0; i < seeds.length; i += step) { + const p = seeds[i]!; + const xi = clamp(Math.round(p.x), 0, w - 1); + const yi = clamp(Math.round(p.y), 0, h - 1); + edgeSamples.push(grad[yi * w + xi] ?? 0); + } + + if (edgeSamples.length === 0) return null; + + const edgeMedian = medianOfNumbers(edgeSamples); + return Math.max(25, edgeMedian * 1.2); + })(); + + const dist = computeGeodesicDistanceToSeeds({ + w, + h, + roi: { x0, y0, x1, y1 }, + seeds: distSeeds, + grad, + edgeCostStrength: k, + edgeBarrier: edgeBarrier ?? undefined, + maxDist: computeMaxDist, + }); + + geodesicCache.set(gray, { key, dist, maxComputed: computeMaxDist }); + return dist; + })(); + + const bgModel = (() => { + if (!bgModelEnabled) return null; + if (!hasDist) return null; + + const cfg = params.bgModel; + const annulusMinPx = typeof cfg?.annulusMinPx === 'number' ? clamp(Math.round(cfg.annulusMinPx), 1, 64) : 2; + const annulusMaxPxRaw = + typeof cfg?.annulusMaxPx === 'number' ? Math.round(cfg.annulusMaxPx) : Math.min(24, maxDist ?? 24); + const annulusMaxPx = clamp(annulusMaxPxRaw, annulusMinPx + 1, 128); + + const maxSamples = typeof cfg?.maxSamples === 'number' ? clamp(Math.round(cfg.maxSamples), 64, 8192) : 2048; + const rejectMarginZ = typeof cfg?.rejectMarginZ === 'number' ? clamp(cfg.rejectMarginZ, 0, 3) : 0.75; + const edgeExclusionGrad = + typeof cfg?.edgeExclusionGrad === 'number' ? clamp(Math.round(cfg.edgeExclusionGrad), 0, 255) : 200; + + // Tumor samples: sample intensities under the paint stroke. + const tumorSamples: number[] = []; + const paintStep = Math.max(1, Math.floor(paintPx.length / 96)); + for (let k = 0; k < paintPx.length; k += paintStep) { + const px = paintPx[k]!; + const xi = clamp(Math.round(px.x), 0, w - 1); + const yi = clamp(Math.round(px.y), 0, h - 1); + tumorSamples.push(gray[yi * w + xi] ?? 0); + } + + const tumor = robustStats(tumorSamples, 6); + if (!tumor) return null; + + // Background samples: annulus just outside paint. + let candCount = 0; + for (let yy = y0; yy <= y1; yy++) { + for (let xx = x0; xx <= x1; xx++) { + const ii = yy * w + xx; + const dd = distToPaint![ii]; + if (dd < annulusMinPx || dd > annulusMaxPx) continue; + + const g = grad ? grad[ii] ?? 0 : 0; + if (g > edgeExclusionGrad) continue; + + candCount++; + } + } + + if (candCount < 64) return null; + + const stride = Math.max(1, Math.floor(candCount / maxSamples)); + const bgSamples: number[] = []; + let seen = 0; + + for (let yy = y0; yy <= y1 && bgSamples.length < maxSamples; yy++) { + for (let xx = x0; xx <= x1 && bgSamples.length < maxSamples; xx++) { + const ii = yy * w + xx; + const dd = distToPaint![ii]; + if (dd < annulusMinPx || dd > annulusMaxPx) continue; + + const g = grad ? grad[ii] ?? 0 : 0; + if (g > edgeExclusionGrad) continue; + + if (seen % stride === 0) { + bgSamples.push(gray[ii] ?? 0); + } + seen++; + } + } + + const bg = robustStats(bgSamples, 6); + if (!bg) return null; + + return { tumor, bg, rejectMarginZ }; + })(); + + const debugEnabled = + looksLikePaintGesture && + typeof localStorage !== 'undefined' && + localStorage.getItem('miraviewer:debug-segmentation') === '1'; + + if (debugEnabled) { + console.log('[segmentTumor] buildAllowedMask', { + adaptiveEnabled, + v2Enabled, + geodesicEnabled, + bgModelEnabled, + maxDist, + anchor, + tolerance, + tolLowScale, + tolHighScale, + }); + } + + if (!adaptiveEnabled) { + // Default path: absolute intensity band + optional distance gating. + // + // If distanceToleranceScaleMin < 1, we additionally tighten the intensity tolerance as distance + // from the painted boundary increases. This reduces leaking into similar-intensity regions. + + for (let y = y0; y <= y1; y++) { + for (let x = x0; x <= x1; x++) { + const i = y * w + x; + + const dx = x - paintCx; + const dy = y - paintCy; + const radialD2 = dx * dx + dy * dy; + + // Only apply distance penalties outside the inner paint region. + const enforceDist = hasDist && radialD2 >= paintInnerD2; + + let effLowTol = lowTolBase; + let effHighTol = highTolBase; + let d = 0; + + if (enforceDist) { + d = geoDistToPaint ? geoDistToPaint[i]! : distToPaint![i]; + // Geodesic distance returns +Infinity for pixels outside the explored region. Treat non-finite + // values as out-of-range so they don't accidentally bypass distance gating. + if (!Number.isFinite(d) || d < 0) continue; + + // IMPORTANT: maxDist is a hard cutoff. This prevents "infinite radius" leakage into far-away + // same-intensity tissue when distanceToleranceScaleMin < 1. + if (d > maxDist!) continue; + + // Optional soft penalty *within* [0, maxDist]: linearly tighten tolerance with distance. + if (distTolScaleMin < 0.999) { + const frac = clamp(d / maxDist!, 0, 1); + const scale = 1 - frac * (1 - distTolScaleMin); + effLowTol = lowTolBase * scale; + effHighTol = highTolBase * scale; + } + + // Edge penalty is intended to prevent *outward leakage* across strong edges. + // We keep it selective so it doesn't create false negatives due to interior texture. + if (edgePenalty) { + const EDGE_PENALTY_START_FRAC = 0.3; + if (d >= Math.round(maxDist! * EDGE_PENALTY_START_FRAC)) { + const g = edgePenalty.grad[i] ?? 0; + const t = g / edgePenalty.barrier; + + // Only penalize sufficiently strong edges. + const ONSET = 0.6; + if (t > ONSET) { + const edgeNorm = clamp((t - ONSET) / (1 - ONSET), 0, 1); + const edgeWeight = edgeNorm * edgeNorm; + const mult = 1 - edgeWeight * edgePenaltyStrength; + effLowTol *= mult; + effHighTol *= mult; + } + } + } + } + + const v = gray[i] ?? 0; + if (v >= anchor - effLowTol && v <= anchor + effHighTol) { + // Background model is intended to prevent outward leakage; don't apply it deep inside the paint. + if (bgModel && enforceDist) { + const zTumor = Math.abs(v - bgModel.tumor.mu) / bgModel.tumor.sigma; + const zBg = Math.abs(v - bgModel.bg.mu) / bgModel.bg.sigma; + + // Only reject when the pixel is substantially more background-like. + if (zBg + bgModel.rejectMarginZ < zTumor) { + continue; + } + } + + allowed[i] = 1; + } + } + } + + return allowed; + } + + // Adaptive path: compare locally normalized intensity (z-score) to a paint-derived anchor. + // This can help in the presence of intensity inhomogeneity / bias fields, but can also hurt. + const integrals = getIntegralImages(gray, w, h); + + // Window radius for local stats. ~17x17 at 512px. + const radius = clamp(Math.round(Math.min(w, h) * 0.015), 5, 14); + + // Estimate anchorZ and sigmaPaint from painted pixels. + const sampleCount = 64; + const step = Math.max(1, Math.floor(paintPx.length / sampleCount)); + + const zs: number[] = []; + const sigmas: number[] = []; + + for (let k = 0; k < paintPx.length; k += step) { + const px = paintPx[k]; + const x = clamp(Math.round(px.x), 0, w - 1); + const y = clamp(Math.round(px.y), 0, h - 1); + const i = y * w + x; + + const { mean, std } = localMeanStd(integrals, x, y, radius); + const s = std > 1e-6 ? std : 1; + const v = gray[i] ?? 0; + zs.push((v - mean) / s); + sigmas.push(s); + } + + const anchorZ = medianOfNumbers(zs); + const sigmaPaint = Math.max(6, medianOfNumbers(sigmas)); + + // Convert intensity tolerance (0..127) into a normalized tolerance. + const zTolLowBase = (tolerance * tolLowScale) / sigmaPaint; + const zTolHighBase = (tolerance * tolHighScale) / sigmaPaint; + + // Edge-aware soft penalty. + const gradMag = getGradientMagnitude(gray, w, h); + + // Estimate edge strength near painted boundary ring. + const edgeSamples: number[] = []; + for (let i = 0; i < paintPx.length; i += step) { + if (paintD2s[i] < ringD2) continue; + const x = clamp(Math.round(paintPx[i]!.x), 0, w - 1); + const y = clamp(Math.round(paintPx[i]!.y), 0, h - 1); + edgeSamples.push(gradMag[y * w + x] ?? 0); + } + + const edgeMedian = medianOfNumbers(edgeSamples); + const edgeBarrier = edgeMedian >= 25 ? Math.max(1, edgeMedian * 1.2) : null; + + // Distance penalty: pixels far from the painted boundary get a tighter tolerance. + // 1.0 near boundary -> distanceToleranceScaleMin (default 0.2) at max distance. + const distScaleMin = + typeof params.distanceToleranceScaleMin === 'number' + ? clamp(params.distanceToleranceScaleMin, 0.15, 1) + : v2Enabled + ? 1 + : 0.2; + const distScaleFor = (d: number) => { + if (!maxDist || maxDist <= 0) return 1; + const frac = clamp(d / maxDist, 0, 1); + return distScaleMin + (1 - distScaleMin) * (1 - frac); + }; + + for (let y = y0; y <= y1; y++) { + for (let x = x0; x <= x1; x++) { + const i = y * w + x; + + const dx = x - paintCx; + const dy = y - paintCy; + const radialD2 = dx * dx + dy * dy; + + // Only apply distance penalties outside the inner paint region. + const enforceDist = hasDist && radialD2 >= paintInnerD2; + + let scale = 1; + if (enforceDist) { + const d = geoDistToPaint ? geoDistToPaint[i]! : distToPaint![i]; + if (!Number.isFinite(d) || d < 0 || d > maxDist!) continue; + scale *= distScaleFor(d); + } + + if (edgeBarrier) { + const g = gradMag[i] ?? 0; + const edgeNorm = clamp(g / edgeBarrier, 0, 1); + // Tighten tolerance near strong edges to avoid leaking across boundaries. + scale *= 1 - edgeNorm * 0.5; + } + + const zTolLow = zTolLowBase * scale; + const zTolHigh = zTolHighBase * scale; + if (zTolLow <= 0 && zTolHigh <= 0) continue; + + const { mean, std } = localMeanStd(integrals, x, y, radius); + const s = std > 1e-6 ? std : 1; + const v = gray[i] ?? 0; + const z = (v - mean) / s; + if (z >= anchorZ - zTolLow && z <= anchorZ + zTolHigh) { + // Background model is intended to prevent outward leakage; don't apply it deep inside the paint. + if (bgModel && enforceDist) { + const zTumor = Math.abs(v - bgModel.tumor.mu) / bgModel.tumor.sigma; + const zBg = Math.abs(v - bgModel.bg.mu) / bgModel.bg.sigma; + + // Only reject when the pixel is substantially more background-like. + if (zBg + bgModel.rejectMarginZ < zTumor) { + continue; + } + } + + allowed[i] = 1; + } + } + } + + return allowed; +} + +export function segmentTumorFromGrayscale( + gray: Uint8Array, + w: number, + h: number, + seedPointsNorm: NormalizedPoint[], + threshold: TumorThreshold, + opts?: SegmentTumorOptions +): SegmentationResult { + // If the caller provided an anchor+tolerance (tolerance mode), normalize low/high from it. + // This guarantees monotonic behavior when the UI adjusts tolerance. + const normalizedThreshold: TumorThreshold = + typeof threshold.anchor === 'number' && typeof threshold.tolerance === 'number' + ? (() => { + const anchor = clamp(Math.round(threshold.anchor), 0, 255); + const tolerance = clamp(Math.round(threshold.tolerance), 0, 127); + return { + ...threshold, + anchor, + tolerance, + low: clamp(anchor - tolerance, 0, 255), + high: clamp(anchor + tolerance, 0, 255), + }; + })() + : threshold; + + const paintPx = seedPointsNorm.map((p) => ({ + x: clamp(p.x, 0, 1) * (w - 1), + y: clamp(p.y, 0, 1) * (h - 1), + })); + + // Painted region is a rough hint. We use it to determine the flood-fill seed (centroid) + // and a generous search ROI, but we do NOT treat it as a hard boundary. + const seed = computeSeedCentroid(paintPx, w, h); + const seedPx = { x: seed.x * (w - 1), y: seed.y * (h - 1) }; + + const roi = computeSeedRoi(paintPx, w, h); + + // If this looks like a real paint gesture (not the tiny seed cross used in propagation), + // compute a distance-to-paint map so pixels far outside the painted region are penalized. + const bbox = (() => { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + for (const p of paintPx) { + minX = Math.min(minX, p.x); + minY = Math.min(minY, p.y); + maxX = Math.max(maxX, p.x); + maxY = Math.max(maxY, p.y); + } + return { minX, minY, maxX, maxY }; + })(); + + const bboxW = Number.isFinite(bbox.maxX) ? Math.max(0, bbox.maxX - bbox.minX) : 0; + const bboxH = Number.isFinite(bbox.maxY) ? Math.max(0, bbox.maxY - bbox.minY) : 0; + const paintScale = Math.max(bboxW, bboxH); + + const looksLikePaintGesture = paintPx.length >= 16 || bboxW * bboxH >= 400; + + // Default: symmetric tolerance band. Asymmetry should come from auto-tune or explicit opts, + // because the "right" asymmetry depends on whether the tumor is brighter or darker than + // surrounding tissue (varies case-by-case). + const tolLowScale = + typeof opts?.toleranceLowScale === 'number' ? clamp(opts.toleranceLowScale, 0.25, 2) : 1; + const tolHighScale = + typeof opts?.toleranceHighScale === 'number' ? clamp(opts.toleranceHighScale, 0.25, 2) : 1; + + const baseTol = + typeof normalizedThreshold.tolerance === 'number' + ? Math.max(0, normalizedThreshold.tolerance) + : Math.max(0, (normalizedThreshold.high - normalizedThreshold.low) / 2); + + const thresholdWidth = baseTol * (tolLowScale + tolHighScale); + + // Default distance gating tuned to avoid large false-positive expansions. + // + // Precision note: + // We bias defaults toward staying close to the painted region. It's better UX if the first + // segmentation is conservative (higher precision) and the user can paint a bit more to recover FN, + // rather than the first result exploding into far-away same-intensity tissue. + const distParams = opts?.maxDistToPaint ?? { + // Tuned from GT-driven auto-tune on axial T2 FLAIR. + // + // Note: maxDist is a hard cutoff; these defaults intentionally allow a wider search region, + // while distanceToleranceScaleMin (default 0.25) prevents far-away leakage. + baseMin: 2, + paintScaleFactor: 0.6, + thresholdWidthFactor: 0.1, + }; + + // Max allowed manhattan distance from the painted boundary. + // + // IMPORTANT: For large paint blobs, use a tighter cap. When the user paints a filled region, + // they're expressing "the tumor is approximately here" — we should stay close to that boundary. + // + // For small/medium strokes: allow some expansion (user is giving a rough hint). + // For large filled blobs: be conservative (user has already outlined the region). + const paintAreaPx = bboxW * bboxH; + const isLargePaintBlob = paintAreaPx > 2500 || paintPx.length > 200; + + const maxDistCap = isLargePaintBlob + ? Math.round(Math.min(w, h) * 0.04) // ~20px for 512×512 + : Math.round(Math.min(w, h) * 0.12); // ~61px for 512×512 + + const maxDistToPaint = looksLikePaintGesture + ? Math.min( + maxDistCap, + Math.round( + Math.max(distParams.baseMin, paintScale * distParams.paintScaleFactor) + + thresholdWidth * distParams.thresholdWidthFactor + ) + ) + : undefined; + + // Cache the distance transform across threshold updates while the paint strokes stay the same. + // This keeps the slider responsive (distance transform is O(ROI area)). + const distToPaint = (() => { + if (!looksLikePaintGesture) return undefined; + + const sampleCount = 32; + const step = Math.max(1, Math.floor(paintPx.length / sampleCount)); + const sampled = paintPx + .filter((_, i) => i % step === 0) + .map((p) => `${Math.round(p.x)},${Math.round(p.y)}`) + .join(';'); + + const key = `${w}x${h}|${Math.round(roi.x0)},${Math.round(roi.y0)},${Math.round(roi.x1)},${Math.round(roi.y1)}|${sampled}`; + if (cachedDistKey === key && cachedDist) { + return cachedDist; + } + + const computed = computeDistanceToPaint(paintPx, w, h, roi); + cachedDistKey = key; + cachedDist = computed; + return computed; + })(); + + const allowed = buildAllowedMask({ + gray, + w, + h, + roi, + paintPx, + threshold: normalizedThreshold, + looksLikePaintGesture, + distToPaint, + maxDistToPaint, + distanceToleranceScaleMin: opts?.distanceToleranceScaleMin, + edgePenaltyStrength: opts?.edgePenaltyStrength, + toleranceLowScale: tolLowScale, + toleranceHighScale: tolHighScale, + bgModel: opts?.bgModel, + geodesic: opts?.geodesic, + adaptiveEnabled: opts?.adaptiveEnabled, + }); + + // Flood fill from the painted region. + // + // IMPORTANT: Do not only seed from the centroid. The allowed mask can be disconnected (e.g. due to + // local intensity changes, edge gating, or the user's stroke spanning multiple lobes). Seeding from + // multiple paint points makes the result much more robust and typically improves recall. + const floodSeeds = (() => { + const seeds: Array<{ x: number; y: number }> = [seedPx]; + if (!looksLikePaintGesture) return seeds; + + const sampleCount = Math.min(12, paintPx.length); + const step = Math.max(1, Math.floor(paintPx.length / sampleCount)); + for (let i = 0; i < paintPx.length; i += step) { + seeds.push(paintPx[i]!); + } + + return seeds; + })(); + + const { mask, area } = regionGrowMask(allowed, w, h, floodSeeds, roi); + + if (area === 0) { + throw new Error('No tumor region found in threshold range'); + } + + // Light morphology before contour extraction. + // + // - Open removes thin spurs / narrow bridges that often cause leakage FP. + // - Close fills tiny holes / gaps but can also bridge and create FP. + // + // Defaults tuned from GT-driven auto-tune on axial T2 FLAIR. + // + // - Close=1 helps fill small interior holes/gaps without overly blurring boundaries. + // - Open remains off by default to avoid deleting thin tumor structures. + const openIterations = clamp(Math.round(opts?.morphologicalOpenIterations ?? 0), 0, 3); + const closeIterations = clamp(Math.round(opts?.morphologicalCloseIterations ?? 1), 0, 3); + + let cleaned = mask; + for (let i = 0; i < openIterations; i++) { + cleaned = morphologicalOpen(cleaned, w, h); + } + for (let i = 0; i < closeIterations; i++) { + cleaned = morphologicalClose(cleaned, w, h); + } + + // Recompute area after cleanup. + let cleanedArea = 0; + for (let i = 0; i < cleaned.length; i++) { + if (cleaned[i]) cleanedArea++; + } + + // Extract a clean outer contour (largest loop) from the binary mask. + const contourPx = marchingSquaresContour(cleaned, w, h, roi); + if (contourPx.length < 3) { + throw new Error('Failed to extract tumor boundary'); + } + + const contourNorm = contourPx.map((p) => ({ + x: p.x / Math.max(1, w - 1), + y: p.y / Math.max(1, h - 1), + })); + + // Smooth jagged pixel edges, then simplify. + // + // Keep epsilon fairly small so the polygon stays detailed enough to track the tumor boundary. + // (Too much simplification looks "blocky" / coarse.) + // Default is no smoothing; smoothing can slightly shrink boundaries and hurt overlap metrics. + const smoothingIterations = clamp(Math.round(opts?.smoothingIterations ?? 0), 0, 4); + + // Slightly higher default epsilon reduces tiny boundary wiggles without materially impacting overlap. + const simplifyEpsilon = opts?.simplifyEpsilon ?? 0.0024; + + const smoothed = smoothingIterations > 0 ? chaikinSmooth(contourNorm, smoothingIterations) : contourNorm; + const simplified = rdpSimplify(smoothed, simplifyEpsilon); + + return { + polygon: { points: simplified }, + threshold: normalizedThreshold, + seed, + meta: { + areaPx: cleanedArea, + areaNorm: cleanedArea / Math.max(1, w * h), + imageWidth: w, + imageHeight: h, + }, + }; +} + +export async function segmentTumorFromCapturedPng( + png: Blob, + paintPointsNorm: NormalizedPoint[], + thresholdOverride?: TumorThreshold +): Promise { + if (paintPointsNorm.length < 2) { + throw new Error('Not enough paint points to segment'); + } + + const { gray, width: w, height: h } = await decodeCapturedPngToGrayscale(png); + + const paintPx = paintPointsNorm.map((p) => ({ + x: clamp(p.x, 0, 1) * (w - 1), + y: clamp(p.y, 0, 1) * (h - 1), + })); + + const threshold = thresholdOverride ?? estimateThresholdFromSeeds(gray, w, h, paintPx); + + return segmentTumorFromGrayscale(gray, w, h, paintPointsNorm, threshold); +} diff --git a/frontend/src/utils/segmentation/simplify.ts b/frontend/src/utils/segmentation/simplify.ts new file mode 100644 index 0000000..8cf1920 --- /dev/null +++ b/frontend/src/utils/segmentation/simplify.ts @@ -0,0 +1,70 @@ +import type { NormalizedPoint } from '../../db/schema'; + +function sqr(x: number) { + return x * x; +} + +function distPointToSegmentSq(p: NormalizedPoint, a: NormalizedPoint, b: NormalizedPoint): number { + const vx = b.x - a.x; + const vy = b.y - a.y; + const wx = p.x - a.x; + const wy = p.y - a.y; + + const c1 = vx * wx + vy * wy; + if (c1 <= 0) return sqr(p.x - a.x) + sqr(p.y - a.y); + + const c2 = vx * vx + vy * vy; + if (c2 <= c1) return sqr(p.x - b.x) + sqr(p.y - b.y); + + const t = c1 / c2; + const projX = a.x + t * vx; + const projY = a.y + t * vy; + return sqr(p.x - projX) + sqr(p.y - projY); +} + +function rdp(points: NormalizedPoint[], epsSq: number): NormalizedPoint[] { + if (points.length <= 2) return points; + + const a = points[0]; + const b = points[points.length - 1]; + + let maxD = -1; + let idx = -1; + + for (let i = 1; i < points.length - 1; i++) { + const d = distPointToSegmentSq(points[i], a, b); + if (d > maxD) { + maxD = d; + idx = i; + } + } + + if (maxD <= epsSq || idx === -1) { + return [a, b]; + } + + const left = rdp(points.slice(0, idx + 1), epsSq); + const right = rdp(points.slice(idx), epsSq); + return [...left.slice(0, -1), ...right]; +} + +export function rdpSimplify(points: NormalizedPoint[], epsilon: number): NormalizedPoint[] { + if (points.length <= 3) return points; + + // Ensure the polygon is closed for simplification stability, then drop the repeated point. + const first = points[0]; + const last = points[points.length - 1]; + const closed = first.x === last.x && first.y === last.y ? points : [...points, first]; + + const simplified = rdp(closed, epsilon * epsilon); + // Remove closing duplicate if present. + if (simplified.length >= 2) { + const s0 = simplified[0]; + const sl = simplified[simplified.length - 1]; + if (s0.x === sl.x && s0.y === sl.y) { + simplified.pop(); + } + } + + return simplified; +} diff --git a/frontend/src/utils/segmentation/smooth.ts b/frontend/src/utils/segmentation/smooth.ts new file mode 100644 index 0000000..11e26b0 --- /dev/null +++ b/frontend/src/utils/segmentation/smooth.ts @@ -0,0 +1,41 @@ +export type Point = { x: number; y: number }; + +/** + * Chaikin corner-cutting smoothing for a closed polygon. + * + * Notes: + * - This assumes `points` is a simple (non-self-intersecting) polygon. + * - It returns a new point array and does not repeat the start point at the end. + * - Chaikin smoothing shrinks the polygon slightly; that's desirable here to remove pixel jaggies. + */ +export function chaikinSmooth(points: Point[], iterations: number = 2): Point[] { + if (points.length < 3) return points; + + let pts = points; + + for (let it = 0; it < iterations; it++) { + const n = pts.length; + const next: Point[] = []; + + for (let i = 0; i < n; i++) { + const p0 = pts[i]; + const p1 = pts[(i + 1) % n]; + + // Q and R points for the edge. + const q = { + x: 0.75 * p0.x + 0.25 * p1.x, + y: 0.75 * p0.y + 0.25 * p1.y, + }; + const r = { + x: 0.25 * p0.x + 0.75 * p1.x, + y: 0.25 * p0.y + 0.75 * p1.y, + }; + + next.push(q, r); + } + + pts = next; + } + + return pts; +} diff --git a/frontend/src/utils/segmentation/traceBoundary.ts b/frontend/src/utils/segmentation/traceBoundary.ts new file mode 100644 index 0000000..6ad1318 --- /dev/null +++ b/frontend/src/utils/segmentation/traceBoundary.ts @@ -0,0 +1,148 @@ +type PxPoint = { x: number; y: number }; + +function clamp(v: number, lo: number, hi: number) { + return Math.max(lo, Math.min(hi, v)); +} + +function idx(x: number, y: number, w: number) { + return y * w + x; +} + +function isSet(mask: Uint8Array, w: number, h: number, x: number, y: number): boolean { + if (x < 0 || y < 0 || x >= w || y >= h) return false; + return mask[idx(x, y, w)] !== 0; +} + +function isBoundaryPixel(mask: Uint8Array, w: number, h: number, x: number, y: number): boolean { + if (!isSet(mask, w, h, x, y)) return false; + // 4-neighborhood boundary. + return ( + !isSet(mask, w, h, x - 1, y) || + !isSet(mask, w, h, x + 1, y) || + !isSet(mask, w, h, x, y - 1) || + !isSet(mask, w, h, x, y + 1) + ); +} + +// Moore-Neighbor tracing for an 8-connected boundary. +// Neighbor directions (clockwise) - defined once outside function to avoid allocation. +const DIRS: readonly PxPoint[] = [ + { x: 1, y: 0 }, + { x: 1, y: 1 }, + { x: 0, y: 1 }, + { x: -1, y: 1 }, + { x: -1, y: 0 }, + { x: -1, y: -1 }, + { x: 0, y: -1 }, + { x: 1, y: -1 }, +]; + +// Pre-computed direction lookup: dirLookup[dx+1][dy+1] = direction index +const DIR_LOOKUP: number[][] = [ + [5, 4, 3], // dx=-1: dy=-1,0,1 + [6, -1, 2], // dx=0: dy=-1,0,1 (center is invalid) + [7, 0, 1], // dx=1: dy=-1,0,1 +]; + +function traceFrom(mask: Uint8Array, w: number, h: number, start: PxPoint): PxPoint[] { + const boundary: PxPoint[] = []; + + let cx = start.x; + let cy = start.y; + // Backtrack point is initially the pixel to the left. + let bx = start.x - 1; + let by = start.y; + + // Hard limit to prevent infinite loops - max boundary length is perimeter of image. + const maxIters = Math.min(50000, (w + h) * 4); + + for (let iter = 0; iter < maxIters; iter++) { + boundary.push({ x: cx, y: cy }); + + // Find direction index from current -> back using lookup table. + const dx = clamp(bx - cx, -1, 1); + const dy = clamp(by - cy, -1, 1); + let startDir = DIR_LOOKUP[dx + 1][dy + 1]; + if (startDir < 0) startDir = 4; // fallback + + // Search neighbors clockwise starting from back direction. + let foundNext = false; + let nx = 0, ny = 0, nbx = 0, nby = 0; + + for (let k = 0; k < 8; k++) { + const di = (startDir + 1 + k) % 8; + const dir = DIRS[di]; + const testX = cx + dir.x; + const testY = cy + dir.y; + + if (isBoundaryPixel(mask, w, h, testX, testY)) { + nx = testX; + ny = testY; + // The new backtrack is the neighbor just before nx,ny in the search order. + const backDi = (di + 7) % 8; + const backDir = DIRS[backDi]; + nbx = cx + backDir.x; + nby = cy + backDir.y; + foundNext = true; + break; + } + } + + if (!foundNext) break; + + // Close when we return to the start with the same backtrack. + if (nx === start.x && ny === start.y && nbx === bx && nby === by) { + break; + } + + cx = nx; + cy = ny; + bx = nbx; + by = nby; + } + + return boundary; +} + +export function traceLargestBoundary( + mask: Uint8Array, + w: number, + h: number, + roi?: { x0: number; y0: number; x1: number; y1: number } +): PxPoint[] { + const x0 = roi ? clamp(Math.floor(roi.x0), 0, w - 1) : 0; + const y0 = roi ? clamp(Math.floor(roi.y0), 0, h - 1) : 0; + const x1 = roi ? clamp(Math.ceil(roi.x1), 0, w - 1) : w - 1; + const y1 = roi ? clamp(Math.ceil(roi.y1), 0, h - 1) : h - 1; + + console.log('[traceLargestBoundary] START', { w, h, roi: { x0, y0, x1, y1 } }); + const t0 = performance.now(); + + let best: PxPoint[] = []; + const visited = new Uint8Array(w * h); + let boundariesFound = 0; + + // Only scan within ROI to avoid O(w*h) scan of entire image. + for (let y = y0; y <= y1; y++) { + for (let x = x0; x <= x1; x++) { + const i = idx(x, y, w); + if (visited[i]) continue; + if (!isBoundaryPixel(mask, w, h, x, y)) continue; + + const b = traceFrom(mask, w, h, { x, y }); + boundariesFound++; + for (const p of b) { + visited[idx(p.x, p.y, w)] = 1; + } + + if (b.length > best.length) { + best = b; + } + } + } + + const elapsed = performance.now() - t0; + console.log('[traceLargestBoundary] DONE', { boundariesFound, bestLength: best.length, elapsed: elapsed.toFixed(1) + 'ms' }); + + return best; +} diff --git a/frontend/src/utils/ssim.ts b/frontend/src/utils/ssim.ts new file mode 100644 index 0000000..619cf09 --- /dev/null +++ b/frontend/src/utils/ssim.ts @@ -0,0 +1,243 @@ +import type { ExclusionMask } from '../types/api'; + +export type BlockSimilarityResult = { + /** Block-averaged SSIM (higher is better; typically ~[-1..1], often [0..1] in practice). */ + ssim: number; + /** Block-averaged local normalized cross correlation (LNCC). Range: ~[-1..1]. */ + lncc: number; + /** Global zero-mean normalized cross correlation (ZNCC). Range: ~[-1..1]. */ + zncc: number; + /** Number of blocks that contributed (had >= 1 included pixel). */ + blocksUsed: number; + /** Number of pixels used after masking. */ + pixelsUsed: number; +}; + +export type SsimResult = { + /** Block-averaged SSIM (higher is better; typically ~[-1..1], often [0..1] in practice). */ + ssim: number; + /** Number of blocks that contributed (had >= 1 included pixel). */ + blocksUsed: number; + /** Number of pixels used after masking. */ + pixelsUsed: number; +}; + +export type SsimOptions = { + /** Block size in pixels. Default: 16. Larger is faster but less local. */ + blockSize?: number; + + /** Optional inclusion mask (same shape as images). Keep pixels where mask[idx] != 0. */ + inclusionMask?: Uint8Array; + + /** Optional exclusion rectangle in normalized [0,1] image coordinates. */ + exclusionRect?: ExclusionMask; + + /** Image width in pixels (required if exclusionRect is provided). */ + imageWidth?: number; + /** Image height in pixels (required if exclusionRect is provided). */ + imageHeight?: number; + + /** SSIM constants (defaults match common SSIM settings). */ + k1?: number; // default 0.01 + k2?: number; // default 0.03 + dynamicRange?: number; // L, default 1.0 for normalized pixels +}; + +function inferSquareSize(n: number): number { + const s = Math.round(Math.sqrt(n)); + if (s <= 0 || s * s !== n) { + throw new Error('computeBlockSSIM: expected square image (provide imageWidth/imageHeight)'); + } + return s; +} + +/** + * Compute a fast approximation of SSIM by averaging SSIM over non-overlapping blocks. + * + * Notes: + * - This is not the classic Gaussian-window SSIM; it's a block-based approximation that is much + * faster in JS/TS while still capturing local structure. + * - Pixels are assumed to be normalized grayscale (typically [0..1]). + */ +export function computeBlockSimilarity(imageA: Float32Array, imageB: Float32Array, opts: SsimOptions = {}): BlockSimilarityResult { + const n = imageA.length; + if (n === 0 || imageB.length !== n) { + return { ssim: 0, lncc: 0, zncc: 0, blocksUsed: 0, pixelsUsed: 0 }; + } + + const inclusionMask = opts.inclusionMask; + if (inclusionMask && inclusionMask.length !== n) { + throw new Error(`computeBlockSSIM: inclusionMask length mismatch (mask=${inclusionMask.length}, image=${n})`); + } + + const width = + typeof opts.imageWidth === 'number' && typeof opts.imageHeight === 'number' && opts.imageWidth === opts.imageHeight + ? opts.imageWidth + : inferSquareSize(n); + const height = + typeof opts.imageHeight === 'number' && typeof opts.imageWidth === 'number' && opts.imageWidth === opts.imageHeight + ? opts.imageHeight + : width; + + const blockSizeRaw = opts.blockSize ?? 16; + const blockSize = Math.max(4, Math.round(blockSizeRaw)); + + const blockCols = Math.ceil(width / blockSize); + const blockRows = Math.ceil(height / blockSize); + const numBlocks = blockCols * blockRows; + + // Per-block accumulators. + const sumA = new Float64Array(numBlocks); + const sumB = new Float64Array(numBlocks); + const sumA2 = new Float64Array(numBlocks); + const sumB2 = new Float64Array(numBlocks); + const sumAB = new Float64Array(numBlocks); + const count = new Uint32Array(numBlocks); + + // Precompute exclusion bounds. + const exclusionRect = opts.exclusionRect; + let hasExclusion = false; + let exclX0 = 0; + let exclY0 = 0; + let exclX1 = 0; + let exclY1 = 0; + if (exclusionRect && width > 0 && height > 0) { + exclX0 = Math.floor(exclusionRect.x * width); + exclY0 = Math.floor(exclusionRect.y * height); + exclX1 = Math.ceil((exclusionRect.x + exclusionRect.width) * width); + exclY1 = Math.ceil((exclusionRect.y + exclusionRect.height) * height); + hasExclusion = exclX1 > exclX0 && exclY1 > exclY0; + } + + // Fast path for power-of-two blocks (default 16): use bitshift instead of division. + const isPowerOfTwo = (v: number) => (v & (v - 1)) === 0; + const useShift = isPowerOfTwo(blockSize); + const blockShift = useShift ? Math.round(Math.log2(blockSize)) : 0; + + let pixelsUsed = 0; + + // Global accumulators (for ZNCC). + let sumATotal = 0; + let sumBTotal = 0; + let sumA2Total = 0; + let sumB2Total = 0; + let sumABTotal = 0; + + for (let y = 0; y < height; y++) { + const row = y * width; + const blockRow = useShift ? (y >> blockShift) : Math.floor(y / blockSize); + + for (let x = 0; x < width; x++) { + const idx = row + x; + + if (inclusionMask && inclusionMask[idx] === 0) continue; + + if (hasExclusion && x >= exclX0 && x < exclX1 && y >= exclY0 && y < exclY1) { + continue; + } + + const a = imageA[idx] ?? 0; + const b = imageB[idx] ?? 0; + + const blockCol = useShift ? (x >> blockShift) : Math.floor(x / blockSize); + const bi = blockRow * blockCols + blockCol; + + sumA[bi] += a; + sumB[bi] += b; + sumA2[bi] += a * a; + sumB2[bi] += b * b; + sumAB[bi] += a * b; + count[bi]++; + pixelsUsed++; + + sumATotal += a; + sumBTotal += b; + sumA2Total += a * a; + sumB2Total += b * b; + sumABTotal += a * b; + } + } + + if (pixelsUsed === 0) { + return { ssim: 0, lncc: 0, zncc: 0, blocksUsed: 0, pixelsUsed: 0 }; + } + + const L = opts.dynamicRange ?? 1; + const k1 = opts.k1 ?? 0.01; + const k2 = opts.k2 ?? 0.03; + const c1 = (k1 * L) * (k1 * L); + const c2 = (k2 * L) * (k2 * L); + + // ZNCC (global, zero-mean normalized cross correlation). + // + // We use population stats (divide by N) for stability. + const invN = 1 / pixelsUsed; + const meanA = sumATotal * invN; + const meanB = sumBTotal * invN; + let varA = sumA2Total * invN - meanA * meanA; + let varB = sumB2Total * invN - meanB * meanB; + let covAB = sumABTotal * invN - meanA * meanB; + + if (varA < 0) varA = 0; + if (varB < 0) varB = 0; + if (!Number.isFinite(covAB)) covAB = 0; + + const eps = 1e-12; + const denomZncc = Math.sqrt(varA * varB); + const zncc = denomZncc > eps ? covAB / denomZncc : 0; + + let weightedSsimSum = 0; + let weightedLnccSum = 0; + let weightTotal = 0; + let blocksUsed = 0; + + + for (let bi = 0; bi < numBlocks; bi++) { + const m = count[bi]; + if (m === 0) continue; + + const invM = 1 / m; + const meanA = sumA[bi] * invM; + const meanB = sumB[bi] * invM; + + let varA = sumA2[bi] * invM - meanA * meanA; + let varB = sumB2[bi] * invM - meanB * meanB; + let covAB = sumAB[bi] * invM - meanA * meanB; + + // Numerical safety. + if (varA < 0) varA = 0; + if (varB < 0) varB = 0; + + // Clamp extreme cov due to numeric issues. + if (!Number.isFinite(covAB)) covAB = 0; + + const num1 = 2 * meanA * meanB + c1; + const den1 = meanA * meanA + meanB * meanB + c1; + + const num2 = 2 * covAB + c2; + const den2 = varA + varB + c2; + + const denom = den1 * den2; + const ssimBlock = denom !== 0 ? (num1 * num2) / denom : 0; + + // LNCC for this block. + const denomLncc = Math.sqrt(varA * varB); + const lnccBlock = denomLncc > eps ? covAB / denomLncc : 0; + + // Weight by included pixels so partially-masked blocks don't dominate. + weightedSsimSum += ssimBlock * m; + weightedLnccSum += lnccBlock * m; + weightTotal += m; + blocksUsed++; + } + + const ssim = weightTotal > 0 ? weightedSsimSum / weightTotal : 0; + const lncc = weightTotal > 0 ? weightedLnccSum / weightTotal : 0; + + return { ssim, lncc, zncc, blocksUsed, pixelsUsed }; +} + +export function computeBlockSSIM(imageA: Float32Array, imageB: Float32Array, opts: SsimOptions = {}): SsimResult { + const r = computeBlockSimilarity(imageA, imageB, opts); + return { ssim: r.ssim, blocksUsed: r.blocksUsed, pixelsUsed: r.pixelsUsed }; +} diff --git a/frontend/src/utils/svr/dicomGeometry.ts b/frontend/src/utils/svr/dicomGeometry.ts new file mode 100644 index 0000000..b6fb7fe --- /dev/null +++ b/frontend/src/utils/svr/dicomGeometry.ts @@ -0,0 +1,171 @@ +import type { DicomInstance } from '../../db/schema'; +import type { Vec3 } from './vec3'; +import { cross, dot, normalize, v3 } from './vec3'; + +function parseMultiNumberString(value: string): number[] { + // Multi-valued DICOM tags are typically separated by backslashes. + // Some exporters use commas/spaces; accept those as well. + return value + .split(/[\\,\s]+/) + .filter(Boolean) + .map((s) => Number.parseFloat(s)) + .filter((n) => Number.isFinite(n)); +} + +export type SliceAxes = { + rowDir: Vec3; + colDir: Vec3; + normalDir: Vec3; +}; + +export function parseImageOrientationPatient(iop: string | undefined): SliceAxes | null { + if (!iop) return null; + const nums = parseMultiNumberString(iop); + if (nums.length < 6) return null; + + const rowDir = normalize(v3(nums[0] ?? 0, nums[1] ?? 0, nums[2] ?? 0)); + const colDir = normalize(v3(nums[3] ?? 0, nums[4] ?? 0, nums[5] ?? 0)); + + // Slice normal is row x col. + const normalDir = normalize(cross(rowDir, colDir)); + + // If the DICOM is malformed (col/row not orthogonal), normal could be zero. + if (!Number.isFinite(normalDir.x) || !Number.isFinite(normalDir.y) || !Number.isFinite(normalDir.z)) { + return null; + } + + return { rowDir, colDir, normalDir }; +} + +export function parseImagePositionPatient(ipp: string | undefined): Vec3 | null { + if (!ipp) return null; + const nums = parseMultiNumberString(ipp); + if (nums.length < 3) return null; + return v3(nums[0] ?? 0, nums[1] ?? 0, nums[2] ?? 0); +} + +export function parsePixelSpacingMm(pixelSpacing: string | undefined): { rowSpacingMm: number; colSpacingMm: number } | null { + if (!pixelSpacing) return null; + const nums = parseMultiNumberString(pixelSpacing); + if (nums.length < 2) return null; + + const rowSpacingMm = nums[0] ?? NaN; + const colSpacingMm = nums[1] ?? NaN; + + if (!Number.isFinite(rowSpacingMm) || !Number.isFinite(colSpacingMm) || rowSpacingMm <= 0 || colSpacingMm <= 0) { + return null; + } + + return { rowSpacingMm, colSpacingMm }; +} + +export type SliceGeometry = { + rows: number; + cols: number; + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + normalDir: Vec3; + rowSpacingMm: number; + colSpacingMm: number; +}; + +export function getSliceGeometryFromInstance(instance: Pick): SliceGeometry { + const axes = parseImageOrientationPatient(instance.imageOrientationPatient); + const ipp = parseImagePositionPatient(instance.imagePositionPatient); + const spacing = parsePixelSpacingMm(instance.pixelSpacing); + + if (!axes || !ipp || !spacing) { + throw new Error('Missing spatial metadata (ImagePositionPatient / ImageOrientationPatient / PixelSpacing)'); + } + + return { + rows: instance.rows, + cols: instance.columns, + ippMm: ipp, + rowDir: axes.rowDir, + colDir: axes.colDir, + normalDir: axes.normalDir, + rowSpacingMm: spacing.rowSpacingMm, + colSpacingMm: spacing.colSpacingMm, + }; +} + +export function sliceCornersMm(params: { + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + rowSpacingMm: number; + colSpacingMm: number; + rows: number; + cols: number; +}): Vec3[] { + const { ippMm, rowDir, colDir, rowSpacingMm, colSpacingMm, rows, cols } = params; + + const rMax = Math.max(0, rows - 1); + const cMax = Math.max(0, cols - 1); + + // DICOM convention recap: + // - ImageOrientationPatient (IOP): first triplet is the direction of increasing *column* index, + // second triplet is the direction of increasing *row* index. + // - PixelSpacing: [rowSpacing, colSpacing] in mm. + // + // Therefore: world(r, c) = IPP + colDir * (r * rowSpacing) + rowDir * (c * colSpacing). + const p00 = ippMm; + const p10 = v3( + ippMm.x + colDir.x * (rMax * rowSpacingMm), + ippMm.y + colDir.y * (rMax * rowSpacingMm), + ippMm.z + colDir.z * (rMax * rowSpacingMm) + ); + const p01 = v3( + ippMm.x + rowDir.x * (cMax * colSpacingMm), + ippMm.y + rowDir.y * (cMax * colSpacingMm), + ippMm.z + rowDir.z * (cMax * colSpacingMm) + ); + const p11 = v3( + p10.x + rowDir.x * (cMax * colSpacingMm), + p10.y + rowDir.y * (cMax * colSpacingMm), + p10.z + rowDir.z * (cMax * colSpacingMm) + ); + + return [p00, p01, p10, p11]; +} + +function median(values: number[]): number | null { + const v = values.filter((x) => Number.isFinite(x)).sort((a, b) => a - b); + if (v.length === 0) return null; + const mid = Math.floor(v.length / 2); + return v.length % 2 === 1 ? v[mid] ?? null : ((v[mid - 1] ?? 0) + (v[mid] ?? 0)) / 2; +} + +export function estimateSliceSpacingMm( + instances: Array> +): number | null { + if (instances.length < 2) return null; + + const firstAxes = parseImageOrientationPatient(instances[0]?.imageOrientationPatient); + if (!firstAxes) return null; + + const ipps: Vec3[] = []; + for (const inst of instances) { + const ipp = parseImagePositionPatient(inst.imagePositionPatient); + if (ipp) ipps.push(ipp); + } + + if (ipps.length < 2) return null; + + const deltas: number[] = []; + for (let i = 0; i < ipps.length - 1; i++) { + const a = ipps[i]; + const b = ipps[i + 1]; + if (!a || !b) continue; + + const d = v3(b.x - a.x, b.y - a.y, b.z - a.z); + const along = Math.abs(dot(d, firstAxes.normalDir)); + if (Number.isFinite(along) && along > 0) { + deltas.push(along); + } + } + + return median(deltas); +} diff --git a/frontend/src/utils/svr/downsample.ts b/frontend/src/utils/svr/downsample.ts new file mode 100644 index 0000000..f9208ba --- /dev/null +++ b/frontend/src/utils/svr/downsample.ts @@ -0,0 +1,31 @@ +export type SliceDownsampleMode = 'fixed' | 'voxel-aware'; + +export function computeSvrDownsampleSize(params: { + rows: number; + cols: number; + maxSize: number; + mode: SliceDownsampleMode; + rowSpacingMm: number; + colSpacingMm: number; + targetVoxelSizeMm: number; +}): { dsRows: number; dsCols: number; scale: number } { + const { rows, cols, maxSize, mode, rowSpacingMm, colSpacingMm, targetVoxelSizeMm } = params; + + const maxDim = Math.max(rows, cols); + + let scale = 1; + if (Number.isFinite(maxSize) && maxSize > 1 && maxDim > maxSize) { + scale = maxSize / maxDim; + } + + if (mode === 'voxel-aware') { + const maxSpacingMm = Math.max(rowSpacingMm, colSpacingMm); + const minScale = Math.min(1, Math.max(0, maxSpacingMm / Math.max(1e-6, targetVoxelSizeMm))); + if (scale < minScale) scale = minScale; + } + + const dsRows = Math.max(1, Math.round(rows * scale)); + const dsCols = Math.max(1, Math.round(cols * scale)); + + return { dsRows, dsCols, scale }; +} diff --git a/frontend/src/utils/svr/reconstructVolume.ts b/frontend/src/utils/svr/reconstructVolume.ts new file mode 100644 index 0000000..43d4344 --- /dev/null +++ b/frontend/src/utils/svr/reconstructVolume.ts @@ -0,0 +1,1565 @@ +import cornerstone from 'cornerstone-core'; +import { getDB } from '../../db/db'; +import type { DicomInstance } from '../../db/schema'; +import type { SvrParams, SvrProgress, SvrResult, SvrRoi, SvrSelectedSeries } from '../../types/svr'; +import { getSortedSopInstanceUidsForSeries } from '../localApi'; +import type { SliceGeometry } from './dicomGeometry'; +import { getSliceGeometryFromInstance, sliceCornersMm } from './dicomGeometry'; +import type { VolumeDims } from './trilinear'; +import { sampleTrilinear } from './trilinear'; +import type { SvrReconstructionGrid, SvrReconstructionOptions } from './reconstructionCore'; +import { reconstructVolumeFromSlices, refineVolumeInPlace, resampleVolumeToGridTrilinear } from './reconstructionCore'; +import { computeSvrDownsampleSize } from './downsample'; +import { resample2dAreaAverage, resample2dLanczos3 } from './resample2d'; +import type { Vec3 } from './vec3'; +import { cross, dot, normalize, v3 } from './vec3'; +import { boundsCornersMm, cropSliceToRoiInPlace } from './sliceRoiCrop'; +import { generateVolumePreviews } from './volumePreview'; +import { debugSvrLog, isDebugSvrEnabled } from '../debugSvr'; + +type SvrSliceResampleKernel = 'area' | 'lanczos3'; + +function getSvrSliceResampleKernel(debug?: boolean): SvrSliceResampleKernel { + if (!debug) return 'area'; + + try { + const v = localStorage.getItem('miraviewer:svr-resample-kernel'); + return v === 'lanczos3' ? 'lanczos3' : 'area'; + } catch { + return 'area'; + } +} + +function yieldToMain(): Promise { + return new Promise((resolve) => setTimeout(resolve, 0)); +} + +function assertNotAborted(signal?: AbortSignal): void { + if (signal?.aborted) { + throw new Error('SVR cancelled'); + } +} + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +type BoundsMm = { min: Vec3; max: Vec3 }; + +function boundsFromRoi(roi: SvrRoi): BoundsMm { + return { + min: v3(roi.boundsMm.min[0], roi.boundsMm.min[1], roi.boundsMm.min[2]), + max: v3(roi.boundsMm.max[0], roi.boundsMm.max[1], roi.boundsMm.max[2]), + }; +} + +function intersectBoundsMm(a: BoundsMm, b: BoundsMm): BoundsMm { + return { + min: v3(Math.max(a.min.x, b.min.x), Math.max(a.min.y, b.min.y), Math.max(a.min.z, b.min.z)), + max: v3(Math.min(a.max.x, b.max.x), Math.min(a.max.y, b.max.y), Math.min(a.max.z, b.max.z)), + }; +} + +function assertNonEmptyBounds(bounds: BoundsMm, label: string): void { + if (!(bounds.min.x < bounds.max.x && bounds.min.y < bounds.max.y && bounds.min.z < bounds.max.z)) { + throw new Error(`SVR ROI does not overlap reconstruction bounds (${label})`); + } +} + +function withinTrilinearSupport(dims: VolumeDims, x: number, y: number, z: number): boolean { + // sampleTrilinear/splatTrilinear require x0>=0 and x1= 0 && y >= 0 && z >= 0 && x < dims.nx - 1 && y < dims.ny - 1 && z < dims.nz - 1; +} + +type LoadedSlice = { + seriesUid: string; + sopInstanceUid: string; + + // Downsampled pixel grid (normalized to [0,1]) + pixels: Float32Array; + dsRows: number; + dsCols: number; + + // Original slice geometry (useful for logging/validation; not used in the hot loops) + srcRows: number; + srcCols: number; + rowSpacingMm: number; + colSpacingMm: number; + + // Optional thickness/spacing hints (if present in DICOM metadata) + sliceThicknessMm: number | null; + spacingBetweenSlicesMm: number | null; + + // Spatial mapping + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + normalDir: Vec3; + + rowSpacingDsMm: number; + colSpacingDsMm: number; +}; + +type Mat3 = [number, number, number, number, number, number, number, number, number]; + +type RigidParams = { + // Translation in world/patient mm. + tx: number; + ty: number; + tz: number; + // Rotation in radians about patient/world axes. + rx: number; + ry: number; + rz: number; +}; + +type SeriesSamples = { + // Observed intensities (normalized [0,1]). + obs: Float32Array; + // Original world positions for each sample (x,y,z per sample). + pos: Float32Array; + count: number; +}; + +function boundsCenterMm(b: BoundsMm): Vec3 { + return v3((b.min.x + b.max.x) * 0.5, (b.min.y + b.max.y) * 0.5, (b.min.z + b.max.z) * 0.5); +} + +function isWithinBoundsMm(p: Vec3, b: BoundsMm): boolean { + return p.x >= b.min.x && p.x <= b.max.x && p.y >= b.min.y && p.y <= b.max.y && p.z >= b.min.z && p.z <= b.max.z; +} + +function clampAbs(x: number, maxAbs: number): number { + if (!Number.isFinite(x)) return 0; + if (!Number.isFinite(maxAbs) || maxAbs <= 0) return 0; + return x < -maxAbs ? -maxAbs : x > maxAbs ? maxAbs : x; +} + +function mat3FromEulerXYZ(rx: number, ry: number, rz: number): Mat3 { + // R = Rz(rz) * Ry(ry) * Rx(rx) + const cx = Math.cos(rx); + const sx = Math.sin(rx); + const cy = Math.cos(ry); + const sy = Math.sin(ry); + const cz = Math.cos(rz); + const sz = Math.sin(rz); + + const m00 = cz * cy; + const m01 = cz * sy * sx - sz * cx; + const m02 = cz * sy * cx + sz * sx; + + const m10 = sz * cy; + const m11 = sz * sy * sx + cz * cx; + const m12 = sz * sy * cx - cz * sx; + + const m20 = -sy; + const m21 = cy * sx; + const m22 = cy * cx; + + return [m00, m01, m02, m10, m11, m12, m20, m21, m22]; +} + +function mat3MulVec3(m: Mat3, x: number, y: number, z: number): Vec3 { + return v3(m[0] * x + m[1] * y + m[2] * z, m[3] * x + m[4] * y + m[5] * z, m[6] * x + m[7] * y + m[8] * z); +} + +function applyRigidToPoint(p: Vec3, centerMm: Vec3, rot: Mat3, tMm: Vec3): Vec3 { + // Rotate about `centerMm`, then translate. + const dx = p.x - centerMm.x; + const dy = p.y - centerMm.y; + const dz = p.z - centerMm.z; + + const r = mat3MulVec3(rot, dx, dy, dz); + return v3(centerMm.x + r.x + tMm.x, centerMm.y + r.y + tMm.y, centerMm.z + r.z + tMm.z); +} + +function applyRotToDir(d: Vec3, rot: Mat3): Vec3 { + const r = mat3MulVec3(rot, d.x, d.y, d.z); + return normalize(r); +} + +function orthonormalizeRowCol(rowDir: Vec3, colDir: Vec3): { rowDir: Vec3; colDir: Vec3 } { + // Keep these as an orthonormal basis; this prevents numerical drift after repeated rotations. + const r = normalize(rowDir); + const c0 = normalize(colDir); + const n = normalize(cross(r, c0)); + const c = normalize(cross(n, r)); + return { rowDir: r, colDir: c }; +} + +function applyRigidToSeriesSlices(params: { + slices: LoadedSlice[]; + centerMm: Vec3; + rot: Mat3; + tMm: Vec3; +}): void { + const { slices, centerMm, rot, tMm } = params; + + for (const s of slices) { + s.ippMm = applyRigidToPoint(s.ippMm, centerMm, rot, tMm); + + const row = applyRotToDir(s.rowDir, rot); + const col = applyRotToDir(s.colDir, rot); + const ortho = orthonormalizeRowCol(row, col); + s.rowDir = ortho.rowDir; + s.colDir = ortho.colDir; + s.normalDir = normalize(cross(s.rowDir, s.colDir)); + } +} + +function buildSeriesSamples(params: { + slices: LoadedSlice[]; + roiBounds: BoundsMm; + maxSamples: number; + signal?: AbortSignal; +}): SeriesSamples { + const { slices, roiBounds, maxSamples, signal } = params; + + const maxN = Math.max(1, Math.round(maxSamples)); + const perSliceTarget = Math.max(64, Math.ceil(maxN / Math.max(1, slices.length))); + + let totalPixels = 0; + for (const s of slices) totalPixels += s.dsRows * s.dsCols; + + // Choose a roughly-uniform stride so we don't spend time scoring every pixel. + const stride = Math.max(1, Math.floor(Math.sqrt(totalPixels / maxN))); + + const obs: number[] = []; + const pos: number[] = []; + + for (let sIdx = 0; sIdx < slices.length; sIdx++) { + assertNotAborted(signal); + const s = slices[sIdx]; + + let usedThisSlice = 0; + + for (let r = 0; r < s.dsRows; r += stride) { + const baseX = s.ippMm.x + s.colDir.x * (r * s.rowSpacingDsMm); + const baseY = s.ippMm.y + s.colDir.y * (r * s.rowSpacingDsMm); + const baseZ = s.ippMm.z + s.colDir.z * (r * s.rowSpacingDsMm); + + const rowBase = r * s.dsCols; + + for (let c = 0; c < s.dsCols; c += stride) { + const v = s.pixels[rowBase + c] ?? 0; + if (v <= 0) continue; + + const wx = baseX + s.rowDir.x * (c * s.colSpacingDsMm); + const wy = baseY + s.rowDir.y * (c * s.colSpacingDsMm); + const wz = baseZ + s.rowDir.z * (c * s.colSpacingDsMm); + + const p = v3(wx, wy, wz); + if (!isWithinBoundsMm(p, roiBounds)) continue; + + obs.push(v); + pos.push(wx, wy, wz); + usedThisSlice++; + + if (usedThisSlice >= perSliceTarget) break; + if (obs.length >= maxN) break; + } + + if (usedThisSlice >= perSliceTarget) break; + if (obs.length >= maxN) break; + } + + if (obs.length >= maxN) break; + } + + return { + obs: Float32Array.from(obs), + pos: Float32Array.from(pos), + count: obs.length, + }; +} + +function scoreNcc(params: { + samples: SeriesSamples; + refVolume: Float32Array; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + centerMm: Vec3; + rigid: RigidParams; +}): { ncc: number; used: number } { + const { samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid } = params; + + if (samples.count <= 0) return { ncc: Number.NEGATIVE_INFINITY, used: 0 }; + + const rot = mat3FromEulerXYZ(rigid.rx, rigid.ry, rigid.rz); + const tMm = v3(rigid.tx, rigid.ty, rigid.tz); + + const invVox = 1 / voxelSizeMm; + + let sumA = 0; + let sumB = 0; + let sumAA = 0; + let sumBB = 0; + let sumAB = 0; + let used = 0; + + const obs = samples.obs; + const pos = samples.pos; + + for (let i = 0; i < samples.count; i++) { + const a = obs[i] ?? 0; + const x = pos[i * 3] ?? 0; + const y = pos[i * 3 + 1] ?? 0; + const z = pos[i * 3 + 2] ?? 0; + + // Apply candidate rigid transform about ROI center. + const p = applyRigidToPoint(v3(x, y, z), centerMm, rot, tMm); + + const vx = (p.x - originMm.x) * invVox; + const vy = (p.y - originMm.y) * invVox; + const vz = (p.z - originMm.z) * invVox; + + if (!withinTrilinearSupport(dims, vx, vy, vz)) continue; + + const b = sampleTrilinear(refVolume, dims, vx, vy, vz); + + sumA += a; + sumB += b; + sumAA += a * a; + sumBB += b * b; + sumAB += a * b; + used++; + } + + if (used < 512) { + // Too few in-bounds samples to reliably optimize. + return { ncc: Number.NEGATIVE_INFINITY, used }; + } + + const invN = 1 / used; + const cov = sumAB - sumA * sumB * invN; + const varA = sumAA - sumA * sumA * invN; + const varB = sumBB - sumB * sumB * invN; + + const denom = Math.sqrt(Math.max(1e-12, varA * varB)); + const ncc = denom > 0 ? cov / denom : Number.NEGATIVE_INFINITY; + + return { ncc, used }; +} + +async function optimizeRigidNcc(params: { + samples: SeriesSamples; + refVolume: Float32Array; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + centerMm: Vec3; + signal?: AbortSignal; +}): Promise<{ best: RigidParams; bestScore: number; used: number; evals: number }> { + const { samples, refVolume, dims, originMm, voxelSizeMm, centerMm, signal } = params; + + // Assumptions: the coarse alignment got us "close". + // We only search a small neighborhood around the current placement to avoid silly transforms. + const maxTransMm = 20; + const maxRotRad = (10 * Math.PI) / 180; + + const stages = [ + { transStepMm: 2.0, rotStepRad: (2 * Math.PI) / 180 }, + { transStepMm: 1.0, rotStepRad: (1 * Math.PI) / 180 }, + { transStepMm: 0.5, rotStepRad: (0.5 * Math.PI) / 180 }, + ]; + + let cur: RigidParams = { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }; + const bestEval = scoreNcc({ samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid: cur }); + let bestScore = bestEval.ncc; + let bestUsed = bestEval.used; + let evals = 1; + + const tryUpdate = (next: RigidParams): boolean => { + const e = scoreNcc({ samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid: next }); + evals++; + if (e.ncc > bestScore + 1e-4) { + cur = next; + bestScore = e.ncc; + bestUsed = e.used; + return true; + } + return false; + }; + + for (const stage of stages) { + let improved = true; + let iter = 0; + + while (improved && iter < 20) { + assertNotAborted(signal); + improved = false; + iter++; + + const t = stage.transStepMm; + const r = stage.rotStepRad; + + const candidates: Array = ['tx', 'ty', 'tz', 'rx', 'ry', 'rz']; + + for (const key of candidates) { + const step = key.startsWith('t') ? t : r; + + const plus: RigidParams = { ...cur }; + const minus: RigidParams = { ...cur }; + (plus as Record)[key] = cur[key] + step; + (minus as Record)[key] = cur[key] - step; + + // Clamp each dimension independently. + plus.tx = clampAbs(plus.tx, maxTransMm); + plus.ty = clampAbs(plus.ty, maxTransMm); + plus.tz = clampAbs(plus.tz, maxTransMm); + plus.rx = clampAbs(plus.rx, maxRotRad); + plus.ry = clampAbs(plus.ry, maxRotRad); + plus.rz = clampAbs(plus.rz, maxRotRad); + + minus.tx = clampAbs(minus.tx, maxTransMm); + minus.ty = clampAbs(minus.ty, maxTransMm); + minus.tz = clampAbs(minus.tz, maxTransMm); + minus.rx = clampAbs(minus.rx, maxRotRad); + minus.ry = clampAbs(minus.ry, maxRotRad); + minus.rz = clampAbs(minus.rz, maxRotRad); + + if (tryUpdate(plus)) { + improved = true; + } + + if (tryUpdate(minus)) { + improved = true; + } + + if (evals % 25 === 0) { + await yieldToMain(); + } + } + } + } + + return { best: cur, bestScore, used: bestUsed, evals }; +} + +async function rigidAlignSeriesInRoi(params: { + allSlices: LoadedSlice[]; + selectedSeries: SvrSelectedSeries[]; + roiBounds: BoundsMm; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + roi: SvrRoi; + signal?: AbortSignal; + onProgress?: (p: SvrProgress) => void; + debug: boolean; +}): Promise { + const { allSlices, selectedSeries, roiBounds, dims, originMm, voxelSizeMm, roi, signal, onProgress, debug } = params; + + // This stage exists because multi-plane fusion is extremely sensitive to even small spatial-tag mismatches. + // If series are misregistered, SVR will smear details rather than sharpen them. + + const bySeries = new Map(); + for (const s of allSlices) { + const arr = bySeries.get(s.seriesUid); + if (arr) arr.push(s); + else bySeries.set(s.seriesUid, [s]); + } + + const labelByUid = new Map(); + for (const s of selectedSeries) labelByUid.set(s.seriesUid, s.label); + + const roiReferenceUid = roi.sourceSeriesUid ?? null; + let referenceUid: string | null = null; + + if (roiReferenceUid && bySeries.has(roiReferenceUid)) { + referenceUid = roiReferenceUid; + } else { + let bestCount = -1; + for (const [uid, arr] of bySeries) { + if (arr.length > bestCount) { + referenceUid = uid; + bestCount = arr.length; + } + } + } + + const centerMm = boundsCenterMm(roiBounds); + + debugSvrLog( + 'registration.roi-rigid.plan', + { + referenceUid, + centerMm: { x: Number(centerMm.x.toFixed(3)), y: Number(centerMm.y.toFixed(3)), z: Number(centerMm.z.toFixed(3)) }, + dims, + voxelSizeMm: Number(voxelSizeMm.toFixed(4)), + }, + debug + ); + + // Align each non-reference series to the reconstruction of the other series. + const seriesUids = Array.from(bySeries.keys()); + for (let idx = 0; idx < seriesUids.length; idx++) { + assertNotAborted(signal); + + const uid = seriesUids[idx]; + if (!uid) continue; + if (referenceUid && uid === referenceUid) continue; + + const movingSlices = bySeries.get(uid); + if (!movingSlices || movingSlices.length === 0) continue; + + onProgress?.({ + phase: 'initializing', + current: 57, + total: 100, + message: `ROI rigid align… (${labelByUid.get(uid) ?? uid})`, + }); + + // Build a reference volume from all other series (used only for scoring). + const otherSlices: LoadedSlice[] = []; + for (const [otherUid, slices] of bySeries) { + if (otherUid === uid) continue; + otherSlices.push(...slices); + } + + if (otherSlices.length === 0) continue; + + const refGrid: SvrReconstructionGrid = { dims, originMm, voxelSizeMm }; + const refOptions: SvrReconstructionOptions = { + iterations: 0, + stepSize: 0, + clampOutput: true, + psfMode: 'none', + robustLoss: 'none', + robustDelta: 0.1, + laplacianWeight: 0, + }; + + const refVol = await reconstructVolumeFromSlices({ + slices: otherSlices, + grid: refGrid, + options: refOptions, + hooks: { + signal, + yieldToMain, + }, + }); + + // Extract samples from the moving series within the ROI bounds. + const samples = buildSeriesSamples({ slices: movingSlices, roiBounds: roiBounds, maxSamples: 40_000, signal }); + + if (samples.count < 1024) { + console.warn('[svr] ROI rigid alignment: too few samples inside ROI; skipping series', { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + samples: samples.count, + }); + continue; + } + + const before = scoreNcc({ + samples, + refVolume: refVol, + dims, + originMm, + voxelSizeMm, + centerMm, + rigid: { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }, + }); + + const opt = await optimizeRigidNcc({ samples, refVolume: refVol, dims, originMm, voxelSizeMm, centerMm, signal }); + + const after = scoreNcc({ + samples, + refVolume: refVol, + dims, + originMm, + voxelSizeMm, + centerMm, + rigid: opt.best, + }); + + // Only apply if the score actually improved. + if (!(after.ncc > before.ncc + 1e-3)) { + debugSvrLog( + 'registration.roi-rigid.skip', + { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + nccBefore: before.ncc, + nccAfter: after.ncc, + used: after.used, + }, + debug + ); + continue; + } + + const rot = mat3FromEulerXYZ(opt.best.rx, opt.best.ry, opt.best.rz); + const tMm = v3(opt.best.tx, opt.best.ty, opt.best.tz); + + applyRigidToSeriesSlices({ slices: movingSlices, centerMm, rot, tMm }); + + console.info('[svr] ROI rigid series alignment applied', { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + nccBefore: Number(before.ncc.toFixed(4)), + nccAfter: Number(after.ncc.toFixed(4)), + usedSamples: after.used, + evals: opt.evals, + translateMm: { + x: Number(opt.best.tx.toFixed(3)), + y: Number(opt.best.ty.toFixed(3)), + z: Number(opt.best.tz.toFixed(3)), + }, + rotateDeg: { + x: Number((opt.best.rx * (180 / Math.PI)).toFixed(3)), + y: Number((opt.best.ry * (180 / Math.PI)).toFixed(3)), + z: Number((opt.best.rz * (180 / Math.PI)).toFixed(3)), + }, + }); + + debugSvrLog( + 'registration.roi-rigid', + { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + samples: samples.count, + usedSamples: after.used, + nccBefore: before.ncc, + nccAfter: after.ncc, + evals: opt.evals, + translateMm: { x: opt.best.tx, y: opt.best.ty, z: opt.best.tz }, + rotateRad: { x: opt.best.rx, y: opt.best.ry, z: opt.best.rz }, + }, + debug + ); + + await yieldToMain(); + } +} + +async function loadSeriesSlices(params: { + series: SvrSelectedSeries; + sliceDownsampleMode: SvrParams['sliceDownsampleMode']; + sliceDownsampleMaxSize: number; + targetVoxelSizeMm: number; + maxIntensitySamples: number; + signal?: AbortSignal; + onProgress?: (p: SvrProgress) => void; + progressBase: { current: number; total: number }; + debug?: boolean; +}): Promise<{ slices: LoadedSlice[]; intensitySamples: number[] }> { + const { + series, + sliceDownsampleMode, + sliceDownsampleMaxSize, + targetVoxelSizeMm, + maxIntensitySamples, + signal, + onProgress, + progressBase, + debug, + } = params; + + const db = await getDB(); + const uids = await getSortedSopInstanceUidsForSeries(series.seriesUid); + + const slices: LoadedSlice[] = []; + + // Deterministic sampling for robust global normalization. + const intensitySamples: number[] = []; + let intensityApproxMin = Number.POSITIVE_INFINITY; + let intensityApproxMax = Number.NEGATIVE_INFINITY; + + const perSliceTarget = Math.max(64, Math.ceil(maxIntensitySamples / Math.max(1, uids.length))); + + const resampleKernel = getSvrSliceResampleKernel(debug); + debugSvrLog( + 'slice.downsample', + { + seriesUid: series.seriesUid, + label: series.label, + kernel: resampleKernel, + }, + !!debug + ); + + for (let i = 0; i < uids.length; i++) { + assertNotAborted(signal); + + const sopInstanceUid = uids[i]; + if (!sopInstanceUid) continue; + + const inst = (await db.get('instances', sopInstanceUid)) as DicomInstance | undefined; + if (!inst) continue; + + const sliceThicknessMm = typeof inst.sliceThickness === 'number' && inst.sliceThickness > 0 ? inst.sliceThickness : null; + const spacingBetweenSlicesMm = + typeof inst.spacingBetweenSlices === 'number' && inst.spacingBetweenSlices > 0 ? inst.spacingBetweenSlices : null; + + const geom: SliceGeometry = getSliceGeometryFromInstance(inst); + + const { dsRows, dsCols } = computeSvrDownsampleSize({ + rows: geom.rows, + cols: geom.cols, + maxSize: sliceDownsampleMaxSize, + mode: sliceDownsampleMode, + rowSpacingMm: geom.rowSpacingMm, + colSpacingMm: geom.colSpacingMm, + targetVoxelSizeMm, + }); + + // Adjust spacings for the downsampled grid (physical FOV preserved). + const rowSpacingDsMm = geom.rowSpacingMm * (geom.rows / dsRows); + const colSpacingDsMm = geom.colSpacingMm * (geom.cols / dsCols); + + // Decode pixels via Cornerstone (uses our miradb: loader + codecs). + const imageId = `miradb:${sopInstanceUid}`; + const image = await cornerstone.loadImage(imageId); + + const getPixelData = (image as unknown as { getPixelData?: () => ArrayLike }).getPixelData; + if (typeof getPixelData !== 'function') { + throw new Error('Cornerstone image did not expose getPixelData()'); + } + + const pixelData = getPixelData.call(image); + + // Higher-fidelity downsampling (anti-aliasing) to reduce aliasing. + // Default is box/area averaging; Lanczos is available behind a debug flag. + const down = + resampleKernel === 'lanczos3' + ? resample2dLanczos3(pixelData, geom.rows, geom.cols, dsRows, dsCols) + : resample2dAreaAverage(pixelData, geom.rows, geom.cols, dsRows, dsCols); + + // Apply modality scaling when available. (Linear, so applying post-downsample is equivalent.) + const slope = typeof (image as unknown as { slope?: unknown }).slope === 'number' ? (image as unknown as { slope: number }).slope : 1; + const intercept = + typeof (image as unknown as { intercept?: unknown }).intercept === 'number' ? (image as unknown as { intercept: number }).intercept : 0; + + if (slope !== 1 || intercept !== 0) { + for (let p = 0; p < down.length; p++) { + down[p] = down[p] * slope + intercept; + } + } + + // Sample intensities deterministically for robust global normalization. + if (intensitySamples.length < maxIntensitySamples) { + const stride = Math.max(1, Math.floor(down.length / perSliceTarget)); + for (let p = 0; p < down.length && intensitySamples.length < maxIntensitySamples; p += stride) { + const v = down[p] ?? 0; + if (!Number.isFinite(v)) continue; + intensitySamples.push(v); + if (v < intensityApproxMin) intensityApproxMin = v; + if (v > intensityApproxMax) intensityApproxMax = v; + } + } + + slices.push({ + seriesUid: series.seriesUid, + sopInstanceUid, + pixels: down, + dsRows, + dsCols, + srcRows: geom.rows, + srcCols: geom.cols, + rowSpacingMm: geom.rowSpacingMm, + colSpacingMm: geom.colSpacingMm, + sliceThicknessMm, + spacingBetweenSlicesMm, + ippMm: geom.ippMm, + rowDir: geom.rowDir, + colDir: geom.colDir, + normalDir: geom.normalDir, + rowSpacingDsMm, + colSpacingDsMm, + }); + + if (i % 8 === 0) { + onProgress?.({ + phase: 'loading', + current: progressBase.current + i, + total: progressBase.total, + message: `Decoding slices (${series.label}) ${i + 1}/${uids.length}`, + }); + await yieldToMain(); + } + } + + if (debug && slices.length > 0) { + const s0 = slices[0]; + const n0 = s0.normalDir; + + let minAbsNDot = 1; + const along: number[] = []; + + for (const s of slices) { + const n = s.normalDir; + const absDot = Math.abs(dot(n, n0)); + if (absDot < minAbsNDot) minAbsNDot = absDot; + + // Use the normal from the first slice to compute approximate slice-to-slice spacing. + along.push(dot(s.ippMm, n0)); + } + + along.sort((a, b) => a - b); + const deltas: number[] = []; + for (let i = 0; i < along.length - 1; i++) { + const d = Math.abs((along[i + 1] ?? 0) - (along[i] ?? 0)); + if (Number.isFinite(d) && d > 0) deltas.push(d); + } + deltas.sort((a, b) => a - b); + const sliceSpacingMm = deltas.length + ? deltas.length % 2 === 1 + ? deltas[Math.floor(deltas.length / 2)] + : ((deltas[deltas.length / 2 - 1] ?? 0) + (deltas[deltas.length / 2] ?? 0)) / 2 + : null; + + const median = (values: Array): number | null => { + const v = values.filter((x) => typeof x === 'number' && Number.isFinite(x)).sort((a, b) => (a as number) - (b as number)); + if (v.length === 0) return null; + const mid = Math.floor(v.length / 2); + return v.length % 2 === 1 ? (v[mid] as number) : (((v[mid - 1] as number) + (v[mid] as number)) / 2); + }; + + const sliceThicknessMedianMm = median(slices.map((s) => s.sliceThicknessMm)); + const spacingBetweenSlicesMedianMm = median(slices.map((s) => s.spacingBetweenSlicesMm)); + + debugSvrLog( + 'series.loaded', + { + label: series.label, + seriesUid: series.seriesUid, + loadedSlices: slices.length, + srcRows: s0.srcRows, + srcCols: s0.srcCols, + dsRows: s0.dsRows, + dsCols: s0.dsCols, + rowSpacingMm: s0.rowSpacingMm, + colSpacingMm: s0.colSpacingMm, + rowSpacingDsMm: s0.rowSpacingDsMm, + colSpacingDsMm: s0.colSpacingDsMm, + approxSliceSpacingMm: sliceSpacingMm, + sliceThicknessMedianMm, + spacingBetweenSlicesMedianMm, + normalConsistencyMinAbsDot: Number(minAbsNDot.toFixed(6)), + intensityApprox: { + min: Number.isFinite(intensityApproxMin) ? Number(intensityApproxMin.toFixed(4)) : null, + max: Number.isFinite(intensityApproxMax) ? Number(intensityApproxMax.toFixed(4)) : null, + samples: intensitySamples.length, + }, + }, + true + ); + + if (minAbsNDot < 0.999) { + console.warn('[svr] Inconsistent slice normals detected within a series (oblique drift?)', { + seriesUid: series.seriesUid, + label: series.label, + minAbsDot: minAbsNDot, + }); + } + } + + return { slices, intensitySamples }; +} + + +function computeBoundsMm(slices: LoadedSlice[]): { min: Vec3; max: Vec3 } { + let minX = Number.POSITIVE_INFINITY; + let minY = Number.POSITIVE_INFINITY; + let minZ = Number.POSITIVE_INFINITY; + let maxX = Number.NEGATIVE_INFINITY; + let maxY = Number.NEGATIVE_INFINITY; + let maxZ = Number.NEGATIVE_INFINITY; + + for (const s of slices) { + const corners = sliceCornersMm({ + ippMm: s.ippMm, + rowDir: s.rowDir, + colDir: s.colDir, + rowSpacingMm: s.rowSpacingDsMm, + colSpacingMm: s.colSpacingDsMm, + rows: s.dsRows, + cols: s.dsCols, + }); + + for (const p of corners) { + if (p.x < minX) minX = p.x; + if (p.y < minY) minY = p.y; + if (p.z < minZ) minZ = p.z; + if (p.x > maxX) maxX = p.x; + if (p.y > maxY) maxY = p.y; + if (p.z > maxZ) maxZ = p.z; + } + } + + if (!Number.isFinite(minX) || !Number.isFinite(maxX)) { + throw new Error('Failed to compute bounds for SVR'); + } + + // Small padding to avoid clipping due to rounding. + const pad = 1; + return { + min: v3(minX - pad, minY - pad, minZ - pad), + max: v3(maxX + pad, maxY + pad, maxZ + pad), + }; +} + +function chooseOutputGrid(params: { bounds: { min: Vec3; max: Vec3 }; voxelSizeMm: number; maxDim: number }): { + originMm: Vec3; + voxelSizeMm: number; + dims: VolumeDims; +} { + const { bounds, maxDim } = params; + + let voxelSizeMm = params.voxelSizeMm; + if (!Number.isFinite(voxelSizeMm) || voxelSizeMm <= 0) voxelSizeMm = 1; + + const extentX = bounds.max.x - bounds.min.x; + const extentY = bounds.max.y - bounds.min.y; + const extentZ = bounds.max.z - bounds.min.z; + + const dimFor = (extent: number, vox: number) => Math.max(2, Math.ceil(extent / vox) + 1); + + // Increase voxel size if any dimension is above maxDim. + for (let attempt = 0; attempt < 10; attempt++) { + const nx = dimFor(extentX, voxelSizeMm); + const ny = dimFor(extentY, voxelSizeMm); + const nz = dimFor(extentZ, voxelSizeMm); + + const maxD = Math.max(nx, ny, nz); + if (maxD <= maxDim) { + return { + originMm: bounds.min, + voxelSizeMm, + dims: { nx, ny, nz }, + }; + } + + voxelSizeMm *= maxD / maxDim; + } + + const nx = dimFor(extentX, voxelSizeMm); + const ny = dimFor(extentY, voxelSizeMm); + const nz = dimFor(extentZ, voxelSizeMm); + + return { + originMm: bounds.min, + voxelSizeMm, + dims: { nx, ny, nz }, + }; +} + + +export async function reconstructVolumeMultiPlane(params: { + selectedSeries: SvrSelectedSeries[]; + svrParams: SvrParams; + signal?: AbortSignal; + onProgress?: (p: SvrProgress) => void; +}): Promise { + const { selectedSeries, svrParams, signal, onProgress } = params; + if (selectedSeries.length < 2) { + throw new Error('Select at least 2 series (multi-plane) for SVR'); + } + + const t0 = performance.now(); + + // 1) Decode + downsample slices. + onProgress?.({ phase: 'loading', current: 0, total: 100, message: 'Loading slices…' }); + + const allSlices: LoadedSlice[] = []; + + // Intensity normalization samples (global across all selected series). + const intensitySamples: number[] = []; + const intensitySamplesBySeries = new Map(); + + // Allocate progress budget: 0..50 for decoding. + const decodeTotal = selectedSeries.reduce((acc, s) => acc + Math.max(1, s.instanceCount), 0); + let decodeBase = 0; + + const debug = isDebugSvrEnabled(); + + if (!debug) { + console.info("[svr] Tip: enable verbose SVR logs with localStorage.setItem('miraviewer:debug-svr', '1')"); + } + + console.info('[svr] Reconstruction started', { + seriesCount: selectedSeries.length, + roi: svrParams.roi ? { mode: svrParams.roi.mode, sourcePlane: svrParams.roi.sourcePlane } : null, + seriesRegistrationMode: svrParams.seriesRegistrationMode, + voxelSizeMm: svrParams.targetVoxelSizeMm, + maxVolumeDim: svrParams.maxVolumeDim, + sliceDownsampleMode: svrParams.sliceDownsampleMode, + sliceDownsampleMaxSize: svrParams.sliceDownsampleMaxSize, + iterations: svrParams.iterations, + stepSize: svrParams.stepSize, + }); + + const MAX_INTENSITY_SAMPLES_TOTAL = 50_000; + const maxIntensitySamplesPerSeries = Math.max(2048, Math.ceil(MAX_INTENSITY_SAMPLES_TOTAL / Math.max(1, selectedSeries.length))); + + for (const series of selectedSeries) { + assertNotAborted(signal); + + const loaded = await loadSeriesSlices({ + series, + sliceDownsampleMode: svrParams.sliceDownsampleMode, + sliceDownsampleMaxSize: svrParams.sliceDownsampleMaxSize, + targetVoxelSizeMm: svrParams.targetVoxelSizeMm, + maxIntensitySamples: maxIntensitySamplesPerSeries, + signal, + onProgress, + progressBase: { current: decodeBase, total: decodeTotal }, + debug, + }); + + const slices = loaded.slices; + const seriesSamples = loaded.intensitySamples; + + if (slices.length > 0) { + const s0 = slices[0]; + console.info('[svr] Series decoded', { + label: series.label, + seriesUid: series.seriesUid, + loadedSlices: slices.length, + srcRows: s0.srcRows, + srcCols: s0.srcCols, + dsRows: s0.dsRows, + dsCols: s0.dsCols, + rowSpacingMm: Number(s0.rowSpacingMm.toFixed(4)), + colSpacingMm: Number(s0.colSpacingMm.toFixed(4)), + rowSpacingDsMm: Number(s0.rowSpacingDsMm.toFixed(4)), + colSpacingDsMm: Number(s0.colSpacingDsMm.toFixed(4)), + }); + } + + decodeBase += Math.max(1, series.instanceCount); + allSlices.push(...slices); + + for (const v of seriesSamples) { + intensitySamples.push(v); + } + + if (seriesSamples.length > 0) { + const prev = intensitySamplesBySeries.get(series.seriesUid); + if (prev) { + prev.push(...seriesSamples); + } else { + intensitySamplesBySeries.set(series.seriesUid, [...seriesSamples]); + } + } + + await yieldToMain(); + } + + if (allSlices.length === 0) { + throw new Error('No slices loaded for SVR'); + } + + // Normalize all slices to [0,1] using a robust global percentile window. + // + // Why: + // - per-series min/max is unstable (outliers/background dominate) + // - cross-series fusion and ROI rigid alignment benefit from a shared intensity domain + const finite = intensitySamples.filter((v) => Number.isFinite(v)).sort((a, b) => a - b); + + const quantileSorted = (sorted: number[], q: number): number => { + const n = sorted.length; + if (n === 0) return 0; + const qq = q < 0 ? 0 : q > 1 ? 1 : q; + const idx = qq * (n - 1); + const i0 = Math.floor(idx); + const i1 = Math.min(n - 1, i0 + 1); + const t = idx - i0; + const a = sorted[i0] ?? 0; + const b = sorted[i1] ?? a; + return a + (b - a) * t; + }; + + const getHistogramMatchingEnabled = (debug?: boolean): boolean => { + if (!debug) return false; + try { + return localStorage.getItem('miraviewer:svr-histmatch') === '1'; + } catch { + return false; + } + }; + + const histMatchEnabled = getHistogramMatchingEnabled(debug); + + // If enabled, we do a simple piecewise-linear quantile mapping per series + // (approximate histogram matching) before global percentile normalization. + const HM_Q = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] as const; + + const refQs = HM_Q.map((q) => quantileSorted(finite, q)); + + if (histMatchEnabled && finite.length > 0) { + const perSeriesMap = new Map(); + + for (const [uid, samples] of intensitySamplesBySeries) { + const sSorted = samples.filter((v) => Number.isFinite(v)).sort((a, b) => a - b); + if (sSorted.length < 16) continue; + + const srcQs = HM_Q.map((q) => quantileSorted(sSorted, q)); + + // Skip degenerate distributions. + const lo = srcQs[0] ?? 0; + const hi = srcQs[srcQs.length - 1] ?? lo; + if (!(hi > lo + 1e-12)) continue; + + perSeriesMap.set(uid, { srcQs, dstQs: [...refQs] }); + } + + const mapValue = (v: number, m: { srcQs: number[]; dstQs: number[] }): number => { + const src = m.srcQs; + const dst = m.dstQs; + const n = Math.min(src.length, dst.length); + if (n < 2) return v; + + if (v <= (src[0] ?? v)) return dst[0] ?? v; + if (v >= (src[n - 1] ?? v)) return dst[n - 1] ?? v; + + // Small n (9), so linear scan is fine. + let i = 0; + while (i < n - 1 && v > (src[i + 1] ?? Number.POSITIVE_INFINITY)) i++; + + const x0 = src[i] ?? v; + const x1 = src[i + 1] ?? x0; + const y0 = dst[i] ?? v; + const y1 = dst[i + 1] ?? y0; + + const den = x1 - x0; + if (!(den > 1e-12)) return y0; + + const t = (v - x0) / den; + return y0 + (y1 - y0) * t; + }; + + let matchedSeries = 0; + + for (const s of allSlices) { + const m = perSeriesMap.get(s.seriesUid); + if (!m) continue; + + for (let i = 0; i < s.pixels.length; i++) { + s.pixels[i] = mapValue(s.pixels[i] ?? 0, m); + } + } + + matchedSeries = perSeriesMap.size; + + console.info('[svr] Histogram matching', { + enabled: true, + seriesMatched: matchedSeries, + quantiles: HM_Q, + }); + } + + let winLo = 0; + let winHi = 1; + + if (finite.length > 0) { + winLo = refQs[0] ?? quantileSorted(finite, 0.01); + winHi = refQs[refQs.length - 1] ?? quantileSorted(finite, 0.99); + + // Fallback if the distribution is degenerate. + if (!(winHi > winLo + 1e-12)) { + winLo = finite[0] ?? 0; + winHi = finite[finite.length - 1] ?? winLo; + } + } + + const invWinRange = winHi > winLo + 1e-12 ? 1 / (winHi - winLo) : 0; + + console.info('[svr] Intensity normalization', { + method: histMatchEnabled ? 'histmatch+global-percentile' : 'global-percentile', + pLow: 1, + pHigh: 99, + window: { lo: Number(winLo.toFixed(4)), hi: Number(winHi.toFixed(4)) }, + samples: finite.length, + }); + + for (const s of allSlices) { + for (let i = 0; i < s.pixels.length; i++) { + const v = s.pixels[i] ?? 0; + const n = invWinRange > 0 ? (v - winLo) * invWinRange : 0; + s.pixels[i] = clamp01(n); + } + } + + // 2) Optional coarse inter-series alignment. + // + // Note: roi-rigid builds on top of bounds-center as a cheap initial guess. + const wantsBoundsCenter = + svrParams.seriesRegistrationMode === 'bounds-center' || svrParams.seriesRegistrationMode === 'roi-rigid'; + + if (wantsBoundsCenter) { + onProgress?.({ phase: 'initializing', current: 52, total: 100, message: 'Coarse series alignment…' }); + + const bySeries = new Map(); + for (const s of allSlices) { + const arr = bySeries.get(s.seriesUid); + if (arr) arr.push(s); + else bySeries.set(s.seriesUid, [s]); + } + + // Pick reference series: + // - Prefer the ROI's source series (if provided), so the ROI stays in the same coordinate frame. + // - Otherwise fallback to "most loaded slices" (stable, data-driven heuristic). + const roiReferenceUid = svrParams.roi?.sourceSeriesUid ?? null; + + let referenceUid: string | null = null; + + if (roiReferenceUid && bySeries.has(roiReferenceUid)) { + referenceUid = roiReferenceUid; + } else { + let bestCount = -1; + for (const [uid, arr] of bySeries) { + if (arr.length > bestCount) { + referenceUid = uid; + bestCount = arr.length; + } + } + } + + const refSlices = referenceUid ? bySeries.get(referenceUid) : null; + if (referenceUid && refSlices && refSlices.length > 0) { + const refBounds = computeBoundsMm(refSlices); + const refCenter = v3( + (refBounds.min.x + refBounds.max.x) * 0.5, + (refBounds.min.y + refBounds.max.y) * 0.5, + (refBounds.min.z + refBounds.max.z) * 0.5 + ); + + debugSvrLog( + 'registration.reference', + { + referenceUid, + loadedSlices: refSlices.length, + centerMm: { x: refCenter.x, y: refCenter.y, z: refCenter.z }, + }, + debug + ); + + for (const [uid, slices] of bySeries) { + if (uid === referenceUid) continue; + if (slices.length === 0) continue; + + const b = computeBoundsMm(slices); + const center = v3((b.min.x + b.max.x) * 0.5, (b.min.y + b.max.y) * 0.5, (b.min.z + b.max.z) * 0.5); + const t = v3(refCenter.x - center.x, refCenter.y - center.y, refCenter.z - center.z); + const tMag = Math.sqrt(dot(t, t)); + + // Apply translation by shifting IPP for each slice. + for (const s of slices) { + s.ippMm = v3(s.ippMm.x + t.x, s.ippMm.y + t.y, s.ippMm.z + t.z); + } + + debugSvrLog( + 'registration.bounds-center', + { + seriesUid: uid, + translateMm: { x: Number(t.x.toFixed(3)), y: Number(t.y.toFixed(3)), z: Number(t.z.toFixed(3)) }, + magnitudeMm: Number(tMag.toFixed(3)), + }, + debug + ); + + // Warn if we're doing something large; this is often a sign of inconsistent DICOM spatial tags. + if (tMag > 20) { + console.warn('[svr] Large coarse alignment translation applied', { + seriesUid: uid, + magnitudeMm: tMag, + translateMm: t, + }); + } + } + } + } + + // 3) Choose output grid (axis-aligned in patient/world coordinates). + const allBounds = computeBoundsMm(allSlices); + + const roi = svrParams.roi ?? null; + const bounds = roi ? intersectBoundsMm(allBounds, boundsFromRoi(roi)) : allBounds; + if (roi) { + assertNonEmptyBounds(bounds, `roi=${roi.mode}/${roi.sourcePlane}`); + } + + onProgress?.({ + phase: 'initializing', + current: 55, + total: 100, + message: roi ? 'Computing output grid (ROI)…' : 'Computing output grid…', + }); + + const iterations = Math.max(0, Math.round(svrParams.iterations)); + + const estimatePeakBytes = (nvox: number, iters: number): number => { + // Persistent arrays: + // - volume, weight + // Per-iteration arrays (allocated once per iter, but they overlap with volume/weight at peak): + // - update, updateW + const floatBytes = 4; + const arrays = iters > 0 ? 4 : 2; + return arrays * nvox * floatBytes; + }; + + const formatMiB = (bytes: number): string => `${Math.round(bytes / (1024 * 1024))}MiB`; + + // Rough safety budget to avoid browser OOM / tab crashes. + // Note: this is only for the core volume arrays; it does not include slice buffers, JS overhead, or GPU textures. + const MAX_PEAK_BYTES = 512 * 1024 * 1024; + + let grid = chooseOutputGrid({ + bounds, + voxelSizeMm: svrParams.targetVoxelSizeMm, + maxDim: svrParams.maxVolumeDim, + }); + + // Preflight: if the volume would be huge, auto-increase voxel size until it fits a memory budget. + // This prevents hard crashes/hangs from attempting multi-hundred-MiB allocations. + for (let attempt = 0; attempt < 6; attempt++) { + const nvox = grid.dims.nx * grid.dims.ny * grid.dims.nz; + const peakBytes = estimatePeakBytes(nvox, iterations); + + if (peakBytes <= MAX_PEAK_BYTES) break; + + const factor = Math.cbrt(peakBytes / MAX_PEAK_BYTES) * 1.05; + const nextVoxelSizeMm = grid.voxelSizeMm * factor; + + console.warn('[svr] Volume would be too large; increasing voxel size to fit memory budget', { + attempt: attempt + 1, + dims: grid.dims, + voxelSizeMm: Number(grid.voxelSizeMm.toFixed(4)), + nextVoxelSizeMm: Number(nextVoxelSizeMm.toFixed(4)), + peak: formatMiB(peakBytes), + budget: formatMiB(MAX_PEAK_BYTES), + iterations, + maxVolumeDim: svrParams.maxVolumeDim, + roi: roi ? { mode: roi.mode, sourcePlane: roi.sourcePlane } : null, + }); + + grid = chooseOutputGrid({ + bounds, + voxelSizeMm: nextVoxelSizeMm, + maxDim: svrParams.maxVolumeDim, + }); + } + + const { dims, originMm, voxelSizeMm } = grid; + const nvox = dims.nx * dims.ny * dims.nz; + const peakBytes = estimatePeakBytes(nvox, iterations); + + if (peakBytes > MAX_PEAK_BYTES) { + throw new Error( + `SVR volume too large (${dims.nx}×${dims.ny}×${dims.nz}); estimated peak ${formatMiB(peakBytes)} exceeds budget ${formatMiB( + MAX_PEAK_BYTES + )}. Try enabling ROI, increasing voxel size, lowering maxVolumeDim, or reducing iterations.` + ); + } + + const voxelSizeIncreased = voxelSizeMm > svrParams.targetVoxelSizeMm + 1e-6; + console.info('[svr] Output grid chosen', { + roi: roi ? { mode: roi.mode, sourcePlane: roi.sourcePlane } : null, + voxelSizeMm: Number(voxelSizeMm.toFixed(4)), + targetVoxelSizeMm: Number(svrParams.targetVoxelSizeMm.toFixed(4)), + voxelSizeIncreased, + maxVolumeDim: svrParams.maxVolumeDim, + dims, + estimatedPeak: formatMiB(peakBytes), + iterations, + boundsMm: { + min: { x: Number(bounds.min.x.toFixed(3)), y: Number(bounds.min.y.toFixed(3)), z: Number(bounds.min.z.toFixed(3)) }, + max: { x: Number(bounds.max.x.toFixed(3)), y: Number(bounds.max.y.toFixed(3)), z: Number(bounds.max.z.toFixed(3)) }, + }, + }); + + // 3) Optional ROI-local rigid alignment (translation + small rotation). + // + // This is intentionally done *after* selecting the output grid so the similarity metric is + // computed in the same coordinate frame we will use for the final reconstruction. + if (svrParams.seriesRegistrationMode === 'roi-rigid') { + if (!roi) { + console.info('[svr] roi-rigid requested but no ROI provided; falling back to bounds-center only'); + } else { + onProgress?.({ phase: 'initializing', current: 56, total: 100, message: 'ROI rigid alignment…' }); + await rigidAlignSeriesInRoi({ + allSlices, + selectedSeries, + roiBounds: bounds, + dims, + originMm, + voxelSizeMm, + roi, + signal, + onProgress, + debug, + }); + } + } + + // 4) Crop slices to ROI bounds to speed up high-detail reconstructions. + if (roi) { + onProgress?.({ phase: 'initializing', current: 58, total: 100, message: 'Cropping slices to ROI…' }); + + const roiCorners = boundsCornersMm(bounds); + + const beforeCount = allSlices.length; + const cropped: LoadedSlice[] = []; + + for (let i = 0; i < allSlices.length; i++) { + assertNotAborted(signal); + const s = allSlices[i]; + if (!s) continue; + + if (cropSliceToRoiInPlace(s, roiCorners)) { + cropped.push(s); + } + + if (i % 8 === 0) { + await yieldToMain(); + } + } + + // Replace in-place so existing references remain valid. + allSlices.length = 0; + allSlices.push(...cropped); + + console.info('[svr] Cropped slices to ROI', { + beforeCount, + afterCount: allSlices.length, + }); + } + + // 5) Reconstruction (higher-fidelity forward model + solver). + onProgress?.({ phase: 'reconstructing', current: 60, total: 100, message: 'Reconstructing volume…' }); + + const solverOptions: SvrReconstructionOptions = { + iterations, + stepSize: svrParams.stepSize, + clampOutput: svrParams.clampOutput, + psfMode: svrParams.psfMode ?? 'gaussian', + robustLoss: svrParams.robustLoss ?? 'huber', + robustDelta: typeof svrParams.robustDelta === 'number' ? svrParams.robustDelta : 0.1, + laplacianWeight: typeof svrParams.laplacianWeight === 'number' ? svrParams.laplacianWeight : 0, + }; + + debugSvrLog( + 'solver.options', + { + psfMode: solverOptions.psfMode, + robustLoss: solverOptions.robustLoss, + robustDelta: solverOptions.robustDelta, + laplacianWeight: solverOptions.laplacianWeight, + multiResolution: svrParams.multiResolution, + multiResolutionFactor: svrParams.multiResolutionFactor, + multiResolutionCoarseIterations: svrParams.multiResolutionCoarseIterations, + }, + debug + ); + + const fineGrid: SvrReconstructionGrid = { dims, originMm, voxelSizeMm }; + + const multiresEnabled = + !!svrParams.multiResolution && + typeof svrParams.multiResolutionFactor === 'number' && + svrParams.multiResolutionFactor > 1.01 && + typeof svrParams.multiResolutionCoarseIterations === 'number' && + svrParams.multiResolutionCoarseIterations > 0 && + iterations > 0; + + let volume: Float32Array; + + if (multiresEnabled) { + const factor = Math.max(1.01, svrParams.multiResolutionFactor ?? 2); + const coarseVoxelSizeMm = voxelSizeMm * factor; + + const coarseGridSelected = chooseOutputGrid({ + bounds, + voxelSizeMm: coarseVoxelSizeMm, + maxDim: svrParams.maxVolumeDim, + }); + + const coarseGrid: SvrReconstructionGrid = { + dims: coarseGridSelected.dims, + originMm: coarseGridSelected.originMm, + voxelSizeMm: coarseGridSelected.voxelSizeMm, + }; + + const coarseIters = Math.max(0, Math.round(svrParams.multiResolutionCoarseIterations ?? 0)); + + onProgress?.({ phase: 'reconstructing', current: 62, total: 100, message: 'Coarse reconstruction…' }); + + const coarse = await reconstructVolumeFromSlices({ + slices: allSlices, + grid: coarseGrid, + options: { + ...solverOptions, + iterations: coarseIters, + }, + hooks: { + signal, + yieldToMain, + }, + }); + + onProgress?.({ phase: 'reconstructing', current: 66, total: 100, message: 'Upsampling coarse volume…' }); + + volume = await resampleVolumeToGridTrilinear({ + src: coarse, + srcGrid: coarseGrid, + dstGrid: fineGrid, + hooks: { + signal, + yieldToMain, + }, + }); + + onProgress?.({ phase: 'reconstructing', current: 70, total: 100, message: 'Refining volume…' }); + + await refineVolumeInPlace({ + volume, + slices: allSlices, + grid: fineGrid, + options: solverOptions, + hooks: { + signal, + yieldToMain, + }, + }); + } else { + volume = await reconstructVolumeFromSlices({ + slices: allSlices, + grid: fineGrid, + options: solverOptions, + hooks: { + signal, + yieldToMain, + }, + }); + } + + // 5) Previews. + onProgress?.({ phase: 'finalizing', current: 95, total: 100, message: 'Generating previews…' }); + + const previews = await generateVolumePreviews({ + volume, + dims, + maxSize: 256, + }); + + onProgress?.({ phase: 'finalizing', current: 100, total: 100, message: `Done (${Math.round(performance.now() - t0)}ms)` }); + + return { + volume: { + data: volume, + dims: [dims.nx, dims.ny, dims.nz], + voxelSizeMm: [voxelSizeMm, voxelSizeMm, voxelSizeMm], + originMm: [originMm.x, originMm.y, originMm.z], + boundsMm: { + min: [bounds.min.x, bounds.min.y, bounds.min.z], + max: [bounds.max.x, bounds.max.y, bounds.max.z], + }, + }, + previews, + }; +} diff --git a/frontend/src/utils/svr/reconstructionCore.ts b/frontend/src/utils/svr/reconstructionCore.ts new file mode 100644 index 0000000..8c9b105 --- /dev/null +++ b/frontend/src/utils/svr/reconstructionCore.ts @@ -0,0 +1,473 @@ +import type { VolumeDims } from './trilinear'; +import { sampleTrilinear, splatTrilinearScaled } from './trilinear'; +import type { Vec3 } from './vec3'; + +export type SvrPsfMode = 'none' | 'box' | 'gaussian'; +export type SvrRobustLoss = 'none' | 'huber' | 'tukey'; + +export type SvrReconstructionOptions = { + iterations: number; + stepSize: number; + clampOutput: boolean; + + // Forward model knobs + psfMode: SvrPsfMode; + + // Solver knobs + robustLoss: SvrRobustLoss; + robustDelta: number; + laplacianWeight: number; +}; + +export type SvrReconstructionGrid = { + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; +}; + +export type SvrReconstructionSlice = { + // Downsampled pixel grid (normalized to [0,1]) + pixels: Float32Array; + dsRows: number; + dsCols: number; + + // Spatial mapping + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + normalDir: Vec3; + + rowSpacingDsMm: number; + colSpacingDsMm: number; + + // Optional thickness/spacing hints (if present in DICOM metadata) + sliceThicknessMm: number | null; + spacingBetweenSlicesMm: number | null; +}; + +export type SvrCoreHooks = { + signal?: AbortSignal; + yieldToMain?: () => Promise; + onProgress?: (p: { current: number; total: number; message: string }) => void; +}; + +function assertNotAborted(signal?: AbortSignal): void { + if (signal?.aborted) { + throw new Error('SVR cancelled'); + } +} + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +function withinTrilinearSupport(dims: VolumeDims, x: number, y: number, z: number): boolean { + // sampleTrilinear/splatTrilinear require x0>=0 and x1= 0 && y >= 0 && z >= 0 && x < dims.nx - 1 && y < dims.ny - 1 && z < dims.nz - 1; +} + +type SlicePsf = { offsetsMm: Float32Array; weights: Float32Array; count: number; effectiveThicknessMm: number }; + +function buildSlicePsf(params: { + slice: SvrReconstructionSlice; + voxelSizeMm: number; + mode: SvrPsfMode; +}): SlicePsf { + const { slice, voxelSizeMm, mode } = params; + + if (mode === 'none') { + return { + offsetsMm: new Float32Array([0]), + weights: new Float32Array([1]), + count: 1, + effectiveThicknessMm: 0, + }; + } + + const hint = slice.sliceThicknessMm ?? slice.spacingBetweenSlicesMm; + const thicknessMm = typeof hint === 'number' && Number.isFinite(hint) && hint > 0 ? hint : voxelSizeMm; + + const ratio = thicknessMm / Math.max(1e-6, voxelSizeMm); + + // Keep the forward model cheap: a handful of samples along the slice normal. + // Use an odd count so the kernel is symmetric around offset=0. + let n = Math.round(ratio); + if (n < 1) n = 1; + if (n > 7) n = 7; + if (n % 2 === 0) n += 1; + + const offsetsMm = new Float32Array(n); + const weights = new Float32Array(n); + + const half = 0.5 * thicknessMm; + const step = thicknessMm / n; + + // Gaussian: distance-to-plane weighting within the thickness support. + // We pick sigma so that the tails are non-trivial within [-half, +half]. + const sigma = Math.max(1e-6, half * 0.5); + + let wSum = 0; + for (let i = 0; i < n; i++) { + const off = -half + (i + 0.5) * step; + offsetsMm[i] = off; + + let w = 1; + if (mode === 'gaussian') { + const u = off / sigma; + w = Math.exp(-0.5 * u * u); + } + + weights[i] = w; + wSum += w; + } + + if (wSum > 1e-12) { + const inv = 1 / wSum; + for (let i = 0; i < n; i++) { + weights[i] *= inv; + } + } + + return { offsetsMm, weights, count: n, effectiveThicknessMm: thicknessMm }; +} + +function robustResidualWeight(residual: number, mode: SvrRobustLoss, delta: number): number { + if (mode === 'none') return 1; + + const a = Math.abs(residual); + const d = Number.isFinite(delta) && delta > 1e-12 ? delta : 0.1; + + if (mode === 'huber') { + return a <= d ? 1 : d / a; + } + + // Tukey's biweight. + if (a >= d) return 0; + const r = a / d; + const t = 1 - r * r; + return t * t; +} + +function normalizeVolumeInPlace(volume: Float32Array, weight: Float32Array): void { + for (let i = 0; i < volume.length; i++) { + const w = weight[i]; + volume[i] = w > 1e-12 ? volume[i] / w : 0; + } +} + +function laplacianSmoothInPlace(volume: Float32Array, dims: VolumeDims, lambda: number, scratch: Float32Array): void { + if (!(lambda > 0)) return; + const { nx, ny, nz } = dims; + if (nx < 3 || ny < 3 || nz < 3) return; + + const strideY = nx; + const strideZ = nx * ny; + + // Compute Laplacian into scratch (interior only). + for (let z = 1; z < nz - 1; z++) { + const zBase = z * strideZ; + for (let y = 1; y < ny - 1; y++) { + const base = zBase + y * strideY; + for (let x = 1; x < nx - 1; x++) { + const idx = base + x; + const c = volume[idx] ?? 0; + + const sum = + (volume[idx - 1] ?? 0) + + (volume[idx + 1] ?? 0) + + (volume[idx - strideY] ?? 0) + + (volume[idx + strideY] ?? 0) + + (volume[idx - strideZ] ?? 0) + + (volume[idx + strideZ] ?? 0); + + scratch[idx] = sum - 6 * c; + } + } + } + + // Apply update (interior only). + for (let z = 1; z < nz - 1; z++) { + const zBase = z * strideZ; + for (let y = 1; y < ny - 1; y++) { + const base = zBase + y * strideY; + for (let x = 1; x < nx - 1; x++) { + const idx = base + x; + const lap = scratch[idx] ?? 0; + volume[idx] = (volume[idx] ?? 0) + lambda * lap; + } + } + } +} + +export async function reconstructVolumeFromSlices(params: { + slices: SvrReconstructionSlice[]; + grid: SvrReconstructionGrid; + options: SvrReconstructionOptions; + hooks?: SvrCoreHooks; +}): Promise { + const { slices, grid, options, hooks } = params; + const { dims, originMm, voxelSizeMm } = grid; + + const yieldToMain = hooks?.yieldToMain ?? (async () => {}); + + const nvox = dims.nx * dims.ny * dims.nz; + const volume = new Float32Array(nvox); + const weight = new Float32Array(nvox); + + const psfBySlice = slices.map((s) => buildSlicePsf({ slice: s, voxelSizeMm, mode: options.psfMode })); + + // 1) Initial splat (backprojection of observations). + const invVox = 1 / voxelSizeMm; + + for (let sIdx = 0; sIdx < slices.length; sIdx++) { + assertNotAborted(hooks?.signal); + const s = slices[sIdx]; + if (!s) continue; + + const psf = psfBySlice[sIdx]; + + for (let r = 0; r < s.dsRows; r++) { + const baseX = s.ippMm.x + s.colDir.x * (r * s.rowSpacingDsMm); + const baseY = s.ippMm.y + s.colDir.y * (r * s.rowSpacingDsMm); + const baseZ = s.ippMm.z + s.colDir.z * (r * s.rowSpacingDsMm); + + const rowBase = r * s.dsCols; + + for (let c = 0; c < s.dsCols; c++) { + const obs = s.pixels[rowBase + c] ?? 0; + if (obs <= 0) continue; + + const wx0 = baseX + s.rowDir.x * (c * s.colSpacingDsMm); + const wy0 = baseY + s.rowDir.y * (c * s.colSpacingDsMm); + const wz0 = baseZ + s.rowDir.z * (c * s.colSpacingDsMm); + + for (let k = 0; k < psf.count; k++) { + const off = psf.offsetsMm[k] ?? 0; + const w = psf.weights[k] ?? 0; + if (!(w > 0)) continue; + + const wx = wx0 + s.normalDir.x * off; + const wy = wy0 + s.normalDir.y * off; + const wz = wz0 + s.normalDir.z * off; + + const vx = (wx - originMm.x) * invVox; + const vy = (wy - originMm.y) * invVox; + const vz = (wz - originMm.z) * invVox; + + if (!withinTrilinearSupport(dims, vx, vy, vz)) continue; + + splatTrilinearScaled(volume, weight, dims, vx, vy, vz, obs, w); + } + } + } + + if (sIdx % 4 === 0) { + hooks?.onProgress?.({ + current: sIdx, + total: slices.length, + message: `Splatting slices… ${sIdx + 1}/${slices.length}`, + }); + await yieldToMain(); + } + } + + normalizeVolumeInPlace(volume, weight); + + await refineVolumeInPlace({ volume, slices, grid, options, hooks, psfBySlice }); + + return volume; +} + +export async function refineVolumeInPlace(params: { + volume: Float32Array; + slices: SvrReconstructionSlice[]; + grid: SvrReconstructionGrid; + options: SvrReconstructionOptions; + hooks?: SvrCoreHooks; + psfBySlice?: SlicePsf[]; +}): Promise { + const { volume, slices, grid, options, hooks } = params; + const { dims, originMm, voxelSizeMm } = grid; + + const yieldToMain = hooks?.yieldToMain ?? (async () => {}); + + const invVox = 1 / voxelSizeMm; + + const psfBySlice = params.psfBySlice ?? slices.map((s) => buildSlicePsf({ slice: s, voxelSizeMm, mode: options.psfMode })); + + const nvox = dims.nx * dims.ny * dims.nz; + + // Iterative refinement: forward-project → residual → backproject. + const iterations = Math.max(0, Math.round(options.iterations)); + const stepSize = options.stepSize; + + // Scratch reused for update accumulation and regularization. + const update = new Float32Array(nvox); + const updateW = new Float32Array(nvox); + + for (let iter = 0; iter < iterations; iter++) { + assertNotAborted(hooks?.signal); + + update.fill(0); + updateW.fill(0); + + for (let sIdx = 0; sIdx < slices.length; sIdx++) { + assertNotAborted(hooks?.signal); + const s = slices[sIdx]; + if (!s) continue; + + const psf = psfBySlice[sIdx]; + + for (let r = 0; r < s.dsRows; r++) { + const baseX = s.ippMm.x + s.colDir.x * (r * s.rowSpacingDsMm); + const baseY = s.ippMm.y + s.colDir.y * (r * s.rowSpacingDsMm); + const baseZ = s.ippMm.z + s.colDir.z * (r * s.rowSpacingDsMm); + + const rowBase = r * s.dsCols; + + for (let c = 0; c < s.dsCols; c++) { + const obs = s.pixels[rowBase + c] ?? 0; + if (obs <= 0) continue; + + const wx0 = baseX + s.rowDir.x * (c * s.colSpacingDsMm); + const wy0 = baseY + s.rowDir.y * (c * s.colSpacingDsMm); + const wz0 = baseZ + s.rowDir.z * (c * s.colSpacingDsMm); + + // Forward projection: integrate the volume along the slice normal. + let pred = 0; + let wUsed = 0; + + for (let k = 0; k < psf.count; k++) { + const off = psf.offsetsMm[k] ?? 0; + const w = psf.weights[k] ?? 0; + if (!(w > 0)) continue; + + const wx = wx0 + s.normalDir.x * off; + const wy = wy0 + s.normalDir.y * off; + const wz = wz0 + s.normalDir.z * off; + + const vx = (wx - originMm.x) * invVox; + const vy = (wy - originMm.y) * invVox; + const vz = (wz - originMm.z) * invVox; + + if (!withinTrilinearSupport(dims, vx, vy, vz)) continue; + + pred += sampleTrilinear(volume, dims, vx, vy, vz) * w; + wUsed += w; + } + + if (!(wUsed > 1e-12)) continue; + pred /= wUsed; + + const residual = obs - pred; + const rW = robustResidualWeight(residual, options.robustLoss, options.robustDelta); + if (!(rW > 0)) continue; + + // Backproject residual into volume using the same PSF weights. + const scaleBase = rW / wUsed; + + for (let k = 0; k < psf.count; k++) { + const off = psf.offsetsMm[k] ?? 0; + const w = psf.weights[k] ?? 0; + if (!(w > 0)) continue; + + const wx = wx0 + s.normalDir.x * off; + const wy = wy0 + s.normalDir.y * off; + const wz = wz0 + s.normalDir.z * off; + + const vx = (wx - originMm.x) * invVox; + const vy = (wy - originMm.y) * invVox; + const vz = (wz - originMm.z) * invVox; + + if (!withinTrilinearSupport(dims, vx, vy, vz)) continue; + + const scale = w * scaleBase; + splatTrilinearScaled(update, updateW, dims, vx, vy, vz, residual, scale); + } + } + } + + if (sIdx % 8 === 0) { + await yieldToMain(); + } + } + + for (let i = 0; i < nvox; i++) { + const w = updateW[i]; + if (w > 1e-12) { + volume[i] = (volume[i] ?? 0) + ((update[i] ?? 0) / w) * stepSize; + } + + if (options.clampOutput) { + volume[i] = clamp01(volume[i] ?? 0); + } + } + + // Light regularization to suppress noise without erasing edges. + if (options.laplacianWeight > 0) { + update.fill(0); + laplacianSmoothInPlace(volume, dims, options.laplacianWeight, update); + if (options.clampOutput) { + for (let i = 0; i < nvox; i++) { + volume[i] = clamp01(volume[i] ?? 0); + } + } + } + + hooks?.onProgress?.({ + current: iter + 1, + total: iterations, + message: `Refining volume… iteration ${iter + 1}/${iterations}`, + }); + + await yieldToMain(); + } +} + +export async function resampleVolumeToGridTrilinear(params: { + src: Float32Array; + srcGrid: SvrReconstructionGrid; + dstGrid: SvrReconstructionGrid; + hooks?: SvrCoreHooks; +}): Promise { + const { src, srcGrid, dstGrid, hooks } = params; + const { dims: sDims, originMm: sOrigin, voxelSizeMm: sVox } = srcGrid; + const { dims: dDims, originMm: dOrigin, voxelSizeMm: dVox } = dstGrid; + + const yieldToMain = hooks?.yieldToMain ?? (async () => {}); + + const out = new Float32Array(dDims.nx * dDims.ny * dDims.nz); + + const invSrcVox = 1 / sVox; + + const strideY = dDims.nx; + const strideZ = dDims.nx * dDims.ny; + + for (let z = 0; z < dDims.nz; z++) { + assertNotAborted(hooks?.signal); + const wz = dOrigin.z + z * dVox; + + for (let y = 0; y < dDims.ny; y++) { + const wy = dOrigin.y + y * dVox; + + const base = z * strideZ + y * strideY; + + for (let x = 0; x < dDims.nx; x++) { + const wx = dOrigin.x + x * dVox; + + const sx = (wx - sOrigin.x) * invSrcVox; + const sy = (wy - sOrigin.y) * invSrcVox; + const sz = (wz - sOrigin.z) * invSrcVox; + + out[base + x] = withinTrilinearSupport(sDims, sx, sy, sz) ? sampleTrilinear(src, sDims, sx, sy, sz) : 0; + } + } + + if (z % 4 === 0) { + await yieldToMain(); + } + } + + return out; +} diff --git a/frontend/src/utils/svr/resample2d.ts b/frontend/src/utils/svr/resample2d.ts new file mode 100644 index 0000000..ea417d3 --- /dev/null +++ b/frontend/src/utils/svr/resample2d.ts @@ -0,0 +1,229 @@ +export function resample2dAreaAverage( + src: ArrayLike, + srcRows: number, + srcCols: number, + dstRows: number, + dstCols: number +): Float32Array { + const outRows = Math.max(0, Math.floor(dstRows)); + const outCols = Math.max(0, Math.floor(dstCols)); + + if (outRows === 0 || outCols === 0) { + return new Float32Array(0); + } + + const inRows = Math.max(0, Math.floor(srcRows)); + const inCols = Math.max(0, Math.floor(srcCols)); + + if (inRows === 0 || inCols === 0) { + return new Float32Array(outRows * outCols); + } + + // Fast path: no resampling. + if (inRows === outRows && inCols === outCols) { + const out = new Float32Array(outRows * outCols); + for (let i = 0; i < out.length; i++) { + out[i] = Number(src[i] ?? 0); + } + return out; + } + + // Box-filter (area) resampling. + // + // Model each source pixel as a constant value over the unit square [r,r+1)×[c,c+1). + // Each destination pixel corresponds to a box in source pixel coordinates: + // r ∈ [dr*rowScale, (dr+1)*rowScale), c ∈ [dc*colScale, (dc+1)*colScale) + // We compute the area-weighted average over that box. + const rowScale = inRows / outRows; + const colScale = inCols / outCols; + const invArea = 1 / (rowScale * colScale); + + const out = new Float32Array(outRows * outCols); + + for (let dr = 0; dr < outRows; dr++) { + const srcR0 = dr * rowScale; + const srcR1 = (dr + 1) * rowScale; + + const r0 = Math.max(0, Math.floor(srcR0)); + const r1 = Math.min(inRows, Math.ceil(srcR1)); + + const outRowBase = dr * outCols; + + for (let dc = 0; dc < outCols; dc++) { + const srcC0 = dc * colScale; + const srcC1 = (dc + 1) * colScale; + + const c0 = Math.max(0, Math.floor(srcC0)); + const c1 = Math.min(inCols, Math.ceil(srcC1)); + + let sum = 0; + + for (let r = r0; r < r1; r++) { + const rStart = Math.max(r, srcR0); + const rEnd = Math.min(r + 1, srcR1); + const wr = rEnd - rStart; + if (wr <= 0) continue; + + const srcRowBase = r * inCols; + + for (let c = c0; c < c1; c++) { + const cStart = Math.max(c, srcC0); + const cEnd = Math.min(c + 1, srcC1); + const wc = cEnd - cStart; + if (wc <= 0) continue; + + const v = Number(src[srcRowBase + c] ?? 0); + sum += v * wr * wc; + } + } + + out[outRowBase + dc] = sum * invArea; + } + } + + return out; +} + +function sinc(x: number): number { + if (x === 0) return 1; + const px = Math.PI * x; + return Math.sin(px) / px; +} + +function lanczos(x: number, a: number): number { + const ax = Math.abs(x); + if (ax >= a) return 0; + return sinc(x) * sinc(x / a); +} + +type Contrib = { idx0: number; idx1: number; w: Float32Array }; + +function buildLanczosContrib(inSize: number, outSize: number, a: number): Contrib[] { + const inN = Math.max(0, Math.floor(inSize)); + const outN = Math.max(0, Math.floor(outSize)); + + const scale = outN / Math.max(1, inN); + + // When downsampling (scale<1), widen the filter footprint to act as an anti-aliasing low-pass. + const kernelScale = scale < 1 ? scale : 1; + const radius = a / Math.max(1e-6, kernelScale); + + const contrib: Contrib[] = new Array(outN); + + for (let o = 0; o < outN; o++) { + // Map output sample centers to input coordinates. + // (Matches common image resampling conventions.) + const center = (o + 0.5) / Math.max(1e-6, scale) - 0.5; + + let i0 = Math.ceil(center - radius); + let i1 = Math.floor(center + radius); + + if (i0 < 0) i0 = 0; + if (i1 > inN - 1) i1 = inN - 1; + + const len = Math.max(0, i1 - i0 + 1); + const w = new Float32Array(len); + + let sum = 0; + for (let i = 0; i < len; i++) { + const idx = i0 + i; + const x = (center - idx) * kernelScale; + const wi = lanczos(x, a) * kernelScale; + w[i] = wi; + sum += wi; + } + + // Normalize so constants stay constant. + if (sum > 1e-12) { + const inv = 1 / sum; + for (let i = 0; i < w.length; i++) { + w[i] *= inv; + } + } else if (w.length > 0) { + // Degenerate case: fall back to nearest. + w.fill(0); + const nearest = Math.max(0, Math.min(w.length - 1, Math.round(center) - i0)); + w[nearest] = 1; + } + + contrib[o] = { idx0: i0, idx1: i1, w }; + } + + return contrib; +} + +export function resample2dLanczos3( + src: ArrayLike, + srcRows: number, + srcCols: number, + dstRows: number, + dstCols: number +): Float32Array { + const outRows = Math.max(0, Math.floor(dstRows)); + const outCols = Math.max(0, Math.floor(dstCols)); + + if (outRows === 0 || outCols === 0) { + return new Float32Array(0); + } + + const inRows = Math.max(0, Math.floor(srcRows)); + const inCols = Math.max(0, Math.floor(srcCols)); + + if (inRows === 0 || inCols === 0) { + return new Float32Array(outRows * outCols); + } + + // Fast path: no resampling. + if (inRows === outRows && inCols === outCols) { + const out = new Float32Array(outRows * outCols); + for (let i = 0; i < out.length; i++) { + out[i] = Number(src[i] ?? 0); + } + return out; + } + + const A = 3; + + const xContrib = buildLanczosContrib(inCols, outCols, A); + const yContrib = buildLanczosContrib(inRows, outRows, A); + + // Horizontal pass: src (inRows x inCols) -> tmp (inRows x outCols) + const tmp = new Float32Array(inRows * outCols); + + for (let r = 0; r < inRows; r++) { + const srcRowBase = r * inCols; + const tmpRowBase = r * outCols; + + for (let oc = 0; oc < outCols; oc++) { + const c = xContrib[oc]; + if (!c) continue; + + let sum = 0; + for (let i = 0; i < c.w.length; i++) { + sum += Number(src[srcRowBase + c.idx0 + i] ?? 0) * c.w[i]; + } + tmp[tmpRowBase + oc] = sum; + } + } + + // Vertical pass: tmp (inRows x outCols) -> out (outRows x outCols) + const out = new Float32Array(outRows * outCols); + + for (let or = 0; or < outRows; or++) { + const c = yContrib[or]; + if (!c) continue; + + const outRowBase = or * outCols; + + for (let oc = 0; oc < outCols; oc++) { + let sum = 0; + for (let i = 0; i < c.w.length; i++) { + const rr = c.idx0 + i; + sum += tmp[rr * outCols + oc] * c.w[i]; + } + out[outRowBase + oc] = sum; + } + } + + return out; +} diff --git a/frontend/src/utils/svr/sliceRoiCrop.ts b/frontend/src/utils/svr/sliceRoiCrop.ts new file mode 100644 index 0000000..3bedf5f --- /dev/null +++ b/frontend/src/utils/svr/sliceRoiCrop.ts @@ -0,0 +1,112 @@ +import type { Vec3 } from './vec3'; +import { dot, v3 } from './vec3'; + +export type BoundsMm = { min: Vec3; max: Vec3 }; + +export type CropSlice = { + pixels: Float32Array; + dsRows: number; + dsCols: number; + + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + normalDir: Vec3; + + rowSpacingDsMm: number; + colSpacingDsMm: number; +}; + +export function boundsCornersMm(bounds: BoundsMm): Vec3[] { + const xs = [bounds.min.x, bounds.max.x]; + const ys = [bounds.min.y, bounds.max.y]; + const zs = [bounds.min.z, bounds.max.z]; + + const corners: Vec3[] = []; + for (const x of xs) { + for (const y of ys) { + for (const z of zs) { + corners.push(v3(x, y, z)); + } + } + } + return corners; +} + +export function cropSliceToRoiInPlace(slice: CropSlice, roiCorners: Vec3[]): boolean { + // Reject slices whose plane does not intersect the ROI slab along its normal. + const n = slice.normalDir; + const planeD = dot(slice.ippMm, n); + + let minD = Number.POSITIVE_INFINITY; + let maxD = Number.NEGATIVE_INFINITY; + for (const c of roiCorners) { + const d = dot(c, n); + if (d < minD) minD = d; + if (d > maxD) maxD = d; + } + + // Small tolerance to avoid dropping boundary slices due to float noise. + const tol = 1e-3; + if (planeD < minD - tol || planeD > maxD + tol) { + return false; + } + + // Compute a conservative pixel-space bounding box by projecting ROI corners into the slice basis. + let minR = Number.POSITIVE_INFINITY; + let maxR = Number.NEGATIVE_INFINITY; + let minC = Number.POSITIVE_INFINITY; + let maxC = Number.NEGATIVE_INFINITY; + + for (const p of roiCorners) { + const dx = p.x - slice.ippMm.x; + const dy = p.y - slice.ippMm.y; + const dz = p.z - slice.ippMm.z; + + // DICOM mapping: world(r,c) = IPP + colDir*(r*rowSpacing) + rowDir*(c*colSpacing) + const r = (dx * slice.colDir.x + dy * slice.colDir.y + dz * slice.colDir.z) / slice.rowSpacingDsMm; + const c = (dx * slice.rowDir.x + dy * slice.rowDir.y + dz * slice.rowDir.z) / slice.colSpacingDsMm; + + if (r < minR) minR = r; + if (r > maxR) maxR = r; + if (c < minC) minC = c; + if (c > maxC) maxC = c; + } + + if (!Number.isFinite(minR) || !Number.isFinite(minC)) return false; + + // Expand slightly; we want to be conservative. + const r0 = Math.max(0, Math.min(slice.dsRows - 1, Math.floor(minR) - 1)); + const r1 = Math.max(0, Math.min(slice.dsRows - 1, Math.ceil(maxR) + 1)); + const c0 = Math.max(0, Math.min(slice.dsCols - 1, Math.floor(minC) - 1)); + const c1 = Math.max(0, Math.min(slice.dsCols - 1, Math.ceil(maxC) + 1)); + + if (r1 < r0 || c1 < c0) return false; + + const nextRows = r1 - r0 + 1; + const nextCols = c1 - c0 + 1; + + const oldCols = slice.dsCols; + const oldPixels = slice.pixels; + + const nextPixels = new Float32Array(nextRows * nextCols); + + for (let r = r0; r <= r1; r++) { + const oldBase = r * oldCols + c0; + const newBase = (r - r0) * nextCols; + nextPixels.set(oldPixels.subarray(oldBase, oldBase + nextCols), newBase); + } + + // Shift IPP so (r0,c0) becomes the new (0,0) for the cropped pixel buffer. + slice.ippMm = v3( + slice.ippMm.x + slice.colDir.x * (r0 * slice.rowSpacingDsMm) + slice.rowDir.x * (c0 * slice.colSpacingDsMm), + slice.ippMm.y + slice.colDir.y * (r0 * slice.rowSpacingDsMm) + slice.rowDir.y * (c0 * slice.colSpacingDsMm), + slice.ippMm.z + slice.colDir.z * (r0 * slice.rowSpacingDsMm) + slice.rowDir.z * (c0 * slice.colSpacingDsMm) + ); + + slice.dsRows = nextRows; + slice.dsCols = nextCols; + slice.pixels = nextPixels; + + return true; +} diff --git a/frontend/src/utils/svr/trilinear.ts b/frontend/src/utils/svr/trilinear.ts new file mode 100644 index 0000000..1718719 --- /dev/null +++ b/frontend/src/utils/svr/trilinear.ts @@ -0,0 +1,142 @@ +export type VolumeDims = { nx: number; ny: number; nz: number }; + +function idxOf(x: number, y: number, z: number, dims: VolumeDims): number { + return x + y * dims.nx + z * dims.nx * dims.ny; +} + +export function sampleTrilinear(volume: Float32Array, dims: VolumeDims, x: number, y: number, z: number): number { + const { nx, ny, nz } = dims; + + const x0 = Math.floor(x); + const y0 = Math.floor(y); + const z0 = Math.floor(z); + + const x1 = x0 + 1; + const y1 = y0 + 1; + const z1 = z0 + 1; + + if (x0 < 0 || y0 < 0 || z0 < 0 || x1 >= nx || y1 >= ny || z1 >= nz) { + return 0; + } + + const fx = x - x0; + const fy = y - y0; + const fz = z - z0; + + const wx0 = 1 - fx; + const wy0 = 1 - fy; + const wz0 = 1 - fz; + const wx1 = fx; + const wy1 = fy; + const wz1 = fz; + + const c000 = volume[idxOf(x0, y0, z0, dims)]; + const c100 = volume[idxOf(x1, y0, z0, dims)]; + const c010 = volume[idxOf(x0, y1, z0, dims)]; + const c110 = volume[idxOf(x1, y1, z0, dims)]; + const c001 = volume[idxOf(x0, y0, z1, dims)]; + const c101 = volume[idxOf(x1, y0, z1, dims)]; + const c011 = volume[idxOf(x0, y1, z1, dims)]; + const c111 = volume[idxOf(x1, y1, z1, dims)]; + + const v00 = c000 * wx0 + c100 * wx1; + const v10 = c010 * wx0 + c110 * wx1; + const v01 = c001 * wx0 + c101 * wx1; + const v11 = c011 * wx0 + c111 * wx1; + + const v0 = v00 * wy0 + v10 * wy1; + const v1 = v01 * wy0 + v11 * wy1; + + return v0 * wz0 + v1 * wz1; +} + +export function splatTrilinear( + accum: Float32Array, + weight: Float32Array, + dims: VolumeDims, + x: number, + y: number, + z: number, + value: number +): void { + splatTrilinearScaled(accum, weight, dims, x, y, z, value, 1); +} + +export function splatTrilinearScaled( + accum: Float32Array, + weight: Float32Array, + dims: VolumeDims, + x: number, + y: number, + z: number, + value: number, + weightScale: number +): void { + const { nx, ny, nz } = dims; + + const x0 = Math.floor(x); + const y0 = Math.floor(y); + const z0 = Math.floor(z); + + const x1 = x0 + 1; + const y1 = y0 + 1; + const z1 = z0 + 1; + + if (x0 < 0 || y0 < 0 || z0 < 0 || x1 >= nx || y1 >= ny || z1 >= nz) { + return; + } + + const fx = x - x0; + const fy = y - y0; + const fz = z - z0; + + const wx0 = 1 - fx; + const wy0 = 1 - fy; + const wz0 = 1 - fz; + const wx1 = fx; + const wy1 = fy; + const wz1 = fz; + + const w000 = wx0 * wy0 * wz0; + const w100 = wx1 * wy0 * wz0; + const w010 = wx0 * wy1 * wz0; + const w110 = wx1 * wy1 * wz0; + const w001 = wx0 * wy0 * wz1; + const w101 = wx1 * wy0 * wz1; + const w011 = wx0 * wy1 * wz1; + const w111 = wx1 * wy1 * wz1; + + const s = Number.isFinite(weightScale) ? weightScale : 0; + + let idx = idxOf(x0, y0, z0, dims); + accum[idx] += value * (w000 * s); + weight[idx] += w000 * s; + + idx = idxOf(x1, y0, z0, dims); + accum[idx] += value * (w100 * s); + weight[idx] += w100 * s; + + idx = idxOf(x0, y1, z0, dims); + accum[idx] += value * (w010 * s); + weight[idx] += w010 * s; + + idx = idxOf(x1, y1, z0, dims); + accum[idx] += value * (w110 * s); + weight[idx] += w110 * s; + + idx = idxOf(x0, y0, z1, dims); + accum[idx] += value * (w001 * s); + weight[idx] += w001 * s; + + idx = idxOf(x1, y0, z1, dims); + accum[idx] += value * (w101 * s); + weight[idx] += w101 * s; + + idx = idxOf(x0, y1, z1, dims); + accum[idx] += value * (w011 * s); + weight[idx] += w011 * s; + + idx = idxOf(x1, y1, z1, dims); + accum[idx] += value * (w111 * s); + weight[idx] += w111 * s; +} diff --git a/frontend/src/utils/svr/vec3.ts b/frontend/src/utils/svr/vec3.ts new file mode 100644 index 0000000..5e8d4a8 --- /dev/null +++ b/frontend/src/utils/svr/vec3.ts @@ -0,0 +1,39 @@ +export type Vec3 = { x: number; y: number; z: number }; + +export function v3(x: number, y: number, z: number): Vec3 { + return { x, y, z }; +} + +export function add(a: Vec3, b: Vec3): Vec3 { + return { x: a.x + b.x, y: a.y + b.y, z: a.z + b.z }; +} + +export function sub(a: Vec3, b: Vec3): Vec3 { + return { x: a.x - b.x, y: a.y - b.y, z: a.z - b.z }; +} + +export function scale(a: Vec3, s: number): Vec3 { + return { x: a.x * s, y: a.y * s, z: a.z * s }; +} + +export function dot(a: Vec3, b: Vec3): number { + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +export function cross(a: Vec3, b: Vec3): Vec3 { + return { + x: a.y * b.z - a.z * b.y, + y: a.z * b.x - a.x * b.z, + z: a.x * b.y - a.y * b.x, + }; +} + +export function norm(a: Vec3): number { + return Math.sqrt(dot(a, a)); +} + +export function normalize(a: Vec3): Vec3 { + const n = norm(a); + if (!Number.isFinite(n) || n <= 0) return { x: 0, y: 0, z: 0 }; + return scale(a, 1 / n); +} diff --git a/frontend/src/utils/svr/volumePreview.ts b/frontend/src/utils/svr/volumePreview.ts new file mode 100644 index 0000000..5204c8e --- /dev/null +++ b/frontend/src/utils/svr/volumePreview.ts @@ -0,0 +1,112 @@ +import type { SvrPreviewImages } from '../../types/svr'; +import type { VolumeDims } from './trilinear'; + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +function sliceToImageData(params: { + width: number; + height: number; + getValue: (x: number, y: number) => number; +}): ImageData { + const { width, height, getValue } = params; + + const img = new ImageData(width, height); + const data = img.data; + + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const v = clamp01(getValue(x, y)); + const b = Math.round(v * 255); + const idx = (y * width + x) * 4; + data[idx] = b; + data[idx + 1] = b; + data[idx + 2] = b; + data[idx + 3] = 255; + } + } + + return img; +} + +async function imageDataToPng(imageData: ImageData, maxSize: number): Promise { + const srcCanvas = document.createElement('canvas'); + srcCanvas.width = imageData.width; + srcCanvas.height = imageData.height; + const srcCtx = srcCanvas.getContext('2d'); + if (!srcCtx) throw new Error('Failed to get canvas context'); + srcCtx.putImageData(imageData, 0, 0); + + const maxDim = Math.max(imageData.width, imageData.height); + const scale = maxDim > maxSize ? maxSize / maxDim : 1; + + const outCanvas = document.createElement('canvas'); + outCanvas.width = Math.max(1, Math.round(imageData.width * scale)); + outCanvas.height = Math.max(1, Math.round(imageData.height * scale)); + + const outCtx = outCanvas.getContext('2d'); + if (!outCtx) throw new Error('Failed to get canvas context'); + outCtx.imageSmoothingEnabled = true; + outCtx.imageSmoothingQuality = 'high'; + outCtx.drawImage(srcCanvas, 0, 0, outCanvas.width, outCanvas.height); + + return await new Promise((resolve, reject) => { + outCanvas.toBlob((blob) => { + if (blob) resolve(blob); + else reject(new Error('Failed to encode PNG')); + }, 'image/png'); + }); +} + +export async function generateVolumePreviews(params: { + volume: Float32Array; + dims: VolumeDims; + maxSize: number; +}): Promise { + const { volume, dims, maxSize } = params; + const { nx, ny, nz } = dims; + + if (typeof document === 'undefined') { + throw new Error('generateVolumePreviews requires a DOM'); + } + + const midX = Math.floor(nx / 2); + const midY = Math.floor(ny / 2); + const midZ = Math.floor(nz / 2); + + const strideXY = nx * ny; + + // Axial: X (width) × Y (height) at Z=mid + const axial = sliceToImageData({ + width: nx, + height: ny, + getValue: (x, y) => volume[x + y * nx + midZ * strideXY] ?? 0, + }); + + // Coronal: X (width) × Z (height) at Y=mid + const coronal = sliceToImageData({ + width: nx, + height: nz, + getValue: (x, z) => volume[x + midY * nx + z * strideXY] ?? 0, + }); + + // Sagittal: Y (width) × Z (height) at X=mid + const sagittal = sliceToImageData({ + width: ny, + height: nz, + getValue: (y, z) => volume[midX + y * nx + z * strideXY] ?? 0, + }); + + const [axialPng, coronalPng, sagittalPng] = await Promise.all([ + imageDataToPng(axial, maxSize), + imageDataToPng(coronal, maxSize), + imageDataToPng(sagittal, maxSize), + ]); + + return { + axial: axialPng, + coronal: coronalPng, + sagittal: sagittalPng, + }; +} diff --git a/frontend/src/utils/tumorPropagation.ts b/frontend/src/utils/tumorPropagation.ts new file mode 100644 index 0000000..179be5f --- /dev/null +++ b/frontend/src/utils/tumorPropagation.ts @@ -0,0 +1,210 @@ +import cornerstone from 'cornerstone-core'; +import type { NormalizedPoint, TumorPolygon, TumorThreshold, ViewerTransform } from '../db/schema'; +import { getSortedSopInstanceUidsForSeries, saveTumorSegmentation } from './localApi'; +import type { SegmentTumorOptions } from './segmentation/segmentTumor'; +import { imageNormToViewerNorm, viewerNormToImageNorm } from './viewportMapping'; +import { propagateTumorAcrossFramesCore } from './tumorPropagationCore'; +import { remapPointBetweenViewerTransforms, type ViewportSize } from './viewTransform'; + +type CornerstoneImageLike = { + rows: number; + columns: number; + getPixelData: () => ArrayLike; + minPixelValue?: number; + maxPixelValue?: number; +}; + +const IDENTITY_VIEWER_TRANSFORM: ViewerTransform = { + zoom: 1, + rotation: 0, + panX: 0, + panY: 0, + affine00: 1, + affine01: 0, + affine10: 0, + affine11: 1, +}; + +function clamp01(v: number) { + return Math.max(0, Math.min(1, v)); +} + +function toByte(v: number) { + return Math.max(0, Math.min(255, Math.round(v))); +} + +function normalizeToByteArray(pixelData: ArrayLike, min: number, max: number): Uint8Array { + const n = pixelData.length; + const out = new Uint8Array(n); + const denom = max - min; + if (!Number.isFinite(denom) || Math.abs(denom) < 1e-8) { + out.fill(0); + return out; + } + + for (let i = 0; i < n; i++) { + const t = (pixelData[i] - min) / denom; + out[i] = toByte(t * 255); + } + return out; +} + + +export type PropagateAcrossSeriesInput = { + comboId: string; + dateIso: string; + studyId: string; + seriesUid: string; + + /** Size of the viewer viewport (used to map between viewer coords and image coords). */ + viewportSize: ViewportSize; + + /** Starting slice index in effective series ordering (0..N-1). */ + startEffectiveIndex: number; + + /** Seed point in normalized *viewer* coordinates. */ + seed: NormalizedPoint; + + /** Viewer transform that `seed` was authored under. Defaults to identity. */ + seedViewTransform?: ViewerTransform; + + threshold: TumorThreshold; + + /** Optional segmentation option overrides during propagation. */ + opts?: SegmentTumorOptions; + + stop: { + minAreaPx: number; + maxMissesInARow: number; + }; + + onProgress?: (p: { direction: 'left' | 'right'; index: number; saved: number; misses: number }) => void; +}; + +export async function propagateTumorAcrossSeries(input: PropagateAcrossSeriesInput): Promise<{ saved: number }> { + const uids = await getSortedSopInstanceUidsForSeries(input.seriesUid); + const n = uids.length; + if (n <= 0) return { saved: 0 }; + + const viewSize = input.viewportSize; + if (viewSize.w <= 0 || viewSize.h <= 0) { + throw new Error('Viewer size not available for propagation'); + } + + // Remap the stored seed back into the identity viewer transform. + const seedView = remapPointBetweenViewerTransforms( + input.seed, + viewSize, + input.seedViewTransform ?? IDENTITY_VIEWER_TRANSFORM, + IDENTITY_VIEWER_TRANSFORM + ); + + const start = Math.max(0, Math.min(n - 1, input.startEffectiveIndex)); + + const getFrame = async (index: number) => { + const sop = uids[index]; + if (!sop) return null; + + const imageId = `miradb:${sop}`; + const image = (await cornerstone.loadImage(imageId)) as unknown as CornerstoneImageLike; + + const rows = image.rows; + const cols = image.columns; + const getPixelData = image.getPixelData; + if (!rows || !cols || typeof getPixelData !== 'function') { + return null; + } + + const imgSize = { w: cols, h: rows }; + const seedImg = viewerNormToImageNorm(seedView, viewSize, imgSize); + + // Seed jitter: add a small cross around the centroid to make region growing less brittle. + // Ensure the jitter moves at least ~1 pixel in either direction. + const jitter = Math.max(0.002, 1 / Math.max(cols, rows)); + const seedPointsImg: NormalizedPoint[] = [ + { x: clamp01(seedImg.x), y: clamp01(seedImg.y) }, + { x: clamp01(seedImg.x + jitter), y: clamp01(seedImg.y) }, + { x: clamp01(seedImg.x - jitter), y: clamp01(seedImg.y) }, + { x: clamp01(seedImg.x), y: clamp01(seedImg.y + jitter) }, + { x: clamp01(seedImg.x), y: clamp01(seedImg.y - jitter) }, + ]; + + const pd = getPixelData(); + + let min = + typeof image.minPixelValue === 'number' && Number.isFinite(image.minPixelValue) + ? image.minPixelValue + : Number.POSITIVE_INFINITY; + let max = + typeof image.maxPixelValue === 'number' && Number.isFinite(image.maxPixelValue) + ? image.maxPixelValue + : Number.NEGATIVE_INFINITY; + + if (!Number.isFinite(min) || !Number.isFinite(max)) { + min = Number.POSITIVE_INFINITY; + max = Number.NEGATIVE_INFINITY; + for (let j = 0; j < pd.length; j++) { + const v = pd[j]; + if (v < min) min = v; + if (v > max) max = v; + } + } + + const gray = normalizeToByteArray(pd, min, max); + + return { + sopInstanceUid: sop, + w: cols, + h: rows, + gray, + seedPointsNorm: seedPointsImg, + }; + }; + + const viewportSize = { w: Math.round(viewSize.w), h: Math.round(viewSize.h) }; + + const res = await propagateTumorAcrossFramesCore({ + minIndex: 0, + maxIndex: n - 1, + startEffectiveIndex: start, + getFrame, + threshold: input.threshold, + opts: input.opts, + stop: input.stop, + onProgress: input.onProgress, + onAcceptedResult: async ({ sopInstanceUid, segmentation }) => { + if (!sopInstanceUid) { + throw new Error('Missing SOPInstanceUID for propagated slice'); + } + + const imgSize = { w: segmentation.meta.imageWidth, h: segmentation.meta.imageHeight }; + + // Convert the predicted polygon/seed into normalized viewer coordinates under an identity transform, + // so it can later be re-projected correctly under pan/zoom/rotation/affine. + const polygonViewer: TumorPolygon = { + points: segmentation.polygon.points.map((p) => imageNormToViewerNorm(p, viewSize, imgSize)), + }; + const seedViewer = imageNormToViewerNorm(segmentation.seed, viewSize, imgSize); + + await saveTumorSegmentation({ + comboId: input.comboId, + dateIso: input.dateIso, + studyId: input.studyId, + seriesUid: input.seriesUid, + sopInstanceUid, + polygon: polygonViewer, + threshold: input.threshold, + seed: seedViewer, + meta: { + areaPx: segmentation.meta.areaPx, + areaNorm: segmentation.meta.areaNorm, + viewTransform: IDENTITY_VIEWER_TRANSFORM, + viewportSize, + }, + algorithmVersion: 'v2-propagation-viewer-seed-remap', + }); + }, + }); + + return { saved: res.saved }; +} diff --git a/frontend/src/utils/tumorPropagationCore.ts b/frontend/src/utils/tumorPropagationCore.ts new file mode 100644 index 0000000..2aa4338 --- /dev/null +++ b/frontend/src/utils/tumorPropagationCore.ts @@ -0,0 +1,115 @@ +import type { NormalizedPoint, TumorPolygon, TumorThreshold } from '../db/schema'; +import type { SegmentTumorOptions, SegmentationResult } from './segmentation/segmentTumor'; +import { segmentTumorFromGrayscale } from './segmentation/segmentTumor'; + +export type PropagationFrame = { + sopInstanceUid?: string; + w: number; + h: number; + gray: Uint8Array; + /** Seed/paint points in normalized image coords (0..1). */ + seedPointsNorm: NormalizedPoint[]; +}; + +export type PropagateTumorCoreInput = { + minIndex: number; + maxIndex: number; + startEffectiveIndex: number; + + getFrame: (effectiveIndex: number) => Promise; + + threshold: TumorThreshold; + opts?: SegmentTumorOptions; + + stop: { + minAreaPx: number; + maxMissesInARow: number; + }; + + onProgress?: (p: { direction: 'left' | 'right'; index: number; saved: number; misses: number }) => void; + + /** Optional callback invoked when a slice segmentation is accepted (area >= minAreaPx). */ + onAcceptedResult?: (r: { + direction: 'left' | 'right'; + index: number; + sopInstanceUid?: string; + segmentation: SegmentationResult; + }) => Promise | void; +}; + +export type PropagateTumorCoreResult = { + saved: number; + results: Array<{ index: number; sopInstanceUid?: string; segmentation: SegmentationResult }>; +}; + +export async function propagateTumorAcrossFramesCore(input: PropagateTumorCoreInput): Promise { + const minIndex = Math.min(input.minIndex, input.maxIndex); + const maxIndex = Math.max(input.minIndex, input.maxIndex); + + const start = Math.max(minIndex, Math.min(maxIndex, Math.round(input.startEffectiveIndex))); + + let saved = 0; + const results: PropagateTumorCoreResult['results'] = []; + + const runDir = async (direction: 'left' | 'right') => { + let misses = 0; + const step = direction === 'left' ? -1 : 1; + + for (let i = start + step; i >= minIndex && i <= maxIndex; i += step) { + try { + const frame = await input.getFrame(i); + if (!frame) { + misses++; + input.onProgress?.({ direction, index: i, saved, misses }); + if (misses >= input.stop.maxMissesInARow) break; + continue; + } + + const { gray, w, h, seedPointsNorm } = frame; + const seg = segmentTumorFromGrayscale(gray, w, h, seedPointsNorm, input.threshold, input.opts); + + if (seg.meta.areaPx < input.stop.minAreaPx) { + misses++; + input.onProgress?.({ direction, index: i, saved, misses }); + if (misses >= input.stop.maxMissesInARow) break; + continue; + } + + // Allow the caller to persist/stream results. If it throws, treat it like a miss (same behavior + // as the old propagation adapter which wrapped segmentation+save in a try/catch). + if (input.onAcceptedResult) { + await input.onAcceptedResult({ + direction, + index: i, + sopInstanceUid: frame.sopInstanceUid, + segmentation: seg, + }); + } + + results.push({ index: i, sopInstanceUid: frame.sopInstanceUid, segmentation: seg }); + saved++; + misses = 0; + input.onProgress?.({ direction, index: i, saved, misses }); + } catch (e) { + console.warn('[propagateTumorAcrossFramesCore] Failed slice', direction, i, e); + misses++; + input.onProgress?.({ direction, index: i, saved, misses }); + if (misses >= input.stop.maxMissesInARow) break; + } + } + }; + + await runDir('left'); + await runDir('right'); + + return { saved, results }; +} + +// Backwards-compatible helper for harness code that wants polygons without the full segmentation result. +export function getPolygonsByIndexFromCoreResult(res: PropagateTumorCoreResult): Map { + const out = new Map(); + for (const r of res.results) { + out.set(r.index, r.segmentation.polygon); + } + return out; +} diff --git a/frontend/src/utils/viewTransform.ts b/frontend/src/utils/viewTransform.ts new file mode 100644 index 0000000..6620887 --- /dev/null +++ b/frontend/src/utils/viewTransform.ts @@ -0,0 +1,149 @@ +import type { NormalizedPoint, TumorPolygon, ViewerTransform } from '../db/schema'; + +export type ViewportSize = { w: number; h: number }; + +export function normalizeViewerTransform(t?: ViewerTransform | null): ViewerTransform { + return { + zoom: Number.isFinite(t?.zoom) ? t!.zoom : 1, + rotation: Number.isFinite(t?.rotation) ? t!.rotation : 0, + panX: Number.isFinite(t?.panX) ? t!.panX : 0, + panY: Number.isFinite(t?.panY) ? t!.panY : 0, + affine00: Number.isFinite(t?.affine00) ? t!.affine00 : 1, + affine01: Number.isFinite(t?.affine01) ? t!.affine01 : 0, + affine10: Number.isFinite(t?.affine10) ? t!.affine10 : 0, + affine11: Number.isFinite(t?.affine11) ? t!.affine11 : 1, + }; +} + +function computeLinearMatrix(t: ViewerTransform): { m00: number; m01: number; m10: number; m11: number } { + const vt = normalizeViewerTransform(t); + + const theta = (vt.rotation * Math.PI) / 180; + const cos = Math.cos(theta); + const sin = Math.sin(theta); + + // CSS/canvas rotation in screen coords (y-down) is clockwise for +theta. + // Using the standard matrix with y-down matches CSS rotate() and ctx.rotate(). + const r00 = cos; + const r01 = -sin; + const r10 = sin; + const r11 = cos; + + // Affine residual A is row-major 2x2: [[a00,a01],[a10,a11]]. + const a00 = vt.affine00; + const a01 = vt.affine01; + const a10 = vt.affine10; + const a11 = vt.affine11; + + // M = zoom * R * A + const ra00 = r00 * a00 + r01 * a10; + const ra01 = r00 * a01 + r01 * a11; + const ra10 = r10 * a00 + r11 * a10; + const ra11 = r10 * a01 + r11 * a11; + + const z = vt.zoom; + return { + m00: z * ra00, + m01: z * ra01, + m10: z * ra10, + m11: z * ra11, + }; +} + +function applyViewerTransformPx( + p: { x: number; y: number }, + size: ViewportSize, + t: ViewerTransform +): { x: number; y: number } { + const { w, h } = size; + + const cx = w / 2; + const cy = h / 2; + + const vt = normalizeViewerTransform(t); + const panXPx = vt.panX * w; + const panYPx = vt.panY * h; + + const { m00, m01, m10, m11 } = computeLinearMatrix(vt); + + const dx = p.x - cx; + const dy = p.y - cy; + + return { + x: cx + panXPx + (m00 * dx + m01 * dy), + y: cy + panYPx + (m10 * dx + m11 * dy), + }; +} + +function invertViewerTransformPx( + p: { x: number; y: number }, + size: ViewportSize, + t: ViewerTransform +): { x: number; y: number } { + const { w, h } = size; + + const cx = w / 2; + const cy = h / 2; + + const vt = normalizeViewerTransform(t); + const panXPx = vt.panX * w; + const panYPx = vt.panY * h; + + const { m00, m01, m10, m11 } = computeLinearMatrix(vt); + const det = m00 * m11 - m01 * m10; + + // If the matrix is singular (shouldn't happen in normal use), fall back to identity. + if (!Number.isFinite(det) || Math.abs(det) < 1e-10) { + return { x: p.x, y: p.y }; + } + + const inv00 = m11 / det; + const inv01 = -m01 / det; + const inv10 = -m10 / det; + const inv11 = m00 / det; + + const dx = p.x - cx - panXPx; + const dy = p.y - cy - panYPx; + + return { + x: cx + (inv00 * dx + inv01 * dy), + y: cy + (inv10 * dx + inv11 * dy), + }; +} + +export function remapPointBetweenViewerTransforms( + p: NormalizedPoint, + size: ViewportSize, + from: ViewerTransform, + to: ViewerTransform +): NormalizedPoint { + if (size.w <= 0 || size.h <= 0) return p; + + const pPx = { x: p.x * size.w, y: p.y * size.h }; + const worldPx = invertViewerTransformPx(pPx, size, from); + const outPx = applyViewerTransformPx(worldPx, size, to); + return { + x: outPx.x / size.w, + y: outPx.y / size.h, + }; +} + +export function remapPointsBetweenViewerTransforms( + points: NormalizedPoint[], + size: ViewportSize, + from: ViewerTransform, + to: ViewerTransform +): NormalizedPoint[] { + return points.map((p) => remapPointBetweenViewerTransforms(p, size, from, to)); +} + +export function remapPolygonBetweenViewerTransforms( + polygon: TumorPolygon, + size: ViewportSize, + from: ViewerTransform, + to: ViewerTransform +): TumorPolygon { + return { + points: remapPointsBetweenViewerTransforms(polygon.points ?? [], size, from, to), + }; +} diff --git a/frontend/src/utils/viewportMapping.ts b/frontend/src/utils/viewportMapping.ts new file mode 100644 index 0000000..8338692 --- /dev/null +++ b/frontend/src/utils/viewportMapping.ts @@ -0,0 +1,51 @@ +import type { NormalizedPoint } from '../db/schema'; +import type { ViewportSize } from './viewTransform'; + +export type ImageSizePx = { w: number; h: number }; + +function clamp01(v: number) { + return Math.max(0, Math.min(1, v)); +} + +/** + * Mirror the viewer's "contain" behavior: scale to fit while preserving aspect ratio. + * Returns the image rect in viewport pixel coordinates. + */ +export function containRectPx(view: ViewportSize, img: ImageSizePx): { dx: number; dy: number; dw: number; dh: number } { + const vw = Math.max(1, view.w); + const vh = Math.max(1, view.h); + const iw = Math.max(1, img.w); + const ih = Math.max(1, img.h); + + const scale = Math.min(vw / iw, vh / ih); + const dw = iw * scale; + const dh = ih * scale; + const dx = (vw - dw) / 2; + const dy = (vh - dh) / 2; + + return { dx, dy, dw, dh }; +} + +export function viewerNormToImageNorm(p: NormalizedPoint, view: ViewportSize, img: ImageSizePx): NormalizedPoint { + const { dx, dy, dw, dh } = containRectPx(view, img); + + const xPx = clamp01(p.x) * Math.max(1, view.w); + const yPx = clamp01(p.y) * Math.max(1, view.h); + + const xi = dw > 1e-6 ? (xPx - dx) / dw : 0; + const yi = dh > 1e-6 ? (yPx - dy) / dh : 0; + + return { x: clamp01(xi), y: clamp01(yi) }; +} + +export function imageNormToViewerNorm(p: NormalizedPoint, view: ViewportSize, img: ImageSizePx): NormalizedPoint { + const { dx, dy, dw, dh } = containRectPx(view, img); + + const xPx = dx + clamp01(p.x) * dw; + const yPx = dy + clamp01(p.y) * dh; + + const xv = xPx / Math.max(1, view.w); + const yv = yPx / Math.max(1, view.h); + + return { x: clamp01(xv), y: clamp01(yv) }; +} diff --git a/frontend/tests/alignmentSliceSearch.test.ts b/frontend/tests/alignmentSliceSearch.test.ts new file mode 100644 index 0000000..1e63004 --- /dev/null +++ b/frontend/tests/alignmentSliceSearch.test.ts @@ -0,0 +1,159 @@ +import { describe, expect, test } from 'vitest'; +import { findBestMatchingSlice } from '../src/utils/alignment'; + +function makeDeterministicRandomBinary(n: number, seed: number): Float32Array { + let s = seed >>> 0; + const out = new Float32Array(n); + for (let i = 0; i < n; i++) { + // LCG (Numerical Recipes) + s = (1664525 * s + 1013904223) >>> 0; + const u = (s >>> 8) / 0x01000000; + out[i] = u < 0.5 ? 0 : 1; + } + return out; +} + +function makeDeterministicRandomFloat(n: number, seed: number): Float32Array { + let s = seed >>> 0; + const out = new Float32Array(n); + for (let i = 0; i < n; i++) { + s = (1664525 * s + 1013904223) >>> 0; + const u = (s >>> 8) / 0x01000000; + out[i] = u; + } + return out; +} + +function affineIntensity(a: Float32Array, scale: number, offset: number): Float32Array { + const out = new Float32Array(a.length); + for (let i = 0; i < a.length; i++) { + const v = (a[i] ?? 0) * scale + offset; + out[i] = Math.max(0, Math.min(1, v)); + } + return out; +} + +function addNoise(a: Float32Array, sigma: number, seed: number): Float32Array { + let s = seed >>> 0; + const out = new Float32Array(a.length); + for (let i = 0; i < a.length; i++) { + s = (1664525 * s + 1013904223) >>> 0; + // Roughly uniform noise in [-sigma, sigma] + const u = (s >>> 8) / 0x01000000; + const n = (u * 2 - 1) * sigma; + const v = (a[i] ?? 0) + n; + out[i] = Math.max(0, Math.min(1, v)); + } + return out; +} + +function flipBinaryWithProb(a: Float32Array, flipProb: number, seed: number): Float32Array { + let s = seed >>> 0; + const out = new Float32Array(a.length); + for (let i = 0; i < a.length; i++) { + s = (1664525 * s + 1013904223) >>> 0; + const u = (s >>> 8) / 0x01000000; + const v = a[i] ?? 0; + out[i] = u < flipProb ? 1 - v : v; + } + return out; +} + +describe('findBestMatchingSlice', () => { + test('minSearchRadius prevents early-stop misses when the true peak is a few slices away', async () => { + // Use a square length so alignment.ts can infer imageWidth/imageHeight. + const size = 64; + const n = size * size; + + const reference = makeDeterministicRandomBinary(n, 123); + + // Construct a score landscape where slices 1..2 look progressively worse, + // but slice 3 is much better (true peak). + const slices: Float32Array[] = []; + slices[0] = flipBinaryWithProb(reference, 0.2, 1); + slices[1] = flipBinaryWithProb(reference, 0.3, 2); + slices[2] = flipBinaryWithProb(reference, 0.4, 3); + slices[3] = flipBinaryWithProb(reference, 0.05, 4); + slices[4] = flipBinaryWithProb(reference, 0.45, 5); + slices[5] = flipBinaryWithProb(reference, 0.5, 6); + + const getSlice = async (idx: number) => { + const s = slices[idx]; + if (!s) throw new Error('missing slice'); + return s; + }; + + const noMin = await findBestMatchingSlice(reference, getSlice, 0, 1, slices.length, undefined, { + miBins: 32, + stopDecreaseStreak: 2, + minSearchRadius: 0, + }); + + const withMin = await findBestMatchingSlice(reference, getSlice, 0, 1, slices.length, undefined, { + miBins: 32, + stopDecreaseStreak: 2, + minSearchRadius: 3, + }); + + expect(noMin.bestIndex).not.toBe(3); + expect(withMin.bestIndex).toBe(3); + }); + + test('scoreMetric=mind prefers same-structure slices despite intensity remapping', async () => { + const size = 64; + const n = size * size; + + const reference = makeDeterministicRandomFloat(n, 42); + + // Best match: intensity remapped reference (scale + offset). + const best = affineIntensity(reference, 0.6, 0.2); + + const slices: Float32Array[] = []; + slices[0] = makeDeterministicRandomFloat(n, 1); + slices[1] = addNoise(reference, 0.15, 2); + slices[2] = best; + slices[3] = addNoise(reference, 0.25, 3); + + const getSlice = async (idx: number) => { + const s = slices[idx]; + if (!s) throw new Error('missing slice'); + return s; + }; + + const r = await findBestMatchingSlice(reference, getSlice, 0, 1, slices.length, undefined, { + scoreMetric: 'mind', + mindSize: 64, + stopDecreaseStreak: 2, + minSearchRadius: 0, + }); + + expect(r.bestIndex).toBe(2); + }); + + test('scoreMetric=phase prefers slices with the same frequency content despite intensity remapping', async () => { + const size = 64; + const n = size * size; + + const reference = makeDeterministicRandomFloat(n, 7); + + const slices: Float32Array[] = []; + slices[0] = makeDeterministicRandomFloat(n, 123); + slices[1] = affineIntensity(reference, 1.2, -0.1); + slices[2] = makeDeterministicRandomFloat(n, 456); + + const getSlice = async (idx: number) => { + const s = slices[idx]; + if (!s) throw new Error('missing slice'); + return s; + }; + + const r = await findBestMatchingSlice(reference, getSlice, 0, 1, slices.length, undefined, { + scoreMetric: 'phase', + phaseSize: 64, + stopDecreaseStreak: 2, + minSearchRadius: 0, + }); + + expect(r.bestIndex).toBe(1); + }); +}); diff --git a/frontend/tests/geodesicDistance.test.ts b/frontend/tests/geodesicDistance.test.ts new file mode 100644 index 0000000..57a9dc6 --- /dev/null +++ b/frontend/tests/geodesicDistance.test.ts @@ -0,0 +1,57 @@ +import { describe, expect, test } from 'vitest'; +import { computeGeodesicDistanceToSeeds } from '../src/utils/segmentation/geodesicDistance'; + +describe('geodesicDistance', () => { + test('edgeCostStrength=0 matches Manhattan distance in a rectangular ROI', () => { + const w = 5; + const h = 5; + + const dist = computeGeodesicDistanceToSeeds({ + w, + h, + roi: { x0: 0, y0: 0, x1: w - 1, y1: h - 1 }, + seeds: [{ x: 0, y: 2 }], + edgeCostStrength: 0, + }); + + const at = (x: number, y: number) => dist[y * w + x]!; + + expect(at(0, 2)).toBeCloseTo(0, 8); + expect(at(1, 2)).toBeCloseTo(1, 8); + expect(at(2, 2)).toBeCloseTo(2, 8); + expect(at(3, 2)).toBeCloseTo(3, 8); + expect(at(4, 2)).toBeCloseTo(4, 8); + + // A diagonal corner should be manhattan distance as well. + expect(at(4, 4)).toBeCloseTo(6, 8); + }); + + test('strong edge cost increases distance across a barrier', () => { + const w = 5; + const h = 5; + + // Vertical "edge barrier" at x=2. + const grad = new Uint8Array(w * h); + for (let y = 0; y < h; y++) { + grad[y * w + 2] = 255; + } + + const dist = computeGeodesicDistanceToSeeds({ + w, + h, + roi: { x0: 0, y0: 0, x1: w - 1, y1: h - 1 }, + seeds: [{ x: 0, y: 2 }], + grad, + edgeCostStrength: 10, + }); + + const at = (x: number, y: number) => dist[y * w + x]!; + + // Crossing the barrier requires entering an x=2 cell once: + // base 4 steps + extra 10 cost = 14. + expect(at(4, 2)).toBeCloseTo(14, 6); + + // The barrier cell itself should reflect the extra cost. + expect(at(2, 2)).toBeCloseTo(12, 6); + }); +}); diff --git a/frontend/tests/mutualInformation.test.ts b/frontend/tests/mutualInformation.test.ts index e8c725d..277f4a9 100644 --- a/frontend/tests/mutualInformation.test.ts +++ b/frontend/tests/mutualInformation.test.ts @@ -58,4 +58,18 @@ describe('computeMutualInformation', () => { expect(Math.abs(ab.mi - ba.mi)).toBeLessThan(1e-6); expect(Math.abs(ab.nmi - ba.nmi)).toBeLessThan(1e-6); }); + + test('inclusionMask restricts pixels used and can increase MI when it selects the matching region', () => { + const a = new Float32Array([0, 1, 0, 1, 0, 1, 0, 1]); + // Second half is inverted relative to A. + const b = new Float32Array([0, 1, 0, 1, 1, 0, 1, 0]); + + const unmasked = computeMutualInformation(a, b, { bins: 8 }); + + const mask = new Uint8Array([1, 1, 1, 1, 0, 0, 0, 0]); + const masked = computeMutualInformation(a, b, { bins: 8, inclusionMask: mask }); + + expect(masked.pixelsUsed).toBe(4); + expect(masked.mi).toBeGreaterThan(unmasked.mi); + }); }); diff --git a/frontend/tests/segmentation.test.ts b/frontend/tests/segmentation.test.ts new file mode 100644 index 0000000..6e5a5aa --- /dev/null +++ b/frontend/tests/segmentation.test.ts @@ -0,0 +1,74 @@ +import { describe, expect, it } from 'vitest'; +import { marchingSquaresContour } from '../src/utils/segmentation/marchingSquares'; +import { morphologicalClose } from '../src/utils/segmentation/morphology'; +import { rdpSimplify } from '../src/utils/segmentation/simplify'; +import { chaikinSmooth } from '../src/utils/segmentation/smooth'; + +describe('segmentation utilities', () => { + it('rdpSimplify reduces points on a nearly straight polyline', () => { + const pts = Array.from({ length: 20 }, (_, i) => ({ x: i / 19, y: 0.5 + (i % 2 ? 1e-4 : -1e-4) })); + const simplified = rdpSimplify(pts, 0.01); + expect(simplified.length).toBeLessThan(pts.length); + expect(simplified.length).toBeGreaterThanOrEqual(2); + }); + + it('marchingSquaresContour returns a contour around a filled square', () => { + const w = 10; + const h = 10; + const mask = new Uint8Array(w * h); + + // Fill a 4x4 square. + for (let y = 3; y <= 6; y++) { + for (let x = 3; x <= 6; x++) { + mask[y * w + x] = 1; + } + } + + const contour = marchingSquaresContour(mask, w, h); + expect(contour.length).toBeGreaterThan(0); + + // Contour points should lie near the square boundary (midpoints between pixels). + for (const p of contour) { + expect(p.x).toBeGreaterThanOrEqual(2.5); + expect(p.x).toBeLessThanOrEqual(6.5); + expect(p.y).toBeGreaterThanOrEqual(2.5); + expect(p.y).toBeLessThanOrEqual(6.5); + } + }); + + it('morphologicalClose fills a 1px hole', () => { + const w = 7; + const h = 7; + const mask = new Uint8Array(w * h); + + // Fill a 3x3 block, but leave a 1px hole in the center. + for (let y = 2; y <= 4; y++) { + for (let x = 2; x <= 4; x++) { + mask[y * w + x] = 1; + } + } + mask[3 * w + 3] = 0; + + const closed = morphologicalClose(mask, w, h); + expect(closed[3 * w + 3]).toBe(1); + }); + + it('chaikinSmooth increases point count and keeps points in bounds', () => { + const square = [ + { x: 0, y: 0 }, + { x: 1, y: 0 }, + { x: 1, y: 1 }, + { x: 0, y: 1 }, + ]; + + const smoothed = chaikinSmooth(square, 2); + expect(smoothed.length).toBeGreaterThan(square.length); + + for (const p of smoothed) { + expect(p.x).toBeGreaterThanOrEqual(0); + expect(p.x).toBeLessThanOrEqual(1); + expect(p.y).toBeGreaterThanOrEqual(0); + expect(p.y).toBeLessThanOrEqual(1); + } + }); +}); diff --git a/frontend/tests/svrDicomGeometry.test.ts b/frontend/tests/svrDicomGeometry.test.ts new file mode 100644 index 0000000..5f25390 --- /dev/null +++ b/frontend/tests/svrDicomGeometry.test.ts @@ -0,0 +1,59 @@ +import { describe, expect, it } from 'vitest'; +import { + parseImageOrientationPatient, + parseImagePositionPatient, + parsePixelSpacingMm, + sliceCornersMm, +} from '../src/utils/svr/dicomGeometry'; + +describe('svr/dicomGeometry', () => { + it('parses PixelSpacing', () => { + expect(parsePixelSpacingMm('0.5\\0.6')).toEqual({ rowSpacingMm: 0.5, colSpacingMm: 0.6 }); + }); + + it('parses ImagePositionPatient', () => { + expect(parseImagePositionPatient('1\\2\\3')).toEqual({ x: 1, y: 2, z: 3 }); + }); + + it('parses ImageOrientationPatient and computes a normal', () => { + const axes = parseImageOrientationPatient('1\\0\\0\\0\\1\\0'); + expect(axes).not.toBeNull(); + if (!axes) return; + + expect(axes.rowDir.x).toBeCloseTo(1); + expect(axes.rowDir.y).toBeCloseTo(0); + expect(axes.rowDir.z).toBeCloseTo(0); + + expect(axes.colDir.x).toBeCloseTo(0); + expect(axes.colDir.y).toBeCloseTo(1); + expect(axes.colDir.z).toBeCloseTo(0); + + // Right-hand rule: row x col = +Z + expect(axes.normalDir.x).toBeCloseTo(0); + expect(axes.normalDir.y).toBeCloseTo(0); + expect(axes.normalDir.z).toBeCloseTo(1); + }); + + it('computes slice corners using DICOM row/col conventions (non-square spacing)', () => { + const axes = parseImageOrientationPatient('1\\0\\0\\0\\1\\0'); + expect(axes).not.toBeNull(); + if (!axes) return; + + const corners = sliceCornersMm({ + ippMm: { x: 0, y: 0, z: 0 }, + rowDir: axes.rowDir, + colDir: axes.colDir, + rowSpacingMm: 2, + colSpacingMm: 3, + rows: 2, + cols: 2, + }); + + // (row=0,col=1) goes +X by colSpacing + expect(corners[1]).toEqual({ x: 3, y: 0, z: 0 }); + // (row=1,col=0) goes +Y by rowSpacing + expect(corners[2]).toEqual({ x: 0, y: 2, z: 0 }); + // (row=1,col=1) + expect(corners[3]).toEqual({ x: 3, y: 2, z: 0 }); + }); +}); diff --git a/frontend/tests/svrDownsample.test.ts b/frontend/tests/svrDownsample.test.ts new file mode 100644 index 0000000..e9cfeb8 --- /dev/null +++ b/frontend/tests/svrDownsample.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, it } from 'vitest'; +import { computeSvrDownsampleSize } from '../src/utils/svr/downsample'; + +describe('svr/downsample', () => { + it('fixed mode obeys maxSize', () => { + const r = computeSvrDownsampleSize({ + rows: 512, + cols: 512, + maxSize: 128, + mode: 'fixed', + rowSpacingMm: 0.5, + colSpacingMm: 0.5, + targetVoxelSizeMm: 1.0, + }); + + expect(r.dsRows).toBe(128); + expect(r.dsCols).toBe(128); + }); + + it('voxel-aware mode refuses to downsample beyond the target voxel size', () => { + // With 0.5mm pixels and a 1.0mm voxel target, we can downsample by at most 2x (512 -> 256). + const r = computeSvrDownsampleSize({ + rows: 512, + cols: 512, + maxSize: 128, + mode: 'voxel-aware', + rowSpacingMm: 0.5, + colSpacingMm: 0.5, + targetVoxelSizeMm: 1.0, + }); + + expect(r.dsRows).toBe(256); + expect(r.dsCols).toBe(256); + }); + + it('voxel-aware mode keeps full resolution when target voxels are as small as pixels', () => { + const r = computeSvrDownsampleSize({ + rows: 512, + cols: 512, + maxSize: 128, + mode: 'voxel-aware', + rowSpacingMm: 0.5, + colSpacingMm: 0.5, + targetVoxelSizeMm: 0.5, + }); + + expect(r.dsRows).toBe(512); + expect(r.dsCols).toBe(512); + }); +}); diff --git a/frontend/tests/svrGeometryInvariants.test.ts b/frontend/tests/svrGeometryInvariants.test.ts new file mode 100644 index 0000000..5d68aed --- /dev/null +++ b/frontend/tests/svrGeometryInvariants.test.ts @@ -0,0 +1,119 @@ +import { describe, expect, it } from 'vitest'; +import { parseImageOrientationPatient, parseImagePositionPatient, parsePixelSpacingMm } from '../src/utils/svr/dicomGeometry'; +import type { Vec3 } from '../src/utils/svr/vec3'; +import { dot, v3 } from '../src/utils/svr/vec3'; + +function worldFromRc(params: { + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + rowSpacingMm: number; + colSpacingMm: number; + r: number; + c: number; +}): Vec3 { + const { ippMm, rowDir, colDir, rowSpacingMm, colSpacingMm, r, c } = params; + + // NOTE: This intentionally matches the convention used throughout SVR: + // world(r, c) = IPP + colDir * (r * rowSpacing) + rowDir * (c * colSpacing) + return v3( + ippMm.x + colDir.x * (r * rowSpacingMm) + rowDir.x * (c * colSpacingMm), + ippMm.y + colDir.y * (r * rowSpacingMm) + rowDir.y * (c * colSpacingMm), + ippMm.z + colDir.z * (r * rowSpacingMm) + rowDir.z * (c * colSpacingMm) + ); +} + +function rcFromWorld(params: { + ippMm: Vec3; + rowDir: Vec3; + colDir: Vec3; + rowSpacingMm: number; + colSpacingMm: number; + worldMm: Vec3; +}): { r: number; c: number } { + const { ippMm, rowDir, colDir, rowSpacingMm, colSpacingMm, worldMm } = params; + + const dx = v3(worldMm.x - ippMm.x, worldMm.y - ippMm.y, worldMm.z - ippMm.z); + + return { + r: dot(dx, colDir) / rowSpacingMm, + c: dot(dx, rowDir) / colSpacingMm, + }; +} + +describe('svr geometry invariants', () => { + it('world(r,c) roundtrips back to (r,c) (axis-aligned, non-square spacing)', () => { + const ippMm = parseImagePositionPatient('10\\20\\30'); + const axes = parseImageOrientationPatient('1\\0\\0\\0\\1\\0'); + const spacing = parsePixelSpacingMm('2\\3'); + + expect(ippMm).not.toBeNull(); + expect(axes).not.toBeNull(); + expect(spacing).not.toBeNull(); + if (!ippMm || !axes || !spacing) return; + + const r = 5; + const c = 7; + + const worldMm = worldFromRc({ + ippMm, + rowDir: axes.rowDir, + colDir: axes.colDir, + rowSpacingMm: spacing.rowSpacingMm, + colSpacingMm: spacing.colSpacingMm, + r, + c, + }); + + const back = rcFromWorld({ + ippMm, + rowDir: axes.rowDir, + colDir: axes.colDir, + rowSpacingMm: spacing.rowSpacingMm, + colSpacingMm: spacing.colSpacingMm, + worldMm, + }); + + expect(back.r).toBeCloseTo(r, 6); + expect(back.c).toBeCloseTo(c, 6); + }); + + it('world(r,c) roundtrips for rotated in-plane axes', () => { + const ippMm = parseImagePositionPatient('0\\0\\0'); + + // 90° rotation in the XY plane. + // First triplet (IOP[0..2]) is +Y, second triplet (IOP[3..5]) is -X. + const axes = parseImageOrientationPatient('0\\1\\0\\-1\\0\\0'); + const spacing = parsePixelSpacingMm('0.5\\2.0'); + + expect(ippMm).not.toBeNull(); + expect(axes).not.toBeNull(); + expect(spacing).not.toBeNull(); + if (!ippMm || !axes || !spacing) return; + + const r = 11; + const c = 3; + + const worldMm = worldFromRc({ + ippMm, + rowDir: axes.rowDir, + colDir: axes.colDir, + rowSpacingMm: spacing.rowSpacingMm, + colSpacingMm: spacing.colSpacingMm, + r, + c, + }); + + const back = rcFromWorld({ + ippMm, + rowDir: axes.rowDir, + colDir: axes.colDir, + rowSpacingMm: spacing.rowSpacingMm, + colSpacingMm: spacing.colSpacingMm, + worldMm, + }); + + expect(back.r).toBeCloseTo(r, 6); + expect(back.c).toBeCloseTo(c, 6); + }); +}); diff --git a/frontend/tests/svrPhantom.test.ts b/frontend/tests/svrPhantom.test.ts new file mode 100644 index 0000000..821a09b --- /dev/null +++ b/frontend/tests/svrPhantom.test.ts @@ -0,0 +1,232 @@ +import { describe, expect, it } from 'vitest'; +import type { VolumeDims } from '../src/utils/svr/trilinear'; +import { sampleTrilinear } from '../src/utils/svr/trilinear'; +import type { SvrReconstructionGrid, SvrReconstructionOptions, SvrReconstructionSlice } from '../src/utils/svr/reconstructionCore'; +import { reconstructVolumeFromSlices } from '../src/utils/svr/reconstructionCore'; + +function idxOf(x: number, y: number, z: number, dims: VolumeDims): number { + return x + y * dims.nx + z * dims.nx * dims.ny; +} + +function makePhantomVolume(dims: VolumeDims): Float32Array { + // Simple sharp-edged structure: a filled cube + a smaller offset cube. + const vol = new Float32Array(dims.nx * dims.ny * dims.nz); + + const fillBox = (min: [number, number, number], max: [number, number, number], v: number) => { + for (let z = min[2]; z <= max[2]; z++) { + for (let y = min[1]; y <= max[1]; y++) { + for (let x = min[0]; x <= max[0]; x++) { + if (x < 0 || y < 0 || z < 0 || x >= dims.nx || y >= dims.ny || z >= dims.nz) continue; + vol[idxOf(x, y, z, dims)] = v; + } + } + } + }; + + fillBox([10, 10, 10], [20, 20, 20], 1); + fillBox([22, 12, 14], [27, 16, 18], 0.6); + + return vol; +} + +function sampleVolumeAtWorldMm(params: { vol: Float32Array; dims: VolumeDims; x: number; y: number; z: number }): number { + const { vol, dims, x, y, z } = params; + return sampleTrilinear(vol, dims, x, y, z); +} + +function sampleWithThicknessBox(params: { + vol: Float32Array; + dims: VolumeDims; + world: { x: number; y: number; z: number }; + normal: { x: number; y: number; z: number }; + thicknessMm: number; +}): number { + const { vol, dims, world, normal, thicknessMm } = params; + + const t = Math.max(0, thicknessMm); + if (!(t > 0)) { + return sampleVolumeAtWorldMm({ vol, dims, x: world.x, y: world.y, z: world.z }); + } + + // Deterministic box integration across thickness. + const n = 7; + const half = 0.5 * t; + const step = t / n; + + let sum = 0; + for (let i = 0; i < n; i++) { + const off = -half + (i + 0.5) * step; + sum += sampleVolumeAtWorldMm({ + vol, + dims, + x: world.x + normal.x * off, + y: world.y + normal.y * off, + z: world.z + normal.z * off, + }); + } + + return sum / n; +} + +function makeSliceSeries(params: { + vol: Float32Array; + dims: VolumeDims; + plane: 'axial' | 'coronal' | 'sagittal'; + rows: number; + cols: number; + slicePositions: number[]; + spacingMm: number; + thicknessMm: number; +}): SvrReconstructionSlice[] { + const { vol, dims, plane, rows, cols, slicePositions, spacingMm, thicknessMm } = params; + + const slices: SvrReconstructionSlice[] = []; + + for (const sPos of slicePositions) { + // Coordinate frame conventions: + // world(r,c) = IPP + colDir*(r*rowSpacing) + rowDir*(c*colSpacing) + // (matches the SVR DICOM convention used in reconstruction). + + let rowDir = { x: 1, y: 0, z: 0 }; + let colDir = { x: 0, y: 1, z: 0 }; + let normalDir = { x: 0, y: 0, z: 1 }; + let ippMm = { x: 0, y: 0, z: 0 }; + + if (plane === 'axial') { + // z fixed, rows +Y, cols +X + rowDir = { x: 1, y: 0, z: 0 }; + colDir = { x: 0, y: 1, z: 0 }; + normalDir = { x: 0, y: 0, z: 1 }; + ippMm = { x: 0, y: 0, z: sPos }; + } else if (plane === 'coronal') { + // y fixed, rows +Z, cols +X, normal -Y + rowDir = { x: 1, y: 0, z: 0 }; + colDir = { x: 0, y: 0, z: 1 }; + normalDir = { x: 0, y: -1, z: 0 }; + ippMm = { x: 0, y: sPos, z: 0 }; + } else { + // sagittal: x fixed, rows +Z, cols +Y, normal +X + rowDir = { x: 0, y: 1, z: 0 }; + colDir = { x: 0, y: 0, z: 1 }; + normalDir = { x: 1, y: 0, z: 0 }; + ippMm = { x: sPos, y: 0, z: 0 }; + } + + const pixels = new Float32Array(rows * cols); + + for (let r = 0; r < rows; r++) { + const baseX = ippMm.x + colDir.x * (r * spacingMm); + const baseY = ippMm.y + colDir.y * (r * spacingMm); + const baseZ = ippMm.z + colDir.z * (r * spacingMm); + + const rowBase = r * cols; + + for (let c = 0; c < cols; c++) { + const wx = baseX + rowDir.x * (c * spacingMm); + const wy = baseY + rowDir.y * (c * spacingMm); + const wz = baseZ + rowDir.z * (c * spacingMm); + + const v = sampleWithThicknessBox({ + vol, + dims, + world: { x: wx, y: wy, z: wz }, + normal: normalDir, + thicknessMm, + }); + + pixels[rowBase + c] = v; + } + } + + slices.push({ + pixels, + dsRows: rows, + dsCols: cols, + ippMm, + rowDir, + colDir, + normalDir, + rowSpacingDsMm: spacingMm, + colSpacingDsMm: spacingMm, + sliceThicknessMm: thicknessMm, + spacingBetweenSlicesMm: null, + }); + } + + return slices; +} + +function mse(a: Float32Array, b: Float32Array): number { + const n = Math.min(a.length, b.length); + let sum = 0; + for (let i = 0; i < n; i++) { + const d = (a[i] ?? 0) - (b[i] ?? 0); + sum += d * d; + } + return sum / Math.max(1, n); +} + +function psnrFromMse(m: number): number { + // Standard PSNR with MAX=1 (since our phantom is in [0,1]): PSNR = 10 * log10(MAX^2 / MSE). + const mm = Math.max(1e-12, m); + return -10 * Math.log10(mm); +} + +describe('svr/phantom', () => { + it('PSF-aware reconstruction reduces error when slices have non-zero thickness', async () => { + const dims: VolumeDims = { nx: 34, ny: 34, nz: 34 }; + const gt = makePhantomVolume(dims); + + const grid: SvrReconstructionGrid = { + dims, + originMm: { x: 0, y: 0, z: 0 }, + voxelSizeMm: 1, + }; + + const rows = 33; + const cols = 33; + const spacingMm = 1; + const thicknessMm = 4; + + const slicePositions = [6, 10, 14, 18, 22, 26]; + + const slices: SvrReconstructionSlice[] = [ + ...makeSliceSeries({ vol: gt, dims, plane: 'axial', rows, cols, slicePositions, spacingMm, thicknessMm }), + ...makeSliceSeries({ vol: gt, dims, plane: 'coronal', rows, cols, slicePositions, spacingMm, thicknessMm }), + ...makeSliceSeries({ vol: gt, dims, plane: 'sagittal', rows, cols, slicePositions, spacingMm, thicknessMm }), + ]; + + const base: SvrReconstructionOptions = { + iterations: 3, + stepSize: 0.6, + clampOutput: true, + psfMode: 'none', + robustLoss: 'none', + robustDelta: 0.1, + laplacianWeight: 0, + }; + + const psfAware: SvrReconstructionOptions = { + ...base, + psfMode: 'box', + robustLoss: 'huber', + laplacianWeight: 0.02, + }; + + const recBase = await reconstructVolumeFromSlices({ slices, grid, options: base }); + const recPsf = await reconstructVolumeFromSlices({ slices, grid, options: psfAware }); + + const mseBase = mse(recBase, gt); + const msePsf = mse(recPsf, gt); + + const psnrBase = psnrFromMse(mseBase); + const psnrPsf = psnrFromMse(msePsf); + + // The PSF-aware model should do meaningfully better on thick slices. + expect(msePsf).toBeLessThan(mseBase * 0.9); + expect(psnrPsf).toBeGreaterThan(psnrBase + 0.2); + + // Sanity bounds (avoid a totally broken solver passing the relative check). + expect(msePsf).toBeLessThan(0.25); + }); +}); diff --git a/frontend/tests/svrResample2d.test.ts b/frontend/tests/svrResample2d.test.ts new file mode 100644 index 0000000..b6d649b --- /dev/null +++ b/frontend/tests/svrResample2d.test.ts @@ -0,0 +1,89 @@ +import { describe, expect, it } from 'vitest'; +import { resample2dAreaAverage, resample2dLanczos3 } from '../src/utils/svr/resample2d'; + +describe('svr/resample2dAreaAverage', () => { + it('returns an identical copy when dimensions match', () => { + const src = new Float32Array([1, 2, 3, 4, 5, 6]); + const out = resample2dAreaAverage(src, 2, 3, 2, 3); + + expect(out).not.toBe(src); + expect(Array.from(out)).toEqual([1, 2, 3, 4, 5, 6]); + }); + + it('preserves constant images under downsampling', () => { + const src = new Float32Array(8 * 6).fill(7.25); + const out = resample2dAreaAverage(src, 8, 6, 4, 3); + + expect(out.length).toBe(4 * 3); + for (const v of out) { + expect(v).toBeCloseTo(7.25, 6); + } + }); + + it('downsamples 2x2 -> 1x1 by averaging all pixels', () => { + // [[1, 2], + // [3, 4]] => avg = 2.5 + const src = new Float32Array([1, 2, 3, 4]); + const out = resample2dAreaAverage(src, 2, 2, 1, 1); + + expect(out.length).toBe(1); + expect(out[0]).toBeCloseTo(2.5, 6); + }); + + it('downsamples 4x4 -> 2x2 by averaging 2x2 blocks', () => { + // src: + // 0 1 2 3 + // 4 5 6 7 + // 8 9 10 11 + // 12 13 14 15 + // blocks (2x2): + // [0,1,4,5] avg=2.5, [2,3,6,7] avg=4.5 + // [8,9,12,13] avg=10.5, [10,11,14,15] avg=12.5 + const src = new Float32Array([...Array.from({ length: 16 }, (_, i) => i)]); + const out = resample2dAreaAverage(src, 4, 4, 2, 2); + + expect(out.length).toBe(4); + expect(out[0]).toBeCloseTo(2.5, 6); + expect(out[1]).toBeCloseTo(4.5, 6); + expect(out[2]).toBeCloseTo(10.5, 6); + expect(out[3]).toBeCloseTo(12.5, 6); + }); + + it('upsamples 2x2 -> 4x4 replicates pixels for integer scales', () => { + // Each source pixel becomes a 2x2 block. + const src = new Float32Array([ + 1, 2, + 3, 4, + ]); + + const out = resample2dAreaAverage(src, 2, 2, 4, 4); + + const expected = [ + 1, 1, 2, 2, + 1, 1, 2, 2, + 3, 3, 4, 4, + 3, 3, 4, 4, + ]; + + expect(out.length).toBe(16); + expect(Array.from(out)).toEqual(expected); + }); + + it('Lanczos3 preserves constant images under downsampling', () => { + const src = new Float32Array(8 * 6).fill(3.125); + const out = resample2dLanczos3(src, 8, 6, 4, 3); + + expect(out.length).toBe(4 * 3); + for (const v of out) { + expect(v).toBeCloseTo(3.125, 5); + } + }); + + it('Lanczos3 2x2 -> 1x1 equals the mean for this symmetric case', () => { + const src = new Float32Array([1, 2, 3, 4]); + const out = resample2dLanczos3(src, 2, 2, 1, 1); + + expect(out.length).toBe(1); + expect(out[0]).toBeCloseTo(2.5, 5); + }); +}); diff --git a/frontend/tests/svrSliceRoiCrop.test.ts b/frontend/tests/svrSliceRoiCrop.test.ts new file mode 100644 index 0000000..fd21488 --- /dev/null +++ b/frontend/tests/svrSliceRoiCrop.test.ts @@ -0,0 +1,86 @@ +import { describe, expect, it } from 'vitest'; +import { boundsCornersMm, cropSliceToRoiInPlace } from '../src/utils/svr/sliceRoiCrop'; + +function makeGridPixels(rows: number, cols: number): Float32Array { + const out = new Float32Array(rows * cols); + for (let r = 0; r < rows; r++) { + for (let c = 0; c < cols; c++) { + out[r * cols + c] = r * 100 + c; + } + } + return out; +} + +describe('svr/sliceRoiCrop', () => { + it('crops an axial slice and shifts IPP so (r0,c0) becomes the new origin', () => { + const slice = { + pixels: makeGridPixels(10, 10), + dsRows: 10, + dsCols: 10, + ippMm: { x: 0, y: 0, z: 5 }, + // world(r,c) = IPP + colDir*(r*rowSpacing) + rowDir*(c*colSpacing) + rowDir: { x: 1, y: 0, z: 0 }, + colDir: { x: 0, y: 1, z: 0 }, + normalDir: { x: 0, y: 0, z: 1 }, + rowSpacingDsMm: 1, + colSpacingDsMm: 1, + }; + + const bounds = { + min: { x: 2, y: 3, z: 4 }, + max: { x: 5, y: 7, z: 6 }, + }; + + const corners = boundsCornersMm(bounds); + const ok = cropSliceToRoiInPlace(slice, corners); + expect(ok).toBe(true); + + // For this axis-aligned slice: + // r corresponds to +Y, c corresponds to +X. + // We conservatively expand by 1px: r0=floor(3)-1=2, r1=ceil(7)+1=8 => 7 rows + // c0=floor(2)-1=1, c1=ceil(5)+1=6 => 6 cols + expect(slice.dsRows).toBe(7); + expect(slice.dsCols).toBe(6); + + // New IPP is shifted by (r0,c0) in world space: + // IPP' = IPP + colDir*(r0*rowSpacing) + rowDir*(c0*colSpacing) + expect(slice.ippMm).toEqual({ x: 1, y: 2, z: 5 }); + + // New (0,0) pixel should match old (r0,c0). + expect(slice.pixels[0]).toBe(2 * 100 + 1); + }); + + it('rejects a slice when ROI slab does not intersect the slice plane', () => { + const slice = { + pixels: makeGridPixels(10, 10), + dsRows: 10, + dsCols: 10, + ippMm: { x: 0, y: 0, z: 5 }, + rowDir: { x: 1, y: 0, z: 0 }, + colDir: { x: 0, y: 1, z: 0 }, + normalDir: { x: 0, y: 0, z: 1 }, + rowSpacingDsMm: 1, + colSpacingDsMm: 1, + }; + + const before = { + dsRows: slice.dsRows, + dsCols: slice.dsCols, + ipp: { ...slice.ippMm }, + }; + + const bounds = { + min: { x: 0, y: 0, z: 10 }, + max: { x: 1, y: 1, z: 11 }, + }; + + const corners = boundsCornersMm(bounds); + const ok = cropSliceToRoiInPlace(slice, corners); + expect(ok).toBe(false); + + // Slice should remain unchanged. + expect(slice.dsRows).toBe(before.dsRows); + expect(slice.dsCols).toBe(before.dsCols); + expect(slice.ippMm).toEqual(before.ipp); + }); +}); diff --git a/frontend/tests/svrTrilinear.test.ts b/frontend/tests/svrTrilinear.test.ts new file mode 100644 index 0000000..0175927 --- /dev/null +++ b/frontend/tests/svrTrilinear.test.ts @@ -0,0 +1,46 @@ +import { describe, expect, it } from 'vitest'; +import { sampleTrilinear, splatTrilinear, splatTrilinearScaled } from '../src/utils/svr/trilinear'; + +describe('svr/trilinear', () => { + it('sampleTrilinear samples the center of a 2x2x2 volume', () => { + const dims = { nx: 2, ny: 2, nz: 2 }; + const volume = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7]); + + const v = sampleTrilinear(volume, dims, 0.5, 0.5, 0.5); + expect(v).toBeCloseTo(3.5); + }); + + it('splatTrilinear distributes weights to 8 neighbors', () => { + const dims = { nx: 2, ny: 2, nz: 2 }; + const accum = new Float32Array(8); + const weight = new Float32Array(8); + + splatTrilinear(accum, weight, dims, 0.5, 0.5, 0.5, 1); + + const sumAccum = accum.reduce((a, b) => a + b, 0); + const sumWeight = weight.reduce((a, b) => a + b, 0); + + expect(sumAccum).toBeCloseTo(1); + expect(sumWeight).toBeCloseTo(1); + + for (let i = 0; i < 8; i++) { + expect(weight[i]).toBeCloseTo(1 / 8); + expect(accum[i]).toBeCloseTo(1 / 8); + } + }); + + it('splatTrilinearScaled scales both accum and weight', () => { + const dims = { nx: 2, ny: 2, nz: 2 }; + const accum = new Float32Array(8); + const weight = new Float32Array(8); + + splatTrilinearScaled(accum, weight, dims, 0.5, 0.5, 0.5, 2, 0.25); + + const sumAccum = accum.reduce((a, b) => a + b, 0); + const sumWeight = weight.reduce((a, b) => a + b, 0); + + // splatTrilinearScaled should be equivalent to splatTrilinear(val * scale) AND weight scaled. + expect(sumAccum).toBeCloseTo(2 * 0.25); + expect(sumWeight).toBeCloseTo(0.25); + }); +}); diff --git a/frontend/tests/tumorHarnessRunner.test.ts b/frontend/tests/tumorHarnessRunner.test.ts new file mode 100644 index 0000000..a5194a5 --- /dev/null +++ b/frontend/tests/tumorHarnessRunner.test.ts @@ -0,0 +1,84 @@ +import { readFile, mkdir, writeFile } from 'node:fs/promises'; +import path from 'node:path'; +import JSZip from 'jszip'; + +import type { SegmentTumorOptions } from '../src/utils/segmentation/segmentTumor'; +import { + parseTumorHarnessDataset, + runTumorHarnessDataset, + summarizeReport, +} from '../src/utils/segmentation/harness/runTumorHarness'; + +async function loadDatasetJsonText(datasetPath: string): Promise { + const buf = await readFile(datasetPath); + + if (datasetPath.toLowerCase().endsWith('.zip')) { + const zip = await JSZip.loadAsync(buf); + const entry = zip.file('dataset.json'); + if (!entry) { + throw new Error('Zip does not contain dataset.json at root'); + } + return await entry.async('string'); + } + + return buf.toString('utf8'); +} + +const DATASET_PATH = process.env.TUMOR_HARNESS_DATASET; + +if (!DATASET_PATH) { + test.skip('tumor harness runner (set TUMOR_HARNESS_DATASET to enable)', () => {}); +} else { + test('tumor harness runner', async () => { + const jsonText = await loadDatasetJsonText(DATASET_PATH); + const dataset = parseTumorHarnessDataset(jsonText); + + if (dataset.cases.length === 0) { + throw new Error('Dataset contains 0 cases'); + } + + const v2Off: SegmentTumorOptions = { + bgModel: { enabled: false }, + geodesic: { enabled: false }, + }; + + const v2Bg: SegmentTumorOptions = { + bgModel: { enabled: true }, + geodesic: { enabled: false }, + }; + + const v2BgGeo: SegmentTumorOptions = { + bgModel: { enabled: true }, + geodesic: { enabled: true }, + }; + + const configs = [ + { name: 'baseline', opts: v2Off }, + { name: 'v2:bg', opts: v2Bg }, + { name: 'v2:bg+geo', opts: v2BgGeo }, + ]; + + const report = await runTumorHarnessDataset({ dataset, configs }); + + const outPath = + process.env.TUMOR_HARNESS_OUT ?? + path.resolve(process.cwd(), 'tmp', `tumor-harness-report.${new Date().toISOString().replace(/[:.]/g, '-')}.json`); + + await mkdir(path.dirname(outPath), { recursive: true }); + await writeFile(outPath, JSON.stringify(report, null, 2), 'utf8'); + + const summary = summarizeReport(report); + // Keep console output intentionally small. + console.log('[tumor-harness] cases:', dataset.cases.length); + console.log('[tumor-harness] scenarios:', dataset.propagationScenarios?.length ?? 0); + console.log('[tumor-harness] report:', outPath); + if (summary.bestSegConfigByDice) { + console.log('[tumor-harness] best dice:', summary.bestSegConfigByDice.name, summary.bestSegConfigByDice.dice.toFixed(4)); + } + if (summary.bestSegConfigByF2) { + console.log('[tumor-harness] best f2:', summary.bestSegConfigByF2.name, summary.bestSegConfigByF2.f2.toFixed(4)); + } + + expect(report.version).toBe(1); + }); +} diff --git a/frontend/tests/viewTransform.test.ts b/frontend/tests/viewTransform.test.ts new file mode 100644 index 0000000..a6a40d1 --- /dev/null +++ b/frontend/tests/viewTransform.test.ts @@ -0,0 +1,98 @@ +import { describe, expect, test } from 'vitest'; +import type { ViewerTransform } from '../src/db/schema'; +import { + remapPointBetweenViewerTransforms, + remapPolygonBetweenViewerTransforms, + type ViewportSize, +} from '../src/utils/viewTransform'; + +describe('viewTransform', () => { + test('remapPointBetweenViewerTransforms round-trips between two transforms', () => { + const size: ViewportSize = { w: 400, h: 300 }; + + const a: ViewerTransform = { + zoom: 1, + rotation: 0, + panX: 0, + panY: 0, + affine00: 1, + affine01: 0, + affine10: 0, + affine11: 1, + }; + + // Non-trivial but invertible matrix (det != 0) + rotation/zoom + pan. + const b: ViewerTransform = { + zoom: 1.6, + rotation: 27, + panX: 0.12, + panY: -0.08, + affine00: 1, + affine01: 0.2, + affine10: -0.15, + affine11: 0.95, + }; + + const pts = [ + { x: 0.5, y: 0.5 }, + { x: 0.1, y: 0.2 }, + { x: 0.9, y: 0.8 }, + { x: 0, y: 0 }, + { x: 1, y: 1 }, + ]; + + for (const p of pts) { + const q = remapPointBetweenViewerTransforms(p, size, a, b); + expect(Number.isFinite(q.x)).toBe(true); + expect(Number.isFinite(q.y)).toBe(true); + + const r = remapPointBetweenViewerTransforms(q, size, b, a); + expect(r.x).toBeCloseTo(p.x, 8); + expect(r.y).toBeCloseTo(p.y, 8); + } + }); + + test('remapPolygonBetweenViewerTransforms remaps each point and round-trips', () => { + const size: ViewportSize = { w: 512, h: 512 }; + + const from: ViewerTransform = { + zoom: 1.25, + rotation: -15, + panX: -0.05, + panY: 0.07, + affine00: 1, + affine01: 0.12, + affine10: 0, + affine11: 0.9, + }; + + const to: ViewerTransform = { + zoom: 0.85, + rotation: 42, + panX: 0.04, + panY: 0.02, + affine00: 0.95, + affine01: 0, + affine10: 0.08, + affine11: 1.05, + }; + + const poly = { + points: [ + { x: 0.2, y: 0.2 }, + { x: 0.8, y: 0.2 }, + { x: 0.75, y: 0.75 }, + { x: 0.25, y: 0.8 }, + ], + }; + + const remapped = remapPolygonBetweenViewerTransforms(poly, size, from, to); + expect(remapped.points).toHaveLength(poly.points.length); + + const roundTripped = remapPolygonBetweenViewerTransforms(remapped, size, to, from); + for (let i = 0; i < poly.points.length; i++) { + expect(roundTripped.points[i]!.x).toBeCloseTo(poly.points[i]!.x, 8); + expect(roundTripped.points[i]!.y).toBeCloseTo(poly.points[i]!.y, 8); + } + }); +}); From 2b3418ffdbf3cb0aace5382d9d77fc7985b9c49b Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 12:23:52 -0800 Subject: [PATCH 02/16] refactor(svr): Extract shared utilities and rigid registration to separate modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This refactoring improves the SVR codebase by: 1. Creating svrUtils.ts - shared utility functions: - clamp01: Clamps values to [0,1] range - assertNotAborted: Checks abort signal for cancellation - yieldToMain: Yields control for UI responsiveness - withinTrilinearSupport: Bounds checking for trilinear interpolation - clampAbs: Clamps absolute values with NaN handling - formatMiB: Human-readable byte formatting - quantileSorted: Quantile computation from sorted arrays 2. Creating rigidRegistration.ts - well-documented rigid registration module: - Full JSDoc documentation explaining the algorithm - Exported types: Mat3, RigidParams, SeriesSamples, BoundsMm, LoadedSlice - Matrix utilities: mat3FromEulerXYZ, mat3MulVec3 - Transform functions: applyRigidToPoint, applyRigidToSeriesSlices - Registration scoring: scoreNcc (Normalized Cross-Correlation) - Optimization: optimizeRigidNcc (multi-scale coordinate descent) - Main function: rigidAlignSeriesInRoi 3. Updating reconstructionCore.ts to use shared utilities: - Imports from svrUtils.ts instead of local definitions - Reduces code duplication 4. Adding comprehensive test coverage (15 new tests): - mat3FromEulerXYZ: identity, 90° rotations, orthonormality - applyRigidToPoint: identity, translation, rotation, combined - boundsCenterMm: positive and negative coordinates - scoreNcc: empty samples, insufficient samples, perfect correlation --- frontend/src/utils/svr/reconstructionCore.ts | 17 +- frontend/src/utils/svr/rigidRegistration.ts | 721 +++++++++++++++++++ frontend/src/utils/svr/svrUtils.ts | 95 +++ frontend/tests/svrRigidRegistration.test.ts | 293 ++++++++ 4 files changed, 1110 insertions(+), 16 deletions(-) create mode 100644 frontend/src/utils/svr/rigidRegistration.ts create mode 100644 frontend/src/utils/svr/svrUtils.ts create mode 100644 frontend/tests/svrRigidRegistration.test.ts diff --git a/frontend/src/utils/svr/reconstructionCore.ts b/frontend/src/utils/svr/reconstructionCore.ts index 8c9b105..4bc9af4 100644 --- a/frontend/src/utils/svr/reconstructionCore.ts +++ b/frontend/src/utils/svr/reconstructionCore.ts @@ -1,6 +1,7 @@ import type { VolumeDims } from './trilinear'; import { sampleTrilinear, splatTrilinearScaled } from './trilinear'; import type { Vec3 } from './vec3'; +import { assertNotAborted, clamp01, withinTrilinearSupport } from './svrUtils'; export type SvrPsfMode = 'none' | 'box' | 'gaussian'; export type SvrRobustLoss = 'none' | 'huber' | 'tukey'; @@ -51,22 +52,6 @@ export type SvrCoreHooks = { onProgress?: (p: { current: number; total: number; message: string }) => void; }; -function assertNotAborted(signal?: AbortSignal): void { - if (signal?.aborted) { - throw new Error('SVR cancelled'); - } -} - -function clamp01(x: number): number { - return x < 0 ? 0 : x > 1 ? 1 : x; -} - -function withinTrilinearSupport(dims: VolumeDims, x: number, y: number, z: number): boolean { - // sampleTrilinear/splatTrilinear require x0>=0 and x1= 0 && y >= 0 && z >= 0 && x < dims.nx - 1 && y < dims.ny - 1 && z < dims.nz - 1; -} - type SlicePsf = { offsetsMm: Float32Array; weights: Float32Array; count: number; effectiveThicknessMm: number }; function buildSlicePsf(params: { diff --git a/frontend/src/utils/svr/rigidRegistration.ts b/frontend/src/utils/svr/rigidRegistration.ts new file mode 100644 index 0000000..5b297bc --- /dev/null +++ b/frontend/src/utils/svr/rigidRegistration.ts @@ -0,0 +1,721 @@ +/** + * Rigid Registration for SVR (Slice-to-Volume Reconstruction) + * + * This module implements ROI-constrained rigid registration for aligning + * multiple MRI series before fusion. The registration uses normalized + * cross-correlation (NCC) as the similarity metric and performs coordinate + * descent optimization with multi-scale step sizes. + * + * Key concepts: + * - Each series is aligned to a reference volume built from other series + * - Transforms are applied about the ROI center to keep the region of interest stable + * - Small rotation and translation limits prevent unreasonable transforms + */ + +import type { SvrProgress, SvrRoi, SvrSelectedSeries } from '../../types/svr'; +import type { VolumeDims } from './trilinear'; +import { sampleTrilinear } from './trilinear'; +import type { Vec3 } from './vec3'; +import { cross, normalize, v3 } from './vec3'; +import { assertNotAborted, clampAbs, withinTrilinearSupport, yieldToMain } from './svrUtils'; +import type { SvrReconstructionGrid, SvrReconstructionOptions, SvrReconstructionSlice } from './reconstructionCore'; +import { reconstructVolumeFromSlices } from './reconstructionCore'; +import { debugSvrLog } from '../debugSvr'; + +// ============================================================================ +// Types +// ============================================================================ + +/** 3×3 rotation matrix stored as a flat 9-element tuple (row-major order) */ +export type Mat3 = [number, number, number, number, number, number, number, number, number]; + +/** + * Parameters for a rigid transform (rotation + translation). + * Rotation is specified as Euler angles in radians (XYZ convention). + */ +export type RigidParams = { + /** Translation in world/patient mm along X axis */ + tx: number; + /** Translation in world/patient mm along Y axis */ + ty: number; + /** Translation in world/patient mm along Z axis */ + tz: number; + /** Rotation in radians about X axis */ + rx: number; + /** Rotation in radians about Y axis */ + ry: number; + /** Rotation in radians about Z axis */ + rz: number; +}; + +/** + * Samples extracted from a series for registration scoring. + * Stores both intensity values and their world positions. + */ +export type SeriesSamples = { + /** Observed intensities (normalized [0,1]) */ + obs: Float32Array; + /** Original world positions (x,y,z interleaved, 3 values per sample) */ + pos: Float32Array; + /** Number of samples */ + count: number; +}; + +/** Axis-aligned bounding box in world/patient mm coordinates */ +export type BoundsMm = { min: Vec3; max: Vec3 }; + +/** + * LoadedSlice extends SvrReconstructionSlice with additional metadata + * needed for the full reconstruction pipeline. + */ +export type LoadedSlice = SvrReconstructionSlice & { + /** Series UID this slice belongs to */ + seriesUid: string; + /** SOP Instance UID for this specific slice */ + sopInstanceUid: string; + + /** Original (pre-downsample) row count */ + srcRows: number; + /** Original (pre-downsample) column count */ + srcCols: number; + /** Original row spacing in mm (pre-downsample) */ + rowSpacingMm: number; + /** Original column spacing in mm (pre-downsample) */ + colSpacingMm: number; +}; + +// ============================================================================ +// Matrix and transform utilities +// ============================================================================ + +/** + * Constructs a 3×3 rotation matrix from Euler angles using XYZ convention. + * The rotation order is: R = Rz(rz) * Ry(ry) * Rx(rx) + * + * @param rx - Rotation about X axis in radians + * @param ry - Rotation about Y axis in radians + * @param rz - Rotation about Z axis in radians + * @returns 3×3 rotation matrix as a flat array + */ +export function mat3FromEulerXYZ(rx: number, ry: number, rz: number): Mat3 { + const cx = Math.cos(rx); + const sx = Math.sin(rx); + const cy = Math.cos(ry); + const sy = Math.sin(ry); + const cz = Math.cos(rz); + const sz = Math.sin(rz); + + const m00 = cz * cy; + const m01 = cz * sy * sx - sz * cx; + const m02 = cz * sy * cx + sz * sx; + + const m10 = sz * cy; + const m11 = sz * sy * sx + cz * cx; + const m12 = sz * sy * cx - cz * sx; + + const m20 = -sy; + const m21 = cy * sx; + const m22 = cy * cx; + + return [m00, m01, m02, m10, m11, m12, m20, m21, m22]; +} + +/** + * Multiplies a 3×3 matrix by a 3D vector. + * + * @param m - 3×3 matrix (row-major) + * @param x - X component of vector + * @param y - Y component of vector + * @param z - Z component of vector + * @returns Transformed vector + */ +export function mat3MulVec3(m: Mat3, x: number, y: number, z: number): Vec3 { + return v3(m[0] * x + m[1] * y + m[2] * z, m[3] * x + m[4] * y + m[5] * z, m[6] * x + m[7] * y + m[8] * z); +} + +/** + * Applies a rigid transform to a point. + * The transform rotates about a center point, then translates. + * + * @param p - Point to transform + * @param centerMm - Center of rotation in mm + * @param rot - Rotation matrix + * @param tMm - Translation vector in mm + * @returns Transformed point + */ +export function applyRigidToPoint(p: Vec3, centerMm: Vec3, rot: Mat3, tMm: Vec3): Vec3 { + const dx = p.x - centerMm.x; + const dy = p.y - centerMm.y; + const dz = p.z - centerMm.z; + + const r = mat3MulVec3(rot, dx, dy, dz); + return v3(centerMm.x + r.x + tMm.x, centerMm.y + r.y + tMm.y, centerMm.z + r.z + tMm.z); +} + +/** + * Applies a rotation to a direction vector. + * + * @param d - Direction vector to rotate + * @param rot - Rotation matrix + * @returns Rotated and normalized direction vector + */ +function applyRotToDir(d: Vec3, rot: Mat3): Vec3 { + const r = mat3MulVec3(rot, d.x, d.y, d.z); + return normalize(r); +} + +/** + * Re-orthonormalizes row and column direction vectors. + * This prevents numerical drift after repeated rotations. + * + * @param rowDir - Row direction vector + * @param colDir - Column direction vector + * @returns Orthonormalized row and column vectors + */ +function orthonormalizeRowCol(rowDir: Vec3, colDir: Vec3): { rowDir: Vec3; colDir: Vec3 } { + const r = normalize(rowDir); + const c0 = normalize(colDir); + const n = normalize(cross(r, c0)); + const c = normalize(cross(n, r)); + return { rowDir: r, colDir: c }; +} + +// ============================================================================ +// Bounds utilities +// ============================================================================ + +/** + * Computes the center point of a bounding box. + */ +export function boundsCenterMm(b: BoundsMm): Vec3 { + return v3((b.min.x + b.max.x) * 0.5, (b.min.y + b.max.y) * 0.5, (b.min.z + b.max.z) * 0.5); +} + +/** + * Checks if a point is within a bounding box (inclusive). + */ +function isWithinBoundsMm(p: Vec3, b: BoundsMm): boolean { + return p.x >= b.min.x && p.x <= b.max.x && p.y >= b.min.y && p.y <= b.max.y && p.z >= b.min.z && p.z <= b.max.z; +} + +// ============================================================================ +// Slice transform application +// ============================================================================ + +/** + * Applies a rigid transform to all slices in a series. + * Modifies slices in-place. + * + * @param params.slices - Slices to transform + * @param params.centerMm - Center of rotation + * @param params.rot - Rotation matrix + * @param params.tMm - Translation vector + */ +export function applyRigidToSeriesSlices(params: { + slices: LoadedSlice[]; + centerMm: Vec3; + rot: Mat3; + tMm: Vec3; +}): void { + const { slices, centerMm, rot, tMm } = params; + + for (const s of slices) { + s.ippMm = applyRigidToPoint(s.ippMm, centerMm, rot, tMm); + + const row = applyRotToDir(s.rowDir, rot); + const col = applyRotToDir(s.colDir, rot); + const ortho = orthonormalizeRowCol(row, col); + s.rowDir = ortho.rowDir; + s.colDir = ortho.colDir; + s.normalDir = normalize(cross(s.rowDir, s.colDir)); + } +} + +// ============================================================================ +// Sample extraction for registration +// ============================================================================ + +/** + * Extracts intensity samples from slices within an ROI for registration scoring. + * Uses strided sampling to limit computation while maintaining spatial coverage. + * + * @param params.slices - Source slices + * @param params.roiBounds - ROI to sample within + * @param params.maxSamples - Maximum number of samples to extract + * @param params.signal - Optional abort signal + * @returns Extracted samples with positions + */ +export function buildSeriesSamples(params: { + slices: LoadedSlice[]; + roiBounds: BoundsMm; + maxSamples: number; + signal?: AbortSignal; +}): SeriesSamples { + const { slices, roiBounds, maxSamples, signal } = params; + + const maxN = Math.max(1, Math.round(maxSamples)); + const perSliceTarget = Math.max(64, Math.ceil(maxN / Math.max(1, slices.length))); + + let totalPixels = 0; + for (const s of slices) totalPixels += s.dsRows * s.dsCols; + + // Choose a roughly-uniform stride so we don't spend time scoring every pixel. + const stride = Math.max(1, Math.floor(Math.sqrt(totalPixels / maxN))); + + const obs: number[] = []; + const pos: number[] = []; + + for (let sIdx = 0; sIdx < slices.length; sIdx++) { + assertNotAborted(signal); + const s = slices[sIdx]; + if (!s) continue; + + let usedThisSlice = 0; + + for (let r = 0; r < s.dsRows; r += stride) { + const baseX = s.ippMm.x + s.colDir.x * (r * s.rowSpacingDsMm); + const baseY = s.ippMm.y + s.colDir.y * (r * s.rowSpacingDsMm); + const baseZ = s.ippMm.z + s.colDir.z * (r * s.rowSpacingDsMm); + + const rowBase = r * s.dsCols; + + for (let c = 0; c < s.dsCols; c += stride) { + const v = s.pixels[rowBase + c] ?? 0; + if (v <= 0) continue; + + const wx = baseX + s.rowDir.x * (c * s.colSpacingDsMm); + const wy = baseY + s.rowDir.y * (c * s.colSpacingDsMm); + const wz = baseZ + s.rowDir.z * (c * s.colSpacingDsMm); + + const p = v3(wx, wy, wz); + if (!isWithinBoundsMm(p, roiBounds)) continue; + + obs.push(v); + pos.push(wx, wy, wz); + usedThisSlice++; + + if (usedThisSlice >= perSliceTarget) break; + if (obs.length >= maxN) break; + } + + if (usedThisSlice >= perSliceTarget) break; + if (obs.length >= maxN) break; + } + + if (obs.length >= maxN) break; + } + + return { + obs: Float32Array.from(obs), + pos: Float32Array.from(pos), + count: obs.length, + }; +} + +// ============================================================================ +// Registration scoring +// ============================================================================ + +/** + * Computes Normalized Cross-Correlation (NCC) between series samples + * and a reference volume, given a candidate rigid transform. + * + * NCC is defined as: cov(A,B) / sqrt(var(A) * var(B)) + * where A = observed intensities, B = sampled volume intensities. + * + * @returns NCC score (higher is better, max 1.0) and count of valid samples + */ +export function scoreNcc(params: { + samples: SeriesSamples; + refVolume: Float32Array; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + centerMm: Vec3; + rigid: RigidParams; +}): { ncc: number; used: number } { + const { samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid } = params; + + if (samples.count <= 0) return { ncc: Number.NEGATIVE_INFINITY, used: 0 }; + + const rot = mat3FromEulerXYZ(rigid.rx, rigid.ry, rigid.rz); + const tMm = v3(rigid.tx, rigid.ty, rigid.tz); + + const invVox = 1 / voxelSizeMm; + + let sumA = 0; + let sumB = 0; + let sumAA = 0; + let sumBB = 0; + let sumAB = 0; + let used = 0; + + const obs = samples.obs; + const pos = samples.pos; + + for (let i = 0; i < samples.count; i++) { + const a = obs[i] ?? 0; + const x = pos[i * 3] ?? 0; + const y = pos[i * 3 + 1] ?? 0; + const z = pos[i * 3 + 2] ?? 0; + + // Apply candidate rigid transform about ROI center. + const p = applyRigidToPoint(v3(x, y, z), centerMm, rot, tMm); + + const vx = (p.x - originMm.x) * invVox; + const vy = (p.y - originMm.y) * invVox; + const vz = (p.z - originMm.z) * invVox; + + if (!withinTrilinearSupport(dims, vx, vy, vz)) continue; + + const b = sampleTrilinear(refVolume, dims, vx, vy, vz); + + sumA += a; + sumB += b; + sumAA += a * a; + sumBB += b * b; + sumAB += a * b; + used++; + } + + // Require minimum samples for reliable optimization + const MIN_SAMPLES_FOR_OPTIMIZATION = 512; + if (used < MIN_SAMPLES_FOR_OPTIMIZATION) { + return { ncc: Number.NEGATIVE_INFINITY, used }; + } + + const invN = 1 / used; + const cov = sumAB - sumA * sumB * invN; + const varA = sumAA - sumA * sumA * invN; + const varB = sumBB - sumB * sumB * invN; + + const denom = Math.sqrt(Math.max(1e-12, varA * varB)); + const ncc = denom > 0 ? cov / denom : Number.NEGATIVE_INFINITY; + + return { ncc, used }; +} + +// ============================================================================ +// Optimization +// ============================================================================ + +/** + * Optimizes rigid transform parameters to maximize NCC with reference volume. + * + * Uses coordinate descent with multi-scale step sizes: + * 1. Coarse: 2mm translation, 2° rotation + * 2. Medium: 1mm translation, 1° rotation + * 3. Fine: 0.5mm translation, 0.5° rotation + * + * The search is bounded to prevent unreasonable transforms: + * - Max translation: ±20mm per axis + * - Max rotation: ±10° per axis + * + * @returns Best transform found, its score, and optimization statistics + */ +export async function optimizeRigidNcc(params: { + samples: SeriesSamples; + refVolume: Float32Array; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + centerMm: Vec3; + signal?: AbortSignal; +}): Promise<{ best: RigidParams; bestScore: number; used: number; evals: number }> { + const { samples, refVolume, dims, originMm, voxelSizeMm, centerMm, signal } = params; + + // Search bounds - assumes coarse alignment got us "close" + const MAX_TRANS_MM = 20; + const MAX_ROT_RAD = (10 * Math.PI) / 180; + + // Multi-scale optimization stages (coarse to fine) + const stages = [ + { transStepMm: 2.0, rotStepRad: (2 * Math.PI) / 180 }, + { transStepMm: 1.0, rotStepRad: (1 * Math.PI) / 180 }, + { transStepMm: 0.5, rotStepRad: (0.5 * Math.PI) / 180 }, + ]; + + let cur: RigidParams = { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }; + const bestEval = scoreNcc({ samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid: cur }); + let bestScore = bestEval.ncc; + let bestUsed = bestEval.used; + let evals = 1; + + const tryUpdate = (next: RigidParams): boolean => { + const e = scoreNcc({ samples, refVolume, dims, originMm, voxelSizeMm, centerMm, rigid: next }); + evals++; + if (e.ncc > bestScore + 1e-4) { + cur = next; + bestScore = e.ncc; + bestUsed = e.used; + return true; + } + return false; + }; + + for (const stage of stages) { + let improved = true; + let iter = 0; + const MAX_ITERATIONS_PER_STAGE = 20; + + while (improved && iter < MAX_ITERATIONS_PER_STAGE) { + assertNotAborted(signal); + improved = false; + iter++; + + const t = stage.transStepMm; + const r = stage.rotStepRad; + + const candidates: Array = ['tx', 'ty', 'tz', 'rx', 'ry', 'rz']; + + for (const key of candidates) { + const step = key.startsWith('t') ? t : r; + const maxVal = key.startsWith('t') ? MAX_TRANS_MM : MAX_ROT_RAD; + + const plus: RigidParams = { ...cur }; + const minus: RigidParams = { ...cur }; + (plus as Record)[key] = clampAbs(cur[key] + step, maxVal); + (minus as Record)[key] = clampAbs(cur[key] - step, maxVal); + + if (tryUpdate(plus)) improved = true; + if (tryUpdate(minus)) improved = true; + + // Yield periodically to avoid blocking the main thread + if (evals % 25 === 0) { + await yieldToMain(); + } + } + } + } + + return { best: cur, bestScore, used: bestUsed, evals }; +} + +// ============================================================================ +// Main registration function +// ============================================================================ + +/** + * Performs ROI-constrained rigid registration for all non-reference series. + * + * Algorithm: + * 1. Group slices by series + * 2. Pick a reference series (preferably the ROI source series, or largest) + * 3. For each non-reference series: + * a. Build a reference volume from all OTHER series + * b. Extract samples from the moving series within ROI + * c. Optimize rigid transform to maximize NCC + * d. If improved, apply transform to moving series slices + * + * This approach handles the "leave-one-out" registration problem where + * we can't include the moving series in its own reference volume. + */ +export async function rigidAlignSeriesInRoi(params: { + allSlices: LoadedSlice[]; + selectedSeries: SvrSelectedSeries[]; + roiBounds: BoundsMm; + dims: VolumeDims; + originMm: Vec3; + voxelSizeMm: number; + roi: SvrRoi; + signal?: AbortSignal; + onProgress?: (p: SvrProgress) => void; + debug: boolean; +}): Promise { + const { allSlices, selectedSeries, roiBounds, dims, originMm, voxelSizeMm, roi, signal, onProgress, debug } = params; + + // Group slices by series for independent processing + const bySeries = new Map(); + for (const s of allSlices) { + const arr = bySeries.get(s.seriesUid); + if (arr) arr.push(s); + else bySeries.set(s.seriesUid, [s]); + } + + // Build label lookup for logging + const labelByUid = new Map(); + for (const s of selectedSeries) labelByUid.set(s.seriesUid, s.label); + + // Select reference series: + // - Prefer the ROI source series (keeps ROI coordinates stable) + // - Fallback to series with most slices (most data = most stable reference) + const roiReferenceUid = roi.sourceSeriesUid ?? null; + let referenceUid: string | null = null; + + if (roiReferenceUid && bySeries.has(roiReferenceUid)) { + referenceUid = roiReferenceUid; + } else { + let bestCount = -1; + for (const [uid, arr] of bySeries) { + if (arr.length > bestCount) { + referenceUid = uid; + bestCount = arr.length; + } + } + } + + const centerMm = boundsCenterMm(roiBounds); + + debugSvrLog( + 'registration.roi-rigid.plan', + { + referenceUid, + centerMm: { x: Number(centerMm.x.toFixed(3)), y: Number(centerMm.y.toFixed(3)), z: Number(centerMm.z.toFixed(3)) }, + dims, + voxelSizeMm: Number(voxelSizeMm.toFixed(4)), + }, + debug + ); + + // Align each non-reference series + const seriesUids = Array.from(bySeries.keys()); + for (let idx = 0; idx < seriesUids.length; idx++) { + assertNotAborted(signal); + + const uid = seriesUids[idx]; + if (!uid) continue; + if (referenceUid && uid === referenceUid) continue; + + const movingSlices = bySeries.get(uid); + if (!movingSlices || movingSlices.length === 0) continue; + + onProgress?.({ + phase: 'initializing', + current: 57, + total: 100, + message: `ROI rigid align… (${labelByUid.get(uid) ?? uid})`, + }); + + // Build reference volume from all OTHER series (leave-one-out) + const otherSlices: LoadedSlice[] = []; + for (const [otherUid, slices] of bySeries) { + if (otherUid === uid) continue; + otherSlices.push(...slices); + } + + if (otherSlices.length === 0) continue; + + // Quick reconstruction for scoring (no iterations, basic settings) + const refGrid: SvrReconstructionGrid = { dims, originMm, voxelSizeMm }; + const refOptions: SvrReconstructionOptions = { + iterations: 0, + stepSize: 0, + clampOutput: true, + psfMode: 'none', + robustLoss: 'none', + robustDelta: 0.1, + laplacianWeight: 0, + }; + + const refVol = await reconstructVolumeFromSlices({ + slices: otherSlices, + grid: refGrid, + options: refOptions, + hooks: { signal, yieldToMain }, + }); + + // Extract samples from moving series within ROI + const MAX_SAMPLES_FOR_REGISTRATION = 40_000; + const samples = buildSeriesSamples({ + slices: movingSlices, + roiBounds: roiBounds, + maxSamples: MAX_SAMPLES_FOR_REGISTRATION, + signal, + }); + + const MIN_SAMPLES_TO_REGISTER = 1024; + if (samples.count < MIN_SAMPLES_TO_REGISTER) { + console.warn('[svr] ROI rigid alignment: too few samples inside ROI; skipping series', { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + samples: samples.count, + }); + continue; + } + + // Score before optimization + const before = scoreNcc({ + samples, + refVolume: refVol, + dims, + originMm, + voxelSizeMm, + centerMm, + rigid: { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }, + }); + + // Optimize + const opt = await optimizeRigidNcc({ samples, refVolume: refVol, dims, originMm, voxelSizeMm, centerMm, signal }); + + // Score after optimization + const after = scoreNcc({ + samples, + refVolume: refVol, + dims, + originMm, + voxelSizeMm, + centerMm, + rigid: opt.best, + }); + + // Only apply if score actually improved + const MIN_NCC_IMPROVEMENT = 1e-3; + if (!(after.ncc > before.ncc + MIN_NCC_IMPROVEMENT)) { + debugSvrLog( + 'registration.roi-rigid.skip', + { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + nccBefore: before.ncc, + nccAfter: after.ncc, + used: after.used, + }, + debug + ); + continue; + } + + // Apply the optimized transform + const rot = mat3FromEulerXYZ(opt.best.rx, opt.best.ry, opt.best.rz); + const tMm = v3(opt.best.tx, opt.best.ty, opt.best.tz); + + applyRigidToSeriesSlices({ slices: movingSlices, centerMm, rot, tMm }); + + console.info('[svr] ROI rigid series alignment applied', { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + nccBefore: Number(before.ncc.toFixed(4)), + nccAfter: Number(after.ncc.toFixed(4)), + usedSamples: after.used, + evals: opt.evals, + translateMm: { + x: Number(opt.best.tx.toFixed(3)), + y: Number(opt.best.ty.toFixed(3)), + z: Number(opt.best.tz.toFixed(3)), + }, + rotateDeg: { + x: Number((opt.best.rx * (180 / Math.PI)).toFixed(3)), + y: Number((opt.best.ry * (180 / Math.PI)).toFixed(3)), + z: Number((opt.best.rz * (180 / Math.PI)).toFixed(3)), + }, + }); + + debugSvrLog( + 'registration.roi-rigid', + { + seriesUid: uid, + label: labelByUid.get(uid) ?? uid, + samples: samples.count, + usedSamples: after.used, + nccBefore: before.ncc, + nccAfter: after.ncc, + evals: opt.evals, + translateMm: { x: opt.best.tx, y: opt.best.ty, z: opt.best.tz }, + rotateRad: { x: opt.best.rx, y: opt.best.ry, z: opt.best.rz }, + }, + debug + ); + + await yieldToMain(); + } +} diff --git a/frontend/src/utils/svr/svrUtils.ts b/frontend/src/utils/svr/svrUtils.ts new file mode 100644 index 0000000..d71e523 --- /dev/null +++ b/frontend/src/utils/svr/svrUtils.ts @@ -0,0 +1,95 @@ +/** + * Shared utility functions for SVR (Slice-to-Volume Reconstruction). + * + * These utilities are used across multiple SVR modules to avoid duplication + * and ensure consistent behavior. + */ + +import type { VolumeDims } from './trilinear'; + +/** + * Clamps a value to the [0, 1] range. + * Used for normalizing intensity values. + */ +export function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +/** + * Throws if the provided AbortSignal has been aborted. + * Used throughout async SVR operations to support cancellation. + */ +export function assertNotAborted(signal?: AbortSignal): void { + if (signal?.aborted) { + throw new Error('SVR cancelled'); + } +} + +/** + * Yields control back to the main thread to prevent UI blocking. + * Should be called periodically during long-running SVR computations. + */ +export function yieldToMain(): Promise { + return new Promise((resolve) => setTimeout(resolve, 0)); +} + +/** + * Checks if a voxel coordinate is within the support of trilinear interpolation. + * + * For trilinear sampling/splatting, we need both the floor and ceil of each + * coordinate to be valid indices: + * - floor >= 0 + * - ceil < dim + * + * This is equivalent to: 0 <= coord < dim - 1 + * + * @param dims - Volume dimensions + * @param x - X coordinate in voxel space + * @param y - Y coordinate in voxel space + * @param z - Z coordinate in voxel space + * @returns true if the coordinate is within valid trilinear interpolation bounds + */ +export function withinTrilinearSupport(dims: VolumeDims, x: number, y: number, z: number): boolean { + return x >= 0 && y >= 0 && z >= 0 && x < dims.nx - 1 && y < dims.ny - 1 && z < dims.nz - 1; +} + +/** + * Clamps a number's absolute value to a maximum. + * Returns 0 for non-finite inputs. + * + * @param x - Value to clamp + * @param maxAbs - Maximum absolute value allowed + */ +export function clampAbs(x: number, maxAbs: number): number { + if (!Number.isFinite(x)) return 0; + if (!Number.isFinite(maxAbs) || maxAbs <= 0) return 0; + return x < -maxAbs ? -maxAbs : x > maxAbs ? maxAbs : x; +} + +/** + * Formats a byte count as a human-readable MiB string. + */ +export function formatMiB(bytes: number): string { + return `${Math.round(bytes / (1024 * 1024))}MiB`; +} + +/** + * Computes a quantile value from a pre-sorted array. + * + * @param sorted - Array of numbers, already sorted ascending + * @param q - Quantile in [0, 1] (e.g., 0.5 for median) + * @returns Interpolated quantile value + */ +export function quantileSorted(sorted: number[], q: number): number { + const n = sorted.length; + if (n === 0) return 0; + + const qq = q < 0 ? 0 : q > 1 ? 1 : q; + const idx = qq * (n - 1); + const i0 = Math.floor(idx); + const i1 = Math.min(n - 1, i0 + 1); + const t = idx - i0; + const a = sorted[i0] ?? 0; + const b = sorted[i1] ?? a; + return a + (b - a) * t; +} diff --git a/frontend/tests/svrRigidRegistration.test.ts b/frontend/tests/svrRigidRegistration.test.ts new file mode 100644 index 0000000..2b36de9 --- /dev/null +++ b/frontend/tests/svrRigidRegistration.test.ts @@ -0,0 +1,293 @@ +/** + * Tests for SVR rigid registration module. + * + * These tests verify the correctness of: + * - Euler angle to rotation matrix conversion + * - Rigid transform application (rotation + translation) + * - Normalized cross-correlation (NCC) scoring + * - Optimization convergence + */ + +import { describe, expect, it } from 'vitest'; +import { + mat3FromEulerXYZ, + mat3MulVec3, + applyRigidToPoint, + boundsCenterMm, + scoreNcc, +} from '../src/utils/svr/rigidRegistration'; +import type { SeriesSamples, BoundsMm } from '../src/utils/svr/rigidRegistration'; +import { v3 } from '../src/utils/svr/vec3'; + +describe('svr/rigidRegistration', () => { + describe('mat3FromEulerXYZ', () => { + it('produces identity matrix for zero angles', () => { + const m = mat3FromEulerXYZ(0, 0, 0); + + // Identity matrix: diagonal 1s, off-diagonal 0s + expect(m[0]).toBeCloseTo(1); // m00 + expect(m[4]).toBeCloseTo(1); // m11 + expect(m[8]).toBeCloseTo(1); // m22 + + expect(m[1]).toBeCloseTo(0); // m01 + expect(m[2]).toBeCloseTo(0); // m02 + expect(m[3]).toBeCloseTo(0); // m10 + expect(m[5]).toBeCloseTo(0); // m12 + expect(m[6]).toBeCloseTo(0); // m20 + expect(m[7]).toBeCloseTo(0); // m21 + }); + + it('rotates 90° about X axis correctly', () => { + const m = mat3FromEulerXYZ(Math.PI / 2, 0, 0); + + // After 90° X rotation: Y → Z, Z → -Y + const v = mat3MulVec3(m, 0, 1, 0); // Rotate unit Y + expect(v.x).toBeCloseTo(0); + expect(v.y).toBeCloseTo(0); + expect(v.z).toBeCloseTo(1); + }); + + it('rotates 90° about Y axis correctly', () => { + const m = mat3FromEulerXYZ(0, Math.PI / 2, 0); + + // After 90° Y rotation: X → -Z, Z → X + const v = mat3MulVec3(m, 1, 0, 0); // Rotate unit X + expect(v.x).toBeCloseTo(0); + expect(v.y).toBeCloseTo(0); + expect(v.z).toBeCloseTo(-1); + }); + + it('rotates 90° about Z axis correctly', () => { + const m = mat3FromEulerXYZ(0, 0, Math.PI / 2); + + // After 90° Z rotation: X → Y, Y → -X + const v = mat3MulVec3(m, 1, 0, 0); // Rotate unit X + expect(v.x).toBeCloseTo(0); + expect(v.y).toBeCloseTo(1); + expect(v.z).toBeCloseTo(0); + }); + + it('produces orthonormal matrix for arbitrary angles', () => { + const m = mat3FromEulerXYZ(0.3, 0.5, 0.7); + + // Check that columns are unit vectors + const col0 = Math.sqrt(m[0] ** 2 + m[3] ** 2 + m[6] ** 2); + const col1 = Math.sqrt(m[1] ** 2 + m[4] ** 2 + m[7] ** 2); + const col2 = Math.sqrt(m[2] ** 2 + m[5] ** 2 + m[8] ** 2); + + expect(col0).toBeCloseTo(1); + expect(col1).toBeCloseTo(1); + expect(col2).toBeCloseTo(1); + + // Check that columns are orthogonal (dot products = 0) + const dot01 = m[0] * m[1] + m[3] * m[4] + m[6] * m[7]; + const dot02 = m[0] * m[2] + m[3] * m[5] + m[6] * m[8]; + const dot12 = m[1] * m[2] + m[4] * m[5] + m[7] * m[8]; + + expect(dot01).toBeCloseTo(0); + expect(dot02).toBeCloseTo(0); + expect(dot12).toBeCloseTo(0); + }); + }); + + describe('applyRigidToPoint', () => { + it('returns same point when no rotation or translation', () => { + const p = v3(5, 10, 15); + const center = v3(0, 0, 0); + const rot = mat3FromEulerXYZ(0, 0, 0); + const t = v3(0, 0, 0); + + const result = applyRigidToPoint(p, center, rot, t); + + expect(result.x).toBeCloseTo(5); + expect(result.y).toBeCloseTo(10); + expect(result.z).toBeCloseTo(15); + }); + + it('applies translation only (no rotation)', () => { + const p = v3(5, 10, 15); + const center = v3(0, 0, 0); + const rot = mat3FromEulerXYZ(0, 0, 0); + const t = v3(1, 2, 3); + + const result = applyRigidToPoint(p, center, rot, t); + + expect(result.x).toBeCloseTo(6); + expect(result.y).toBeCloseTo(12); + expect(result.z).toBeCloseTo(18); + }); + + it('rotates about center point correctly', () => { + // Point on the X axis, 10 units from center + const p = v3(10, 0, 0); + const center = v3(0, 0, 0); + const rot = mat3FromEulerXYZ(0, 0, Math.PI / 2); // 90° about Z + const t = v3(0, 0, 0); + + const result = applyRigidToPoint(p, center, rot, t); + + // After 90° Z rotation: (10,0,0) → (0,10,0) + expect(result.x).toBeCloseTo(0); + expect(result.y).toBeCloseTo(10); + expect(result.z).toBeCloseTo(0); + }); + + it('rotates about non-origin center correctly', () => { + // Point at (20, 10, 0), center at (10, 10, 0) + // Offset from center is (10, 0, 0) + const p = v3(20, 10, 0); + const center = v3(10, 10, 0); + const rot = mat3FromEulerXYZ(0, 0, Math.PI / 2); // 90° about Z + const t = v3(0, 0, 0); + + const result = applyRigidToPoint(p, center, rot, t); + + // After 90° Z rotation about (10,10,0): offset (10,0,0) → (0,10,0) + // Final position: (10,10,0) + (0,10,0) = (10,20,0) + expect(result.x).toBeCloseTo(10); + expect(result.y).toBeCloseTo(20); + expect(result.z).toBeCloseTo(0); + }); + + it('combines rotation and translation correctly', () => { + const p = v3(10, 0, 0); + const center = v3(0, 0, 0); + const rot = mat3FromEulerXYZ(0, 0, Math.PI / 2); + const t = v3(5, 5, 0); + + const result = applyRigidToPoint(p, center, rot, t); + + // (10,0,0) rotated 90° about Z → (0,10,0), then translated by (5,5,0) → (5,15,0) + expect(result.x).toBeCloseTo(5); + expect(result.y).toBeCloseTo(15); + expect(result.z).toBeCloseTo(0); + }); + }); + + describe('boundsCenterMm', () => { + it('computes center of axis-aligned box', () => { + const bounds: BoundsMm = { + min: v3(0, 0, 0), + max: v3(10, 20, 30), + }; + + const center = boundsCenterMm(bounds); + + expect(center.x).toBeCloseTo(5); + expect(center.y).toBeCloseTo(10); + expect(center.z).toBeCloseTo(15); + }); + + it('handles negative coordinates', () => { + const bounds: BoundsMm = { + min: v3(-10, -20, -30), + max: v3(10, 20, 30), + }; + + const center = boundsCenterMm(bounds); + + expect(center.x).toBeCloseTo(0); + expect(center.y).toBeCloseTo(0); + expect(center.z).toBeCloseTo(0); + }); + }); + + describe('scoreNcc', () => { + it('returns -Infinity for empty samples', () => { + const samples: SeriesSamples = { + obs: new Float32Array(0), + pos: new Float32Array(0), + count: 0, + }; + + const dims = { nx: 10, ny: 10, nz: 10 }; + const volume = new Float32Array(dims.nx * dims.ny * dims.nz); + + const result = scoreNcc({ + samples, + refVolume: volume, + dims, + originMm: v3(0, 0, 0), + voxelSizeMm: 1, + centerMm: v3(5, 5, 5), + rigid: { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }, + }); + + expect(result.ncc).toBe(Number.NEGATIVE_INFINITY); + expect(result.used).toBe(0); + }); + + it('returns -Infinity when too few samples are in bounds', () => { + // Create a small number of samples (less than the MIN_SAMPLES threshold of 512) + const samples: SeriesSamples = { + obs: new Float32Array([0.5, 0.6, 0.7]), + pos: new Float32Array([1, 1, 1, 2, 2, 2, 3, 3, 3]), + count: 3, + }; + + const dims = { nx: 10, ny: 10, nz: 10 }; + const volume = new Float32Array(dims.nx * dims.ny * dims.nz); + + const result = scoreNcc({ + samples, + refVolume: volume, + dims, + originMm: v3(0, 0, 0), + voxelSizeMm: 1, + centerMm: v3(5, 5, 5), + rigid: { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }, + }); + + expect(result.ncc).toBe(Number.NEGATIVE_INFINITY); + expect(result.used).toBeLessThan(512); + }); + + it('returns high NCC for identical signals', () => { + const dims = { nx: 20, ny: 20, nz: 20 }; + const volume = new Float32Array(dims.nx * dims.ny * dims.nz); + + // Fill volume with a gradient + for (let z = 0; z < dims.nz; z++) { + for (let y = 0; y < dims.ny; y++) { + for (let x = 0; x < dims.nx; x++) { + const idx = x + y * dims.nx + z * dims.nx * dims.ny; + volume[idx] = (x + y + z) / (dims.nx + dims.ny + dims.nz); + } + } + } + + // Create samples that match the volume exactly (large enough to pass threshold) + const obs: number[] = []; + const pos: number[] = []; + for (let z = 2; z < dims.nz - 2; z += 2) { + for (let y = 2; y < dims.ny - 2; y += 2) { + for (let x = 2; x < dims.nx - 2; x += 2) { + const idx = x + y * dims.nx + z * dims.nx * dims.ny; + obs.push(volume[idx] ?? 0); + pos.push(x, y, z); + } + } + } + + const samples: SeriesSamples = { + obs: Float32Array.from(obs), + pos: Float32Array.from(pos), + count: obs.length, + }; + + const result = scoreNcc({ + samples, + refVolume: volume, + dims, + originMm: v3(0, 0, 0), + voxelSizeMm: 1, + centerMm: v3(10, 10, 10), + rigid: { tx: 0, ty: 0, tz: 0, rx: 0, ry: 0, rz: 0 }, + }); + + // NCC of identical signals should be 1 (or very close) + expect(result.ncc).toBeGreaterThan(0.99); + expect(result.used).toBeGreaterThan(100); + }); + }); +}); From e2c11ba08186396c6e77ff6500534e029d738def Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 13:16:46 -0800 Subject: [PATCH 03/16] svr3d: add label overlay rendering --- frontend/src/components/SvrVolume3DViewer.tsx | 315 ++++++++++++++++-- frontend/src/types/svr.ts | 22 ++ .../src/utils/segmentation/labelPalette.ts | 47 +++ frontend/tests/labelPalette.test.ts | 44 +++ 4 files changed, 408 insertions(+), 20 deletions(-) create mode 100644 frontend/src/utils/segmentation/labelPalette.ts create mode 100644 frontend/tests/labelPalette.test.ts diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index 0f000fb..1ff4478 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -1,6 +1,7 @@ import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'; import { ChevronLeft, ChevronRight } from 'lucide-react'; -import type { SvrVolume } from '../types/svr'; +import type { SvrLabelVolume, SvrVolume } from '../types/svr'; +import { buildRgbaPalette256, rgbCss } from '../utils/segmentation/labelPalette'; import { resample2dAreaAverage } from '../utils/svr/resample2d'; function clamp(x: number, min: number, max: number): number { @@ -449,6 +450,7 @@ function createProgram(gl: WebGL2RenderingContext, vsSrc: string, fsSrc: string) export type SvrVolume3DViewerProps = { volume: SvrVolume | null; + labels?: SvrLabelVolume | null; }; export type SvrVolume3DViewerHandle = { @@ -459,13 +461,23 @@ export type SvrVolume3DViewerHandle = { }; export const SvrVolume3DViewer = forwardRef(function SvrVolume3DViewer( - { volume }, + { volume, labels }, ref ) { const canvasRef = useRef(null); const axesCanvasRef = useRef(null); const pendingCapture3dRef = useRef<{ resolve: (b: Blob | null) => void } | null>(null); + const glLabelStateRef = useRef< + | { + gl: WebGL2RenderingContext; + texLabels: WebGLTexture; + texPalette: WebGLTexture; + dims: { nx: number; ny: number; nz: number }; + } + | null + >(null); + const [initError, setInitError] = useState(null); // Viewer controls (composite-only) @@ -476,15 +488,30 @@ export const SvrVolume3DViewer = forwardRef { + if (!volume) return false; + if (!labels) return false; + + const [nx, ny, nz] = volume.dims; + const [lx, ly, lz] = labels.dims; + if (nx !== lx || ny !== ly || nz !== lz) return false; + + return labels.data.length === nx * ny * nz; + }, [labels, volume]); + // Slice inspector (orthogonal slices). const sliceCanvasRef = useRef(null); const [inspectPlane, setInspectPlane] = useState<'axial' | 'coronal' | 'sagittal'>('axial'); const [inspectIndex, setInspectIndex] = useState(0); - const paramsRef = useRef({ threshold, steps, gamma, opacity, zoom }); + const paramsRef = useRef({ threshold, steps, gamma, opacity, zoom, labelsEnabled, labelMix, hasLabels }); useEffect(() => { - paramsRef.current = { threshold, steps, gamma, opacity, zoom }; - }, [gamma, opacity, steps, threshold, zoom]); + paramsRef.current = { threshold, steps, gamma, opacity, zoom, labelsEnabled, labelMix, hasLabels }; + }, [gamma, hasLabels, labelMix, labelsEnabled, opacity, steps, threshold, zoom]); const rotationRef = useRef([0, 0, 0, 1]); @@ -739,19 +766,66 @@ export const SvrVolume3DViewer = forwardRef 0 && labels) { + const px = i % dsCols; + const py = Math.floor(i / dsCols); + + const srcX = dsCols > 1 ? Math.round((px / (dsCols - 1)) * (srcCols - 1)) : 0; + const srcY = dsRows > 1 ? Math.round((py / (dsRows - 1)) * (srcRows - 1)) : 0; + + let vx = 0; + let vy = 0; + let vz = 0; + + if (inspectPlane === 'axial') { + vx = srcX; + vy = srcY; + vz = idx; + } else if (inspectPlane === 'coronal') { + vx = srcX; + vy = idx; + vz = srcY; + } else { + // sagittal + vx = idx; + vy = srcX; + vz = srcY; + } + + const labelId = labels.data[vz * strideZ + vy * strideY + vx] ?? 0; + if (labelId !== 0) { + const o = labelId * 4; + const lr = palette[o] ?? 0; + const lg = palette[o + 1] ?? 0; + const lb = palette[o + 2] ?? 0; + + const a = overlayAlpha; + r = Math.round((1 - a) * r + a * lr); + g = Math.round((1 - a) * g + a * lg); + b = Math.round((1 - a) * b + a * lb); + } + } const j = i * 4; - out[j] = b; - out[j + 1] = b; + out[j] = r; + out[j + 1] = g; out[j + 2] = b; out[j + 3] = 255; } ctx.putImageData(img, 0, 0); - }, [inspectIndex, inspectPlane, inspectorInfo.maxIndex, inspectorInfo.srcCols, inspectorInfo.srcRows, volume]); + }, [hasLabels, inspectIndex, inspectPlane, inspectorInfo.maxIndex, inspectorInfo.srcCols, inspectorInfo.srcRows, labelMix, labels, labelsEnabled, volume]); useEffect(() => { setInitError(null); @@ -791,11 +865,18 @@ void main() { const fsSrc = `#version 300 es precision highp float; precision highp sampler3D; +precision highp usampler3D; +precision highp sampler2D; in vec2 v_uv; out vec4 outColor; uniform sampler3D u_vol; +uniform usampler3D u_labels; +uniform sampler2D u_palette; +uniform int u_labelsEnabled; +uniform float u_labelMix; + uniform mat3 u_rot; uniform vec3 u_box; uniform float u_aspect; @@ -873,7 +954,7 @@ void main() { const float EDGE_K = 14.0; const float CENTER_EDGE_GAIN = 2.5; - float accum = 0.0; + vec3 accum = vec3(0.0); float aAccum = 0.0; float t = max(t0, 0.0); @@ -938,7 +1019,18 @@ void main() { float sampleV = v * shade * (0.6 + 0.4 * edge); - accum += (1.0 - aAccum) * sampleV * aStep; + vec3 sampleColor = vec3(sampleV); + + if (u_labelsEnabled != 0) { + uint lid = texture(u_labels, tc).r; + if (lid != 0u) { + vec3 labelRgb = texelFetch(u_palette, ivec2(int(lid), 0), 0).rgb; + float mixK = clamp(u_labelMix, 0.0, 1.0); + sampleColor = mix(sampleColor, labelRgb, mixK); + } + } + + accum += (1.0 - aAccum) * sampleColor * aStep; aAccum += (1.0 - aAccum) * aStep; if (aAccum > 0.98) { @@ -947,13 +1039,15 @@ void main() { } } - outColor = vec4(vec3(saturate(accum)), 1.0); + outColor = vec4(clamp(accum, 0.0, 1.0), 1.0); }`; let program: WebGLProgram | null = null; let vao: WebGLVertexArrayObject | null = null; let vbo: WebGLBuffer | null = null; - let tex: WebGLTexture | null = null; + let texVol: WebGLTexture | null = null; + let texLabels: WebGLTexture | null = null; + let texPalette: WebGLTexture | null = null; let raf = 0; try { @@ -979,11 +1073,11 @@ void main() { gl.bindBuffer(gl.ARRAY_BUFFER, null); // Volume texture (prefer float for fidelity; fall back to 8-bit for compatibility) - tex = gl.createTexture(); - if (!tex) throw new Error('Failed to allocate 3D texture'); + texVol = gl.createTexture(); + if (!texVol) throw new Error('Failed to allocate 3D texture'); gl.activeTexture(gl.TEXTURE0); - gl.bindTexture(gl.TEXTURE_3D, tex); + gl.bindTexture(gl.TEXTURE_3D, texVol); gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); // We'll try float first; if WebGL rejects it, re-upload as R8. @@ -1035,8 +1129,63 @@ void main() { gl.bindTexture(gl.TEXTURE_3D, null); + // Label texture (uint8 IDs). We always allocate a valid texture to keep the shader path stable, + // even when no segmentation is present yet. + texLabels = gl.createTexture(); + if (!texLabels) throw new Error('Failed to allocate label 3D texture'); + + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_3D, texLabels); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_R, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); + gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); + + // Initialize to zeros so sampling produces "no label" deterministically. + const zeros = new Uint8Array(dims.nx * dims.ny * dims.nz); + gl.texImage3D( + gl.TEXTURE_3D, + 0, + gl.R8UI, + dims.nx, + dims.ny, + dims.nz, + 0, + gl.RED_INTEGER, + gl.UNSIGNED_BYTE, + zeros + ); + + gl.bindTexture(gl.TEXTURE_3D, null); + + // Palette texture: 256x1 RGBA8 lookup table for label->color. + texPalette = gl.createTexture(); + if (!texPalette) throw new Error('Failed to allocate label palette texture'); + + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, texPalette); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); + + gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA8, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE, new Uint8Array(256 * 4)); + gl.bindTexture(gl.TEXTURE_2D, null); + + glLabelStateRef.current = { gl, texLabels, texPalette, dims }; + const u = { vol: gl.getUniformLocation(program, 'u_vol'), + labels: gl.getUniformLocation(program, 'u_labels'), + palette: gl.getUniformLocation(program, 'u_palette'), + labelsEnabled: gl.getUniformLocation(program, 'u_labelsEnabled'), + labelMix: gl.getUniformLocation(program, 'u_labelMix'), + rot: gl.getUniformLocation(program, 'u_rot'), box: gl.getUniformLocation(program, 'u_box'), aspect: gl.getUniformLocation(program, 'u_aspect'), @@ -1076,7 +1225,7 @@ void main() { const draw = () => { resizeAndViewport(); - const { threshold, steps, gamma, opacity, zoom } = paramsRef.current; + const { threshold, steps, gamma, opacity, zoom, labelsEnabled, labelMix, hasLabels } = paramsRef.current; gl.disable(gl.DEPTH_TEST); gl.disable(gl.CULL_FACE); @@ -1084,11 +1233,23 @@ void main() { gl.useProgram(program); gl.bindVertexArray(vao); - // Bind texture + // Bind textures gl.activeTexture(gl.TEXTURE0); - gl.bindTexture(gl.TEXTURE_3D, tex); + gl.bindTexture(gl.TEXTURE_3D, texVol); gl.uniform1i(u.vol, 0); + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_3D, texLabels); + gl.uniform1i(u.labels, 1); + + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, texPalette); + gl.uniform1i(u.palette, 2); + + const labelsOn = labelsEnabled && hasLabels ? 1 : 0; + gl.uniform1i(u.labelsEnabled, labelsOn); + gl.uniform1f(u.labelMix, clamp(labelMix, 0, 1)); + // Uniforms mat3FromQuat(rotationRef.current, rotMat); gl.uniformMatrix3fv(u.rot, false, rotMat); @@ -1144,7 +1305,14 @@ void main() { } } + // Reset bindings (avoid leaking WebGL state across frames). + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, null); + gl.activeTexture(gl.TEXTURE1); gl.bindTexture(gl.TEXTURE_3D, null); + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_3D, null); + gl.bindVertexArray(null); raf = window.requestAnimationFrame(draw); @@ -1165,8 +1333,12 @@ void main() { if (raf) window.cancelAnimationFrame(raf); + glLabelStateRef.current = null; + if (gl) { - if (tex) gl.deleteTexture(tex); + if (texVol) gl.deleteTexture(texVol); + if (texLabels) gl.deleteTexture(texLabels); + if (texPalette) gl.deleteTexture(texPalette); if (vbo) gl.deleteBuffer(vbo); if (vao) gl.deleteVertexArray(vao); if (program) gl.deleteProgram(program); @@ -1174,6 +1346,59 @@ void main() { }; }, [boxScale, dims, volume]); + // Incrementally upload label data + palette without re-initializing the whole GL program. + useEffect(() => { + if (!volume) return; + if (!labels) return; + + if (!hasLabels) { + console.warn('[svr3d] Ignoring label volume (dims mismatch)', { + volumeDims: volume.dims, + labelDims: labels.dims, + labelLen: labels.data.length, + }); + return; + } + + const st = glLabelStateRef.current; + if (!st) return; + + const { gl, texLabels, texPalette, dims } = st; + + try { + // Label IDs + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_3D, texLabels); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + gl.texSubImage3D( + gl.TEXTURE_3D, + 0, + 0, + 0, + 0, + dims.nx, + dims.ny, + dims.nz, + gl.RED_INTEGER, + gl.UNSIGNED_BYTE, + labels.data + ); + gl.bindTexture(gl.TEXTURE_3D, null); + + // Palette lookup table + const rgba = buildRgbaPalette256(labels.meta); + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, texPalette); + gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1); + gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, 256, 1, gl.RGBA, gl.UNSIGNED_BYTE, rgba); + gl.bindTexture(gl.TEXTURE_2D, null); + } catch (e) { + console.warn('[svr3d] Failed to upload label textures', e); + } finally { + gl.activeTexture(gl.TEXTURE0); + } + }, [hasLabels, labels, volume]); + return (
@@ -1297,6 +1522,56 @@ void main() {
{zoom.toFixed(2)}
+
+
Segmentation
+
+ + + + + {!hasLabels || !labels ? ( +
No segmentation labels available yet.
+ ) : ( +
+ {labels.meta + .filter((m) => m.id !== 0) + .map((m) => ( +
+ + {m.name} + {m.id} +
+ ))} +
+ )} +
+
+
+
+ +
+ + + +
+ +
+ + + {growStatus.running ? ( + + ) : null} + + +
+ + {growStatus.error ? ( +
{growStatus.error}
+ ) : growStatus.message ? ( +
{growStatus.message}
+ ) : null} + {!hasLabels || !labels ? (
No segmentation labels available yet.
) : ( @@ -1628,7 +1883,12 @@ void main() {
- +
{volume ? ( diff --git a/frontend/src/utils/segmentation/brats.ts b/frontend/src/utils/segmentation/brats.ts new file mode 100644 index 0000000..10639d7 --- /dev/null +++ b/frontend/src/utils/segmentation/brats.ts @@ -0,0 +1,27 @@ +import type { SvrLabelMeta } from '../../types/svr'; + +/** + * BraTS-style base label IDs. + * + * Common convention: + * - 0: Background + * - 1: Necrotic / non-enhancing tumor core (NCR/NET) + * - 2: Peritumoral edema (ED) + * - 4: Enhancing tumor (ET) + */ +export const BRATS_LABEL_ID = { + BACKGROUND: 0, + NCR_NET: 1, + EDEMA: 2, + ENHANCING: 4, +} as const; + +export type BratsBaseLabelId = (typeof BRATS_LABEL_ID)[keyof typeof BRATS_LABEL_ID]; + +// Colors are arbitrary but chosen to be visually distinct on a dark background. +export const BRATS_BASE_LABEL_META: SvrLabelMeta[] = [ + { id: BRATS_LABEL_ID.BACKGROUND, name: 'Background', color: [0, 0, 0] }, + { id: BRATS_LABEL_ID.NCR_NET, name: 'Tumor core (NCR/NET)', color: [255, 176, 0] }, + { id: BRATS_LABEL_ID.EDEMA, name: 'Edema (ED)', color: [0, 170, 255] }, + { id: BRATS_LABEL_ID.ENHANCING, name: 'Enhancing tumor (ET)', color: [255, 0, 128] }, +]; diff --git a/frontend/src/utils/segmentation/connectedComponents3D.ts b/frontend/src/utils/segmentation/connectedComponents3D.ts new file mode 100644 index 0000000..edf7d6b --- /dev/null +++ b/frontend/src/utils/segmentation/connectedComponents3D.ts @@ -0,0 +1,152 @@ +export type Connectivity3D = 6 | 26; + +function idx3(x: number, y: number, z: number, nx: number, ny: number): number { + return z * (nx * ny) + y * nx + x; +} + +function inBounds(x: number, y: number, z: number, nx: number, ny: number, nz: number): boolean { + return x >= 0 && x < nx && y >= 0 && y < ny && z >= 0 && z < nz; +} + +/** + * Keep only the largest connected component in a 3D binary mask. + * + * This is useful as a cleanup step when region-growing or morphological operations + * leave tiny disconnected islands. + */ +export function keepLargestConnectedComponent3D(params: { + mask: Uint8Array; + dims: [number, number, number]; + connectivity?: Connectivity3D; +}): { mask: Uint8Array; keptSize: number } { + const { mask, dims } = params; + const nx = dims[0]; + const ny = dims[1]; + const nz = dims[2]; + + const n = nx * ny * nz; + if (mask.length !== n) { + throw new Error(`keepLargestConnectedComponent3D: mask length mismatch (expected ${n}, got ${mask.length})`); + } + + const connectivity: Connectivity3D = params.connectivity ?? 6; + + const visited = new Uint8Array(n); + const queue = new Uint32Array(n); + + const strideY = nx; + const strideZ = nx * ny; + + let bestStart = -1; + let bestSize = 0; + + const bfsCount = (start: number): number => { + let head = 0; + let tail = 0; + queue[tail++] = start; + visited[start] = 1; + + let size = 0; + + while (head < tail) { + const i = queue[head++]!; + size++; + + const z = Math.floor(i / strideZ); + const yz = i - z * strideZ; + const y = Math.floor(yz / strideY); + const x = yz - y * strideY; + + const tryNeighbor = (nx0: number, ny0: number, nz0: number) => { + if (!inBounds(nx0, ny0, nz0, nx, ny, nz)) return; + const ni = idx3(nx0, ny0, nz0, nx, ny); + if (visited[ni]) return; + if (!mask[ni]) return; + visited[ni] = 1; + queue[tail++] = ni; + }; + + if (connectivity === 6) { + tryNeighbor(x - 1, y, z); + tryNeighbor(x + 1, y, z); + tryNeighbor(x, y - 1, z); + tryNeighbor(x, y + 1, z); + tryNeighbor(x, y, z - 1); + tryNeighbor(x, y, z + 1); + } else { + for (let dz = -1; dz <= 1; dz++) { + for (let dy = -1; dy <= 1; dy++) { + for (let dx = -1; dx <= 1; dx++) { + if (dx === 0 && dy === 0 && dz === 0) continue; + tryNeighbor(x + dx, y + dy, z + dz); + } + } + } + } + } + + return size; + }; + + for (let i = 0; i < n; i++) { + if (!mask[i] || visited[i]) continue; + const size = bfsCount(i); + if (size > bestSize) { + bestSize = size; + bestStart = i; + } + } + + if (bestStart < 0 || bestSize === 0) { + return { mask: new Uint8Array(n), keptSize: 0 }; + } + + // Second pass: BFS from bestStart to build the kept mask. + visited.fill(0); + const kept = new Uint8Array(n); + + let head = 0; + let tail = 0; + queue[tail++] = bestStart; + visited[bestStart] = 1; + kept[bestStart] = 1; + + while (head < tail) { + const i = queue[head++]!; + + const z = Math.floor(i / strideZ); + const yz = i - z * strideZ; + const y = Math.floor(yz / strideY); + const x = yz - y * strideY; + + const tryNeighbor = (nx0: number, ny0: number, nz0: number) => { + if (!inBounds(nx0, ny0, nz0, nx, ny, nz)) return; + const ni = idx3(nx0, ny0, nz0, nx, ny); + if (visited[ni]) return; + if (!mask[ni]) return; + visited[ni] = 1; + kept[ni] = 1; + queue[tail++] = ni; + }; + + if (connectivity === 6) { + tryNeighbor(x - 1, y, z); + tryNeighbor(x + 1, y, z); + tryNeighbor(x, y - 1, z); + tryNeighbor(x, y + 1, z); + tryNeighbor(x, y, z - 1); + tryNeighbor(x, y, z + 1); + } else { + for (let dz = -1; dz <= 1; dz++) { + for (let dy = -1; dy <= 1; dy++) { + for (let dx = -1; dx <= 1; dx++) { + if (dx === 0 && dy === 0 && dz === 0) continue; + tryNeighbor(x + dx, y + dy, z + dz); + } + } + } + } + } + + return { mask: kept, keptSize: bestSize }; +} diff --git a/frontend/src/utils/segmentation/morphology3D.ts b/frontend/src/utils/segmentation/morphology3D.ts new file mode 100644 index 0000000..981fe35 --- /dev/null +++ b/frontend/src/utils/segmentation/morphology3D.ts @@ -0,0 +1,120 @@ +function idx3(x: number, y: number, z: number, nx: number, ny: number): number { + return z * (nx * ny) + y * nx + x; +} + +/** + * 3D dilation with a 3x3x3 structuring element. + * + * Mask values are expected to be 0 or 1. + */ +export function dilate3x3x3(mask: Uint8Array, dims: [number, number, number]): Uint8Array { + const nx = dims[0]; + const ny = dims[1]; + const nz = dims[2]; + + const n = nx * ny * nz; + if (mask.length !== n) { + throw new Error(`dilate3x3x3: mask length mismatch (expected ${n}, got ${mask.length})`); + } + + const out = new Uint8Array(n); + + for (let z = 0; z < nz; z++) { + for (let y = 0; y < ny; y++) { + for (let x = 0; x < nx; x++) { + let on = 0; + + for (let dz = -1; dz <= 1 && !on; dz++) { + const zz = z + dz; + if (zz < 0 || zz >= nz) continue; + + for (let dy = -1; dy <= 1 && !on; dy++) { + const yy = y + dy; + if (yy < 0 || yy >= ny) continue; + + for (let dx = -1; dx <= 1; dx++) { + const xx = x + dx; + if (xx < 0 || xx >= nx) continue; + if (mask[idx3(xx, yy, zz, nx, ny)]) { + on = 1; + break; + } + } + } + } + + out[idx3(x, y, z, nx, ny)] = on; + } + } + } + + return out; +} + +/** + * 3D erosion with a 3x3x3 structuring element. + * + * Out-of-bounds neighbors are treated as 0. + */ +export function erode3x3x3(mask: Uint8Array, dims: [number, number, number]): Uint8Array { + const nx = dims[0]; + const ny = dims[1]; + const nz = dims[2]; + + const n = nx * ny * nz; + if (mask.length !== n) { + throw new Error(`erode3x3x3: mask length mismatch (expected ${n}, got ${mask.length})`); + } + + const out = new Uint8Array(n); + + for (let z = 0; z < nz; z++) { + for (let y = 0; y < ny; y++) { + for (let x = 0; x < nx; x++) { + let on = 1; + + for (let dz = -1; dz <= 1 && on; dz++) { + const zz = z + dz; + if (zz < 0 || zz >= nz) { + on = 0; + break; + } + + for (let dy = -1; dy <= 1 && on; dy++) { + const yy = y + dy; + if (yy < 0 || yy >= ny) { + on = 0; + break; + } + + for (let dx = -1; dx <= 1; dx++) { + const xx = x + dx; + if (xx < 0 || xx >= nx) { + on = 0; + break; + } + if (!mask[idx3(xx, yy, zz, nx, ny)]) { + on = 0; + break; + } + } + } + } + + out[idx3(x, y, z, nx, ny)] = on; + } + } + } + + return out; +} + +export function morphologicalClose3D(mask: Uint8Array, dims: [number, number, number]): Uint8Array { + const dilated = dilate3x3x3(mask, dims); + return erode3x3x3(dilated, dims); +} + +export function morphologicalOpen3D(mask: Uint8Array, dims: [number, number, number]): Uint8Array { + const eroded = erode3x3x3(mask, dims); + return dilate3x3x3(eroded, dims); +} diff --git a/frontend/src/utils/segmentation/regionGrow3D.ts b/frontend/src/utils/segmentation/regionGrow3D.ts new file mode 100644 index 0000000..f2af80a --- /dev/null +++ b/frontend/src/utils/segmentation/regionGrow3D.ts @@ -0,0 +1,200 @@ +export type Vec3i = { x: number; y: number; z: number }; + +export type RegionGrow3DResult = { + /** Binary mask (0/1) in the same indexing order as the input volume. */ + mask: Uint8Array; + /** Number of voxels included in the region (<= mask.length). */ + count: number; + /** Seed intensity value (raw value from `volume[seedIdx]`). */ + seedValue: number; + /** Whether the grow hit the configured max voxel limit and stopped early. */ + hitMaxVoxels: boolean; +}; + +export type RegionGrow3DOptions = { + /** + * Maximum number of voxels to include before stopping early. + * + * This is a safety valve to prevent accidental runaway segmentation. + */ + maxVoxels?: number; + + /** + * Neighborhood connectivity. + * - 6: faces only (less leakage) + * - 26: faces+edges+corners (more permissive) + */ + connectivity?: 6 | 26; + + /** Yield to the event loop every N dequeued voxels (helps keep UI responsive). */ + yieldEvery?: number; + + /** Optional abort signal for cancellation. */ + signal?: AbortSignal; + + /** Optional progress callback. */ + onProgress?: (p: { processed: number; queued: number }) => void; + + /** Optional yield function (defaults to setTimeout(0)). Useful for tests. */ + yieldFn?: () => Promise; +}; + +function clamp01(x: number): number { + return x < 0 ? 0 : x > 1 ? 1 : x; +} + +function inBounds(x: number, y: number, z: number, nx: number, ny: number, nz: number): boolean { + return x >= 0 && x < nx && y >= 0 && y < ny && z >= 0 && z < nz; +} + +function idx3(x: number, y: number, z: number, nx: number, ny: number): number { + return z * (nx * ny) + y * nx + x; +} + +/** + * Simple intensity-threshold 3D region growing. + * + * Intended as a baseline interactive segmentation tool: + * - user picks a seed voxel (via the slice inspector) + * - we flood-fill neighbors whose intensity lies in [min,max] + * + * Notes: + * - Input volume is expected to be roughly normalized to [0,1] but this isn't strictly required. + * - `yieldEvery` lets the implementation cooperate with the UI thread for large regions. + */ +export async function regionGrow3D(params: { + volume: Float32Array; + dims: [number, number, number]; + seed: Vec3i; + min: number; + max: number; + opts?: RegionGrow3DOptions; +}): Promise { + const { volume, dims, seed } = params; + const nx = dims[0]; + const ny = dims[1]; + const nz = dims[2]; + + const n = nx * ny * nz; + if (volume.length !== n) { + throw new Error(`regionGrow3D: volume length mismatch (expected ${n}, got ${volume.length})`); + } + + if (!inBounds(seed.x, seed.y, seed.z, nx, ny, nz)) { + throw new Error(`regionGrow3D: seed out of bounds: (${seed.x}, ${seed.y}, ${seed.z})`); + } + + const minV = Math.min(params.min, params.max); + const maxV = Math.max(params.min, params.max); + + const opts = params.opts; + const connectivity: 6 | 26 = opts?.connectivity ?? 6; + const maxVoxels = Math.max(1, Math.min(opts?.maxVoxels ?? n, n)); + const yieldEvery = Math.max(0, Math.floor(opts?.yieldEvery ?? 120_000)); + const yieldFn = opts?.yieldFn ?? (() => new Promise((r) => window.setTimeout(r, 0))); + + const mask = new Uint8Array(n); + + const seedIdx = idx3(seed.x, seed.y, seed.z, nx, ny); + const seedValue = volume[seedIdx] ?? 0; + + // Fast exit if seed is outside the acceptance range. + if (!(seedValue >= minV && seedValue <= maxV)) { + return { mask, count: 0, seedValue, hitMaxVoxels: false }; + } + + // Queue holds voxel indices. + const queue = new Uint32Array(maxVoxels); + let head = 0; + let tail = 0; + + queue[tail++] = seedIdx; + mask[seedIdx] = 1; + + const strideY = nx; + const strideZ = nx * ny; + + const accept = (i: number): boolean => { + const v = volume[i] ?? 0; + return v >= minV && v <= maxV; + }; + + const enqueue = (i: number): void => { + mask[i] = 1; + queue[tail++] = i; + }; + + let processed = 0; + let hitMaxVoxels = false; + + while (head < tail) { + if (opts?.signal?.aborted) { + // Preserve partial work: caller can decide whether to keep it. + break; + } + + const i = queue[head++]!; + processed++; + + if (yieldEvery > 0 && (processed % yieldEvery === 0)) { + opts?.onProgress?.({ processed, queued: tail }); + // Yield to keep the UI responsive. + await yieldFn(); + } + + // Decode x/y/z for boundary checks. + const z = Math.floor(i / strideZ); + const yz = i - z * strideZ; + const y = Math.floor(yz / strideY); + const x = yz - y * strideY; + + const tryNeighbor = (nx0: number, ny0: number, nz0: number) => { + if (tail >= maxVoxels) { + hitMaxVoxels = true; + return; + } + if (!inBounds(nx0, ny0, nz0, nx, ny, nz)) return; + const ni = idx3(nx0, ny0, nz0, nx, ny); + if (mask[ni]) return; + if (!accept(ni)) return; + enqueue(ni); + }; + + if (connectivity === 6) { + tryNeighbor(x - 1, y, z); + tryNeighbor(x + 1, y, z); + tryNeighbor(x, y - 1, z); + tryNeighbor(x, y + 1, z); + tryNeighbor(x, y, z - 1); + tryNeighbor(x, y, z + 1); + if (hitMaxVoxels) break; + } else { + for (let dz = -1; dz <= 1; dz++) { + for (let dy = -1; dy <= 1; dy++) { + for (let dx = -1; dx <= 1; dx++) { + if (dx === 0 && dy === 0 && dz === 0) continue; + tryNeighbor(x + dx, y + dy, z + dz); + if (hitMaxVoxels) break; + } + if (hitMaxVoxels) break; + } + if (hitMaxVoxels) break; + } + if (hitMaxVoxels) break; + } + } + + return { + mask, + count: tail, + seedValue, + hitMaxVoxels, + }; +} + +export function computeSeedRange01(params: { seedValue: number; tolerance: number }): { min: number; max: number } { + const tol = Math.max(0, params.tolerance); + const min = clamp01(params.seedValue - tol); + const max = clamp01(params.seedValue + tol); + return { min, max }; +} diff --git a/frontend/tests/segmentation3d.test.ts b/frontend/tests/segmentation3d.test.ts new file mode 100644 index 0000000..73c4b5c --- /dev/null +++ b/frontend/tests/segmentation3d.test.ts @@ -0,0 +1,118 @@ +import { describe, expect, it } from 'vitest'; +import { keepLargestConnectedComponent3D } from '../src/utils/segmentation/connectedComponents3D'; +import { dilate3x3x3, erode3x3x3 } from '../src/utils/segmentation/morphology3D'; +import { regionGrow3D } from '../src/utils/segmentation/regionGrow3D'; + +function sumMask(mask: Uint8Array): number { + let s = 0; + for (let i = 0; i < mask.length; i++) s += mask[i] ? 1 : 0; + return s; +} + +describe('regionGrow3D', () => { + it('grows a simple 3D cube region from a seed', async () => { + const dims: [number, number, number] = [4, 4, 4]; + const [nx, ny, nz] = dims; + const n = nx * ny * nz; + const vol = new Float32Array(n); + vol.fill(0.1); + + // 2x2x2 cube at x,y,z in [1,2] + for (let z = 1; z <= 2; z++) { + for (let y = 1; y <= 2; y++) { + for (let x = 1; x <= 2; x++) { + vol[z * (nx * ny) + y * nx + x] = 0.8; + } + } + } + + const res = await regionGrow3D({ + volume: vol, + dims, + seed: { x: 1, y: 1, z: 1 }, + min: 0.7, + max: 0.9, + opts: { maxVoxels: 100, connectivity: 6, yieldEvery: 0 }, + }); + + expect(res.seedValue).toBeCloseTo(0.8, 6); + expect(res.hitMaxVoxels).toBe(false); + expect(res.count).toBe(8); + expect(sumMask(res.mask)).toBe(8); + }); + + it('returns an empty mask if the seed is out of range', async () => { + const dims: [number, number, number] = [3, 3, 3]; + const vol = new Float32Array(27); + vol.fill(0.2); + + const res = await regionGrow3D({ + volume: vol, + dims, + seed: { x: 1, y: 1, z: 1 }, + min: 0.5, + max: 0.6, + opts: { yieldEvery: 0 }, + }); + + expect(res.count).toBe(0); + expect(sumMask(res.mask)).toBe(0); + }); +}); + +describe('keepLargestConnectedComponent3D', () => { + it('keeps only the largest component', () => { + const dims: [number, number, number] = [4, 4, 1]; + const n = 4 * 4 * 1; + const mask = new Uint8Array(n); + + // Component A: 3 voxels along the top row. + mask[0] = 1; // (0,0,0) + mask[1] = 1; // (1,0,0) + mask[2] = 1; // (2,0,0) + + // Component B: 5 voxels along the bottom row + one above (connected). + mask[12] = 1; // (0,3,0) + mask[13] = 1; // (1,3,0) + mask[14] = 1; // (2,3,0) + mask[15] = 1; // (3,3,0) + mask[11] = 1; // (3,2,0) + + const out = keepLargestConnectedComponent3D({ mask, dims, connectivity: 6 }); + expect(out.keptSize).toBe(5); + expect(sumMask(out.mask)).toBe(5); + + // Ensure A is removed. + expect(out.mask[0]).toBe(0); + expect(out.mask[1]).toBe(0); + expect(out.mask[2]).toBe(0); + + // Ensure B remains. + expect(out.mask[12]).toBe(1); + expect(out.mask[15]).toBe(1); + expect(out.mask[11]).toBe(1); + }); +}); + +describe('morphology3D', () => { + it('dilate3x3x3 grows a single voxel into a 3x3x3 block', () => { + const dims: [number, number, number] = [3, 3, 3]; + const mask = new Uint8Array(27); + + // Center voxel (1,1,1) + mask[13] = 1; + + const dilated = dilate3x3x3(mask, dims); + expect(sumMask(dilated)).toBe(27); + }); + + it('erode3x3x3 shrinks a full 3x3x3 block to the center voxel', () => { + const dims: [number, number, number] = [3, 3, 3]; + const mask = new Uint8Array(27); + mask.fill(1); + + const eroded = erode3x3x3(mask, dims); + expect(sumMask(eroded)).toBe(1); + expect(eroded[13]).toBe(1); + }); +}); From 595944d52026586698514ca01a78ff4c9f521e7a Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 14:20:34 -0800 Subject: [PATCH 05/16] svr3d: add offline onnx tumor segmentation --- frontend/package-lock.json | 39 +++ frontend/package.json | 1 + frontend/src/components/SvrVolume3DViewer.tsx | 230 ++++++++++++++++++ .../utils/segmentation/onnx/logitsToLabels.ts | 86 +++++++ .../src/utils/segmentation/onnx/modelCache.ts | 44 ++++ .../src/utils/segmentation/onnx/ortLoader.ts | 67 +++++ .../segmentation/onnx/tumorSegmentation.ts | 72 ++++++ frontend/tests/onnxLogitsToLabels.test.ts | 51 ++++ frontend/tests/onnxTumorSegmentation.test.ts | 64 +++++ frontend/vite.config.ts | 8 +- 10 files changed, 661 insertions(+), 1 deletion(-) create mode 100644 frontend/src/utils/segmentation/onnx/logitsToLabels.ts create mode 100644 frontend/src/utils/segmentation/onnx/modelCache.ts create mode 100644 frontend/src/utils/segmentation/onnx/ortLoader.ts create mode 100644 frontend/src/utils/segmentation/onnx/tumorSegmentation.ts create mode 100644 frontend/tests/onnxLogitsToLabels.test.ts create mode 100644 frontend/tests/onnxTumorSegmentation.test.ts diff --git a/frontend/package-lock.json b/frontend/package-lock.json index d9f1418..c109ba4 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -20,6 +20,7 @@ "idb": "^8.0.3", "jszip": "^3.10.1", "lucide-react": "^0.562.0", + "onnxruntime-web": "^1.23.2", "react": "^19.2.0", "react-dom": "^19.2.0", "tailwindcss": "^4.1.18" @@ -4244,6 +4245,12 @@ "node": ">=16" } }, + "node_modules/flatbuffers": { + "version": "25.9.23", + "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==", + "license": "Apache-2.0" + }, "node_modules/flatted": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", @@ -4523,6 +4530,12 @@ "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", "license": "ISC" }, + "node_modules/guid-typescript": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", + "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==", + "license": "ISC" + }, "node_modules/hammerjs": { "version": "2.0.8", "resolved": "https://registry.npmjs.org/hammerjs/-/hammerjs-2.0.8.tgz", @@ -5877,6 +5890,26 @@ "wrappy": "1" } }, + "node_modules/onnxruntime-common": { + "version": "1.23.2", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.23.2.tgz", + "integrity": "sha512-5LFsC9Dukzp2WV6kNHYLNzp8sT6V02IubLCbzw2Xd6X5GOlr65gAX6xiJwyi2URJol/s71gaQLC5F2C25AAR2w==", + "license": "MIT" + }, + "node_modules/onnxruntime-web": { + "version": "1.23.2", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.23.2.tgz", + "integrity": "sha512-T09JUtMn+CZLk3mFwqiH0lgQf+4S7+oYHHtk6uhaYAAJI95bTcKi5bOOZYwORXfS/RLZCjDDEXGWIuOCAFlEjg==", + "license": "MIT", + "dependencies": { + "flatbuffers": "^25.1.24", + "guid-typescript": "^1.0.9", + "long": "^5.2.3", + "onnxruntime-common": "1.23.2", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" + } + }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -6087,6 +6120,12 @@ "node": ">=0.10.0" } }, + "node_modules/platform": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz", + "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==", + "license": "MIT" + }, "node_modules/possible-typed-array-names": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index ecb4d4b..cf9e622 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -25,6 +25,7 @@ "idb": "^8.0.3", "jszip": "^3.10.1", "lucide-react": "^0.562.0", + "onnxruntime-web": "^1.23.2", "react": "^19.2.0", "react-dom": "^19.2.0", "tailwindcss": "^4.1.18" diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index efd1d64..77d4fa2 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -1,8 +1,12 @@ import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'; import { ChevronLeft, ChevronRight } from 'lucide-react'; +import type * as Ort from 'onnxruntime-web'; import type { SvrLabelVolume, SvrVolume } from '../types/svr'; import { BRATS_BASE_LABEL_META, BRATS_LABEL_ID, type BratsBaseLabelId } from '../utils/segmentation/brats'; import { buildRgbaPalette256, rgbCss } from '../utils/segmentation/labelPalette'; +import { deleteModelBlob, getModelBlob, getModelSavedAtMs, putModelBlob } from '../utils/segmentation/onnx/modelCache'; +import { createOrtSessionFromModelBlob } from '../utils/segmentation/onnx/ortLoader'; +import { runTumorSegmentationOnnx } from '../utils/segmentation/onnx/tumorSegmentation'; import { computeSeedRange01, regionGrow3D, type Vec3i } from '../utils/segmentation/regionGrow3D'; import { resample2dAreaAverage } from '../utils/svr/resample2d'; @@ -20,6 +24,11 @@ function clamp(x: number, min: number, max: number): number { const SVR3D_CAMERA_Z = 1.6; const SVR3D_FOCAL_Z = 1.2; +// IndexedDB key for the cached tumor segmentation ONNX model. +const ONNX_TUMOR_MODEL_KEY = 'brats-tumor-v1'; + +type OnnxSessionMode = 'webgpu-preferred' | 'wasm'; + async function rgbaToPngBlob(params: { rgba: Uint8ClampedArray; width: number; height: number }): Promise { const { rgba, width, height } = params; @@ -486,6 +495,39 @@ export const SvrVolume3DViewer = forwardRef(null); const labels = labelsOverride ?? generatedLabels; + // Phase 3: ONNX model execution (offline; model cached in IndexedDB). + const onnxSessionRef = useRef(null); + const onnxSessionModeRef = useRef(null); + const onnxFileInputRef = useRef(null); + const [onnxStatus, setOnnxStatus] = useState<{ + cached: boolean; + savedAtMs: number | null; + loading: boolean; + sessionReady: boolean; + message?: string; + error?: string; + }>(() => ({ + cached: false, + savedAtMs: null, + loading: false, + sessionReady: false, + })); + + const refreshOnnxCacheStatus = useCallback(() => { + void getModelSavedAtMs(ONNX_TUMOR_MODEL_KEY) + .then((savedAtMs) => { + setOnnxStatus((s) => ({ ...s, cached: savedAtMs !== null, savedAtMs })); + }) + .catch((e) => { + const msg = e instanceof Error ? e.message : String(e); + setOnnxStatus((s) => ({ ...s, error: msg })); + }); + }, []); + + useEffect(() => { + refreshOnnxCacheStatus(); + }, [refreshOnnxCacheStatus]); + // Viewer controls (composite-only) const [controlsCollapsed, setControlsCollapsed] = useState(false); const [threshold, setThreshold] = useState(0.05); @@ -831,6 +873,125 @@ export const SvrVolume3DViewer = forwardRef { + onnxFileInputRef.current?.click(); + }, []); + + const onnxClearModel = useCallback(() => { + onnxSessionRef.current = null; + onnxSessionModeRef.current = null; + setOnnxStatus((s) => ({ ...s, sessionReady: false, loading: true, message: 'Clearing cached model…', error: undefined })); + + void deleteModelBlob(ONNX_TUMOR_MODEL_KEY) + .then(() => { + setOnnxStatus((s) => ({ ...s, loading: false, message: 'Cleared cached model' })); + refreshOnnxCacheStatus(); + }) + .catch((e) => { + const msg = e instanceof Error ? e.message : String(e); + setOnnxStatus((s) => ({ ...s, loading: false, error: msg })); + }); + }, [refreshOnnxCacheStatus]); + + const onnxHandleSelectedFile = useCallback( + (file: File) => { + onnxSessionRef.current = null; + onnxSessionModeRef.current = null; + setOnnxStatus((s) => ({ + ...s, + loading: true, + sessionReady: false, + message: `Caching model: ${file.name}`, + error: undefined, + })); + + void putModelBlob(ONNX_TUMOR_MODEL_KEY, file) + .then(() => { + setOnnxStatus((s) => ({ ...s, loading: false, message: 'Model cached' })); + refreshOnnxCacheStatus(); + }) + .catch((e) => { + const msg = e instanceof Error ? e.message : String(e); + setOnnxStatus((s) => ({ ...s, loading: false, error: msg })); + }); + }, + [refreshOnnxCacheStatus] + ); + + const ensureOnnxSession = useCallback(async (): Promise<{ session: Ort.InferenceSession; mode: OnnxSessionMode }> => { + if (onnxSessionRef.current) { + return { + session: onnxSessionRef.current, + mode: onnxSessionModeRef.current ?? 'webgpu-preferred', + }; + } + + const blob = await getModelBlob(ONNX_TUMOR_MODEL_KEY); + if (!blob) { + throw new Error('No cached ONNX model found. Upload one first.'); + } + + try { + const session = await createOrtSessionFromModelBlob({ model: blob, preferWebGpu: true, logLevel: 'warning' }); + onnxSessionRef.current = session; + onnxSessionModeRef.current = 'webgpu-preferred'; + return { session, mode: 'webgpu-preferred' }; + } catch (_eWebGpu) { + // Fallback to WASM-only. + const session = await createOrtSessionFromModelBlob({ model: blob, preferWebGpu: false, logLevel: 'warning' }); + onnxSessionRef.current = session; + onnxSessionModeRef.current = 'wasm'; + return { session, mode: 'wasm' }; + } + }, []); + + const initOnnxSession = useCallback(() => { + setOnnxStatus((s) => ({ ...s, loading: true, message: 'Initializing ONNX runtime…', error: undefined })); + + void ensureOnnxSession() + .then(({ mode }) => { + setOnnxStatus((s) => ({ + ...s, + loading: false, + sessionReady: true, + message: mode === 'wasm' ? 'ONNX session ready (WASM)' : 'ONNX session ready (WebGPU preferred)', + })); + }) + .catch((e) => { + const msg = e instanceof Error ? e.message : String(e); + setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: false, error: msg })); + }); + }, [ensureOnnxSession]); + + const runOnnxSegmentation = useCallback(() => { + if (!volume) return; + + const started = performance.now(); + setOnnxStatus((s) => ({ ...s, loading: true, message: 'Running ONNX segmentation…', error: undefined })); + + void (async () => { + const { session, mode } = await ensureOnnxSession(); + setOnnxStatus((s) => ({ + ...s, + sessionReady: true, + loading: true, + message: mode === 'wasm' ? 'Running ONNX segmentation… (WASM)' : 'Running ONNX segmentation… (WebGPU preferred)', + })); + + const res = await runTumorSegmentationOnnx({ session, volume: volume.data, dims: volume.dims }); + + setGeneratedLabels({ data: res.labels, dims: volume.dims, meta: BRATS_BASE_LABEL_META }); + setLabelsEnabled(true); + + const ms = Math.round(performance.now() - started); + setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: true, message: `Segmentation complete (${ms}ms)` })); + })().catch((e) => { + const msg = e instanceof Error ? e.message : String(e); + const hasSession = onnxSessionRef.current !== null; + setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: hasSession, error: msg })); + }); + }, [ensureOnnxSession, volume]); + // Draw the inspector slice to a 2D canvas. useEffect(() => { const canvas = sliceCanvasRef.current; @@ -1805,6 +1966,75 @@ void main() {
{growStatus.message}
) : null} +
+
ONNX tumor model
+ + { + const f = e.target.files?.[0]; + if (f) { + onnxHandleSelectedFile(f); + } + // Allow re-uploading the same file. + e.target.value = ''; + }} + /> + +
+ + + + + + + +
+ +
+ Cached: {onnxStatus.cached ? 'yes' : 'no'} + {onnxStatus.savedAtMs ? ` · saved ${new Date(onnxStatus.savedAtMs).toLocaleString()}` : ''} + {onnxStatus.sessionReady ? ' · session ready' : ''} +
+ + {onnxStatus.error ? ( +
{onnxStatus.error}
+ ) : onnxStatus.message ? ( +
{onnxStatus.message}
+ ) : null} +
+ {!hasLabels || !labels ? (
No segmentation labels available yet.
) : ( diff --git a/frontend/src/utils/segmentation/onnx/logitsToLabels.ts b/frontend/src/utils/segmentation/onnx/logitsToLabels.ts new file mode 100644 index 0000000..bb5a389 --- /dev/null +++ b/frontend/src/utils/segmentation/onnx/logitsToLabels.ts @@ -0,0 +1,86 @@ +export type LogitsTensorLike = { + data: Float32Array; + /** Expected to be [1,C,Z,Y,X] or [C,Z,Y,X]. */ + dims: readonly number[]; +}; + +export type LogitsToLabelsResult = { + /** Flattened label volume in Z,Y,X order (x fastest). */ + labels: Uint8Array; + /** Spatial dims in [x,y,z] order for convenience. */ + spatialDims: [number, number, number]; +}; + +function assertFiniteInt(v: number, name: string): number { + if (!Number.isFinite(v)) { + throw new Error(`logitsToLabels: ${name} must be finite`); + } + const vi = Math.floor(v); + if (vi !== v) { + throw new Error(`logitsToLabels: ${name} must be an integer`); + } + return vi; +} + +export function logitsToLabels(params: { + logits: LogitsTensorLike; + /** Maps class index -> uint8 label id (0..255). */ + labelMap: readonly number[]; +}): LogitsToLabelsResult { + const { logits, labelMap } = params; + + const dims = logits.dims; + const data = logits.data; + + let c = 0; + let z = 0; + let y = 0; + let x = 0; + + if (dims.length === 5) { + const n = assertFiniteInt(dims[0] ?? 0, 'N'); + if (n !== 1) { + throw new Error(`logitsToLabels: only N=1 is supported (got ${n})`); + } + c = assertFiniteInt(dims[1] ?? 0, 'C'); + z = assertFiniteInt(dims[2] ?? 0, 'Z'); + y = assertFiniteInt(dims[3] ?? 0, 'Y'); + x = assertFiniteInt(dims[4] ?? 0, 'X'); + } else if (dims.length === 4) { + c = assertFiniteInt(dims[0] ?? 0, 'C'); + z = assertFiniteInt(dims[1] ?? 0, 'Z'); + y = assertFiniteInt(dims[2] ?? 0, 'Y'); + x = assertFiniteInt(dims[3] ?? 0, 'X'); + } else { + throw new Error(`logitsToLabels: unsupported logits dims length ${dims.length}`); + } + + if (!(c > 0 && x > 0 && y > 0 && z > 0)) { + throw new Error(`logitsToLabels: invalid dims C=${c} Z=${z} Y=${y} X=${x}`); + } + + const spatial = x * y * z; + if (data.length !== c * spatial) { + throw new Error(`logitsToLabels: data length mismatch (expected ${c * spatial}, got ${data.length})`); + } + + const out = new Uint8Array(spatial); + + for (let p = 0; p < spatial; p++) { + let bestC = 0; + let best = -Infinity; + + for (let ci = 0; ci < c; ci++) { + const v = data[ci * spatial + p] ?? -Infinity; + if (v > best) { + best = v; + bestC = ci; + } + } + + const labelId = labelMap[bestC] ?? 0; + out[p] = (labelId & 0xff) >>> 0; + } + + return { labels: out, spatialDims: [x, y, z] }; +} diff --git a/frontend/src/utils/segmentation/onnx/modelCache.ts b/frontend/src/utils/segmentation/onnx/modelCache.ts new file mode 100644 index 0000000..9319c86 --- /dev/null +++ b/frontend/src/utils/segmentation/onnx/modelCache.ts @@ -0,0 +1,44 @@ +import { openDB } from 'idb'; + +const DB_NAME = 'miraviewer:model-cache'; +const DB_VERSION = 1; +const STORE = 'models'; + +type ModelRecord = { + key: string; + blob: Blob; + savedAtMs: number; +}; + +async function getDb() { + return openDB(DB_NAME, DB_VERSION, { + upgrade(db) { + if (!db.objectStoreNames.contains(STORE)) { + db.createObjectStore(STORE); + } + }, + }); +} + +export async function putModelBlob(key: string, blob: Blob): Promise { + const db = await getDb(); + const rec: ModelRecord = { key, blob, savedAtMs: Date.now() }; + await db.put(STORE, rec, key); +} + +export async function getModelBlob(key: string): Promise { + const db = await getDb(); + const rec = (await db.get(STORE, key)) as ModelRecord | undefined; + return rec?.blob ?? null; +} + +export async function deleteModelBlob(key: string): Promise { + const db = await getDb(); + await db.delete(STORE, key); +} + +export async function getModelSavedAtMs(key: string): Promise { + const db = await getDb(); + const rec = (await db.get(STORE, key)) as ModelRecord | undefined; + return rec?.savedAtMs ?? null; +} diff --git a/frontend/src/utils/segmentation/onnx/ortLoader.ts b/frontend/src/utils/segmentation/onnx/ortLoader.ts new file mode 100644 index 0000000..7a013ea --- /dev/null +++ b/frontend/src/utils/segmentation/onnx/ortLoader.ts @@ -0,0 +1,67 @@ +import type * as Ort from 'onnxruntime-web'; + +// We intentionally load ORT from: +// - dev: the installed `onnxruntime-web` module (Vite can serve wasm assets correctly) +// - prod: statically-copied assets under /onnxruntime/ (works fully offline) +let ortPromise: Promise | null = null; + +export async function loadOrtAll(): Promise { + if (ortPromise) return ortPromise; + + ortPromise = (async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let mod: any; + + if (import.meta.env.DEV) { + // In dev, load from the package so Vite can handle wasm asset URLs. + mod = await import('onnxruntime-web'); + } else { + // In production builds, load from our vendored runtime assets. + mod = await import(/* @vite-ignore */ '/onnxruntime/ort.all.bundle.min.mjs'); + } + + // The ESM bundles export both named exports and a default export. + const ort: typeof Ort = (mod?.default ?? mod) as typeof Ort; + + // Prefer stability: threads require COOP/COEP (crossOriginIsolated) which we don't assume. + ort.env.wasm.numThreads = 1; + + if (!import.meta.env.DEV) { + // Ensure ORT can locate its runtime assets. + ort.env.wasm.wasmPaths = '/onnxruntime/'; + } + + return ort; + })(); + + return ortPromise; +} + +export async function createOrtSessionFromModelBlob(params: { + model: Blob; + preferWebGpu?: boolean; + logLevel?: Ort.Env['logLevel']; +}): Promise { + const ort = await loadOrtAll(); + + if (params.logLevel) { + ort.env.logLevel = params.logLevel; + } + + + const bytes = await params.model.arrayBuffer(); + + const baseOpts: Ort.InferenceSession.SessionOptions = { + graphOptimizationLevel: 'all', + }; + + if (params.preferWebGpu) { + // Try WebGPU first; if unavailable, ORT will pick the best available provider. + // NOTE: this requires the /onnxruntime/ assets to be present in the build output. + baseOpts.executionProviders = ['webgpu', 'wasm']; + } else { + baseOpts.executionProviders = ['wasm']; + } + + return ort.InferenceSession.create(bytes, baseOpts); +} diff --git a/frontend/src/utils/segmentation/onnx/tumorSegmentation.ts b/frontend/src/utils/segmentation/onnx/tumorSegmentation.ts new file mode 100644 index 0000000..dc6985f --- /dev/null +++ b/frontend/src/utils/segmentation/onnx/tumorSegmentation.ts @@ -0,0 +1,72 @@ +import type * as Ort from 'onnxruntime-web'; +import { BRATS_LABEL_ID } from '../brats'; +import { loadOrtAll } from './ortLoader'; +import { logitsToLabels } from './logitsToLabels'; + +export type TumorOnnxSegmentationResult = { + /** Flattened label IDs (length = nx*ny*nz). */ + labels: Uint8Array; + /** The output logits dims returned by the model. */ + logitsDims: readonly number[]; +}; + +export async function runTumorSegmentationOnnx(params: { + session: Ort.InferenceSession; + volume: Float32Array; + dims: [number, number, number]; + /** Override model input name. Defaults to first session input. */ + inputName?: string; + /** Override model output name. Defaults to first session output. */ + outputName?: string; + /** Map class index -> label id. Default assumes 4 classes [0,1,2,4]. */ + labelMap?: readonly number[]; +}): Promise { + const { session, volume, dims } = params; + const [nx, ny, nz] = dims; + + const ort = await loadOrtAll(); + + const inputName = params.inputName ?? session.inputNames[0]; + const outputName = params.outputName ?? session.outputNames[0]; + if (!inputName) { + throw new Error('ONNX session has no inputs'); + } + if (!outputName) { + throw new Error('ONNX session has no outputs'); + } + + // ORT expects NCHW-like layout for 3D conv models: [N, C, Z, Y, X]. + // Our Float32Array is already in X-fastest order, so [Z,Y,X] is consistent. + const inputTensor = new ort.Tensor('float32', volume, [1, 1, nz, ny, nx]); + + const outputs = await session.run({ [inputName]: inputTensor } as Record); + const logitsTensor = outputs[outputName]; + if (!logitsTensor) { + throw new Error(`ONNX run did not return expected output: ${outputName}`); + } + + if (logitsTensor.type !== 'float32') { + throw new Error(`Unsupported logits tensor type: ${logitsTensor.type}`); + } + + const labelMap = params.labelMap ?? [BRATS_LABEL_ID.BACKGROUND, BRATS_LABEL_ID.NCR_NET, BRATS_LABEL_ID.EDEMA, BRATS_LABEL_ID.ENHANCING]; + + const { labels, spatialDims } = logitsToLabels({ + logits: { data: logitsTensor.data as Float32Array, dims: logitsTensor.dims }, + labelMap, + }); + + // Sanity check that the model output matches the current SVR volume. + const expected = nx * ny * nz; + if (labels.length !== expected) { + throw new Error(`Model output spatial size mismatch (expected ${expected}, got ${labels.length}).`); + } + + // NOTE: spatialDims is [X,Y,Z] for convenience. This should match the SVR dims. + if (spatialDims[0] !== nx || spatialDims[1] !== ny || spatialDims[2] !== nz) { + // Don't fail hard: some models output in a different orientation; callers can add remapping later. + console.warn('[onnx] Output dims differ from SVR volume dims', { spatialDims, svrDims: dims }); + } + + return { labels, logitsDims: logitsTensor.dims }; +} diff --git a/frontend/tests/onnxLogitsToLabels.test.ts b/frontend/tests/onnxLogitsToLabels.test.ts new file mode 100644 index 0000000..8eedf12 --- /dev/null +++ b/frontend/tests/onnxLogitsToLabels.test.ts @@ -0,0 +1,51 @@ +import { describe, expect, it } from 'vitest'; +import { logitsToLabels } from '../src/utils/segmentation/onnx/logitsToLabels'; + +describe('logitsToLabels', () => { + it('converts [1,C,Z,Y,X] logits to uint8 labels using a label map', () => { + // C=3, Z=1, Y=1, X=4 + const dims = [1, 3, 1, 1, 4] as const; + const spatial = 4; + + // Layout: [C, spatial] + // voxel 0: class0 + // voxel 1: class1 + // voxel 2: class2 + // voxel 3: class1 + const data = new Float32Array([ + // c0 + 10, 0, 0, 0, + // c1 + 0, 9, 0, 8, + // c2 + 0, 0, 7, 0, + ]); + + const out = logitsToLabels({ logits: { data, dims }, labelMap: [0, 1, 4] }); + expect(out.spatialDims).toEqual([4, 1, 1]); + expect(Array.from(out.labels)).toEqual([0, 1, 4, 1]); + expect(out.labels.length).toBe(spatial); + }); + + it('supports [C,Z,Y,X] logits', () => { + const dims = [2, 1, 1, 3] as const; // C=2, Z=1, Y=1, X=3 + const data = new Float32Array([ + // c0 + 0, 5, 0, + // c1 + 1, 0, 2, + ]); + + const out = logitsToLabels({ logits: { data, dims }, labelMap: [0, 2] }); + expect(Array.from(out.labels)).toEqual([2, 0, 2]); + }); + + it('throws on shape/data mismatch', () => { + expect(() => + logitsToLabels({ + logits: { data: new Float32Array([1, 2, 3]), dims: [1, 2, 1, 1, 2] }, + labelMap: [0, 1], + }) + ).toThrow(/data length mismatch/i); + }); +}); diff --git a/frontend/tests/onnxTumorSegmentation.test.ts b/frontend/tests/onnxTumorSegmentation.test.ts new file mode 100644 index 0000000..78d7d06 --- /dev/null +++ b/frontend/tests/onnxTumorSegmentation.test.ts @@ -0,0 +1,64 @@ +import { describe, expect, it, vi } from 'vitest'; + +// Mock ORT loader to avoid pulling real onnxruntime-web + wasm during unit tests. +vi.mock('../src/utils/segmentation/onnx/ortLoader', () => { + class Tensor { + type: string; + data: unknown; + dims: number[]; + + constructor(type: string, data: unknown, dims: number[]) { + this.type = type; + this.data = data; + this.dims = dims; + } + } + + return { + loadOrtAll: async () => ({ Tensor, env: { wasm: {} } }), + }; +}); + +import { runTumorSegmentationOnnx } from '../src/utils/segmentation/onnx/tumorSegmentation'; + +describe('runTumorSegmentationOnnx', () => { + it('feeds [1,1,Z,Y,X] tensor and converts logits to labels', async () => { + const dims: [number, number, number] = [2, 1, 1]; // nx=2, ny=1, nz=1 + const volume = new Float32Array([0.1, 0.9]); + + // 4 classes (0,1,2,4) and spatial=2 + // logits layout: [C, spatial] + // voxel0 -> class1, voxel1 -> class3 + const logits = new Float32Array([ + // c0 + 0, 0, + // c1 + 5, 0, + // c2 + 0, 0, + // c3 + 0, 7, + ]); + + const session = { + inputNames: ['input'], + outputNames: ['logits'], + run: vi.fn(async (feeds: Record) => { + expect(Object.keys(feeds)).toEqual(['input']); + expect(feeds.input.type).toBe('float32'); + expect(feeds.input.dims).toEqual([1, 1, 1, 1, 2]); // [N,C,Z,Y,X] + + return { + logits: { + type: 'float32', + dims: [1, 4, 1, 1, 2], + data: logits, + }, + }; + }), + } as any; + + const out = await runTumorSegmentationOnnx({ session, volume, dims }); + expect(Array.from(out.labels)).toEqual([1, 4]); + }); +}); diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 6b1cabd..f80cd35 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -20,6 +20,12 @@ export default defineConfig(() => { src: 'node_modules/@itk-wasm/elastix/dist/pipelines/*.{js,wasm,wasm.zst}', dest: 'pipelines/', }, + // onnxruntime-web dynamically loads helper .mjs modules + .wasm binaries at runtime. + // We vendor these into the output so segmentation can run fully offline. + { + src: 'node_modules/onnxruntime-web/dist/ort*.{mjs,wasm}', + dest: 'onnxruntime/', + }, ], }) ); @@ -40,7 +46,7 @@ export default defineConfig(() => { // Avoid pre-bundling ITK-Wasm packages. These rely on lazy-loaded web workers // and Emscripten modules that can break when optimized. optimizeDeps: { - exclude: ['itk-wasm', '@itk-wasm/elastix', '@thewtex/zstddec'], + exclude: ['itk-wasm', '@itk-wasm/elastix', '@thewtex/zstddec', 'onnxruntime-web'], }, // Expose only Vite-prefixed env vars to the client. envPrefix: ['VITE_'], From 4f826cc9a5c4295cdcda575ba58481d9c650715d Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 14:31:37 -0800 Subject: [PATCH 06/16] svr3d: auto-run ml, metrics, and nifti export --- frontend/src/components/SvrVolume3DViewer.tsx | 208 +++++++++++++++--- frontend/src/utils/segmentation/nifti1.ts | 154 +++++++++++++ frontend/tests/nifti1.test.ts | 50 +++++ 3 files changed, 384 insertions(+), 28 deletions(-) create mode 100644 frontend/src/utils/segmentation/nifti1.ts create mode 100644 frontend/tests/nifti1.test.ts diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index 77d4fa2..737151d 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -4,6 +4,7 @@ import type * as Ort from 'onnxruntime-web'; import type { SvrLabelVolume, SvrVolume } from '../types/svr'; import { BRATS_BASE_LABEL_META, BRATS_LABEL_ID, type BratsBaseLabelId } from '../utils/segmentation/brats'; import { buildRgbaPalette256, rgbCss } from '../utils/segmentation/labelPalette'; +import { buildNifti1Uint8 } from '../utils/segmentation/nifti1'; import { deleteModelBlob, getModelBlob, getModelSavedAtMs, putModelBlob } from '../utils/segmentation/onnx/modelCache'; import { createOrtSessionFromModelBlob } from '../utils/segmentation/onnx/ortLoader'; import { runTumorSegmentationOnnx } from '../utils/segmentation/onnx/tumorSegmentation'; @@ -513,6 +514,13 @@ export const SvrVolume3DViewer = forwardRef(null); + const refreshOnnxCacheStatus = useCallback(() => { void getModelSavedAtMs(ONNX_TUMOR_MODEL_KEY) .then((savedAtMs) => { @@ -556,6 +564,12 @@ export const SvrVolume3DViewer = forwardRef (s.loading ? { ...s, loading: false } : s)); }, [volume]); const hasLabels = useMemo(() => { @@ -569,6 +583,33 @@ export const SvrVolume3DViewer = forwardRef { + if (!volume) return null; + if (!labels) return null; + if (!hasLabels) return null; + + const counts = new Map(); + const data = labels.data; + + for (let i = 0; i < data.length; i++) { + const id = data[i] ?? 0; + if (id === 0) continue; + counts.set(id, (counts.get(id) ?? 0) + 1); + } + + const [vx, vy, vz] = volume.voxelSizeMm; + const voxelVolMm3 = Math.abs(vx * vy * vz); + + let totalCount = 0; + for (const c of counts.values()) { + totalCount += c; + } + + const totalMl = voxelVolMm3 > 0 ? (totalCount * voxelVolMm3) / 1000 : 0; + + return { counts, voxelVolMm3, totalCount, totalMl }; + }, [hasLabels, labels, volume]); + // Slice inspector (orthogonal slices). const sliceCanvasRef = useRef(null); const [inspectPlane, setInspectPlane] = useState<'axial' | 'coronal' | 'sagittal'>('axial'); @@ -602,6 +643,34 @@ export const SvrVolume3DViewer = forwardRef { + if (!volume) return; + if (!labels) return; + if (!hasLabels) return; + + const buf = buildNifti1Uint8({ + data: labels.data, + dims: volume.dims, + voxelSizeMm: volume.voxelSizeMm, + description: 'MiraViewer SVR labels (uint8)', + units: { spatial: 'mm' }, + }); + + const blob = new Blob([buf], { type: 'application/octet-stream' }); + const url = URL.createObjectURL(blob); + try { + const date = new Date().toISOString().slice(0, 10); + const [nx, ny, nz] = volume.dims; + const a = document.createElement('a'); + a.href = url; + a.download = `svr_labels_${nx}x${ny}x${nz}_${date}.nii`; + a.rel = 'noopener'; + a.click(); + } finally { + URL.revokeObjectURL(url); + } + }, [hasLabels, labels, volume]); + useImperativeHandle( ref, () => ({ @@ -966,32 +1035,71 @@ export const SvrVolume3DViewer = forwardRef { if (!volume) return; + const runId = ++onnxSegRunIdRef.current; + setOnnxSegRunning(true); + const started = performance.now(); setOnnxStatus((s) => ({ ...s, loading: true, message: 'Running ONNX segmentation…', error: undefined })); void (async () => { - const { session, mode } = await ensureOnnxSession(); - setOnnxStatus((s) => ({ - ...s, - sessionReady: true, - loading: true, - message: mode === 'wasm' ? 'Running ONNX segmentation… (WASM)' : 'Running ONNX segmentation… (WebGPU preferred)', - })); + try { + const { session, mode } = await ensureOnnxSession(); + if (onnxSegRunIdRef.current !== runId) return; - const res = await runTumorSegmentationOnnx({ session, volume: volume.data, dims: volume.dims }); + setOnnxStatus((s) => ({ + ...s, + sessionReady: true, + loading: true, + message: mode === 'wasm' ? 'Running ONNX segmentation… (WASM)' : 'Running ONNX segmentation… (WebGPU preferred)', + })); - setGeneratedLabels({ data: res.labels, dims: volume.dims, meta: BRATS_BASE_LABEL_META }); - setLabelsEnabled(true); + const res = await runTumorSegmentationOnnx({ session, volume: volume.data, dims: volume.dims }); + if (onnxSegRunIdRef.current !== runId) return; - const ms = Math.round(performance.now() - started); - setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: true, message: `Segmentation complete (${ms}ms)` })); - })().catch((e) => { - const msg = e instanceof Error ? e.message : String(e); - const hasSession = onnxSessionRef.current !== null; - setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: hasSession, error: msg })); - }); + setGeneratedLabels({ data: res.labels, dims: volume.dims, meta: BRATS_BASE_LABEL_META }); + setLabelsEnabled(true); + + const ms = Math.round(performance.now() - started); + setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: true, message: `Segmentation complete (${ms}ms)` })); + } catch (e) { + if (onnxSegRunIdRef.current !== runId) return; + const msg = e instanceof Error ? e.message : String(e); + const hasSession = onnxSessionRef.current !== null; + setOnnxStatus((s) => ({ ...s, loading: false, sessionReady: hasSession, error: msg })); + } finally { + if (onnxSegRunIdRef.current === runId) { + setOnnxSegRunning(false); + } + } + })(); }, [ensureOnnxSession, volume]); + const cancelOnnxSegmentation = useCallback(() => { + if (!onnxSegRunning) return; + onnxSegRunIdRef.current++; + setOnnxSegRunning(false); + setOnnxStatus((s) => ({ ...s, loading: false, message: 'Segmentation cancelled', error: undefined })); + }, [onnxSegRunning]); + + // Auto-run ONNX segmentation once per SVR volume (when enabled and a model is cached). + useEffect(() => { + if (!autoRunOnnx) return; + if (!volume) return; + if (!onnxStatus.cached) return; + + // Don't clobber externally-provided labels or manual work. + if (labels) { + onnxAutoRunAttemptedForVolumeRef.current = volume.data; + return; + } + + // Only attempt once per volume. + if (onnxAutoRunAttemptedForVolumeRef.current === volume.data) return; + onnxAutoRunAttemptedForVolumeRef.current = volume.data; + + runOnnxSegmentation(); + }, [autoRunOnnx, labels, onnxStatus.cached, runOnnxSegmentation, volume]); + // Draw the inspector slice to a 2D canvas. useEffect(() => { const canvas = sliceCanvasRef.current; @@ -2012,6 +2120,16 @@ void main() { Run ML + {onnxSegRunning ? ( + + ) : null} + +
Exports a uint8 label volume in NIfTI-1 format (single-file .nii).
+
diff --git a/frontend/src/utils/segmentation/nifti1.ts b/frontend/src/utils/segmentation/nifti1.ts new file mode 100644 index 0000000..a2d70fe --- /dev/null +++ b/frontend/src/utils/segmentation/nifti1.ts @@ -0,0 +1,154 @@ +export type Nifti1Units = { + spatial: 'mm' | 'm' | 'um'; + temporal?: 'sec' | 'msec' | 'usec' | 'hz' | 'ppm' | 'rads'; +}; + +function clampAscii(s: string, maxBytes: number): Uint8Array { + const out = new Uint8Array(maxBytes); + const enc = new TextEncoder(); + const bytes = enc.encode(s); + out.set(bytes.subarray(0, maxBytes)); + return out; +} + +function unitsToXyzt(units: Nifti1Units | undefined): number { + const spatial = units?.spatial ?? 'mm'; + const temporal = units?.temporal; + + const spatialCode = spatial === 'm' ? 1 : spatial === 'mm' ? 2 : 3; // um + + const temporalCode = + temporal === 'msec' + ? 16 + : temporal === 'usec' + ? 24 + : temporal === 'hz' + ? 32 + : temporal === 'ppm' + ? 40 + : temporal === 'rads' + ? 48 + : temporal === 'sec' + ? 8 + : 0; + + return spatialCode | temporalCode; +} + +export function buildNifti1Uint8(params: { + /** Flattened in X-fastest order (length = nx*ny*nz). */ + data: Uint8Array; + dims: [number, number, number]; + voxelSizeMm: [number, number, number]; + description?: string; + /** Defaults to mm. */ + units?: Nifti1Units; +}): ArrayBuffer { + const [nx, ny, nz] = params.dims; + + if (!(nx > 0 && ny > 0 && nz > 0)) { + throw new Error(`buildNifti1Uint8: invalid dims ${nx}x${ny}x${nz}`); + } + + const expected = nx * ny * nz; + if (params.data.length !== expected) { + throw new Error(`buildNifti1Uint8: data length mismatch (expected ${expected}, got ${params.data.length})`); + } + + // NIfTI-1 .nii = 348-byte header + 4-byte extension + payload. + const HEADER_BYTES = 348; + const VOX_OFFSET = 352; + + const header = new ArrayBuffer(HEADER_BYTES); + const dv = new DataView(header); + + // Little-endian NIfTI-1. + dv.setInt32(0, HEADER_BYTES, true); // sizeof_hdr + + // dim[0..7] (int16), starting at offset 40. + dv.setInt16(40, 3, true); // 3D + dv.setInt16(42, nx, true); + dv.setInt16(44, ny, true); + dv.setInt16(46, nz, true); + dv.setInt16(48, 1, true); // dim[4] + dv.setInt16(50, 1, true); + dv.setInt16(52, 1, true); + dv.setInt16(54, 1, true); + + // datatype + bitpix. + // NIfTI datatype codes: uint8 = 2, bitpix = 8. + dv.setInt16(70, 2, true); // datatype + dv.setInt16(72, 8, true); // bitpix + + // pixdim[0..7] (float32), starting at offset 76. + // pixdim[0] is qfac; keep 1. + dv.setFloat32(76, 1, true); + + const vx = Math.abs(params.voxelSizeMm[0]); + const vy = Math.abs(params.voxelSizeMm[1]); + const vz = Math.abs(params.voxelSizeMm[2]); + dv.setFloat32(80, vx, true); + dv.setFloat32(84, vy, true); + dv.setFloat32(88, vz, true); + + // vox_offset (float32) + dv.setFloat32(108, VOX_OFFSET, true); + + // Scaling: identity. + dv.setFloat32(112, 1, true); // scl_slope + dv.setFloat32(116, 0, true); // scl_inter + + // Units. + dv.setUint8(123, unitsToXyzt(params.units)); + + // descrip[80] at offset 148. + if (params.description) { + new Uint8Array(header, 148, 80).set(clampAscii(params.description, 80)); + } + + // Prefer sform affine: voxel -> mm (diagonal). + dv.setInt16(252, 0, true); // qform_code + dv.setInt16(254, 1, true); // sform_code + + // srow_x/y/z at offsets 280/296/312. + // Affine maps (i,j,k,1) to world mm. + // srow_x = [vx, 0, 0, 0] + // srow_y = [0, vy, 0, 0] + // srow_z = [0, 0, vz, 0] + dv.setFloat32(280, vx, true); + dv.setFloat32(284, 0, true); + dv.setFloat32(288, 0, true); + dv.setFloat32(292, 0, true); + + dv.setFloat32(296, 0, true); + dv.setFloat32(300, vy, true); + dv.setFloat32(304, 0, true); + dv.setFloat32(308, 0, true); + + dv.setFloat32(312, 0, true); + dv.setFloat32(316, 0, true); + dv.setFloat32(320, vz, true); + dv.setFloat32(324, 0, true); + + // magic[4] at offset 344: "n+1\0" for .nii + dv.setUint8(344, 'n'.charCodeAt(0)); + dv.setUint8(345, '+'.charCodeAt(0)); + dv.setUint8(346, '1'.charCodeAt(0)); + dv.setUint8(347, 0); + + const out = new Uint8Array(VOX_OFFSET + params.data.length); + + // Header. + out.set(new Uint8Array(header), 0); + + // Extension bytes (4): all zeros. + out[348] = 0; + out[349] = 0; + out[350] = 0; + out[351] = 0; + + // Payload. + out.set(params.data, VOX_OFFSET); + + return out.buffer; +} diff --git a/frontend/tests/nifti1.test.ts b/frontend/tests/nifti1.test.ts new file mode 100644 index 0000000..458c393 --- /dev/null +++ b/frontend/tests/nifti1.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, it } from 'vitest'; +import { buildNifti1Uint8 } from '../src/utils/segmentation/nifti1'; + +describe('buildNifti1Uint8', () => { + it('writes a minimal .nii with correct header + payload', () => { + const dims: [number, number, number] = [2, 3, 1]; + const voxelSizeMm: [number, number, number] = [1.5, 2.0, 3.0]; + const data = new Uint8Array([0, 1, 2, 3, 4, 5]); + + const buf = buildNifti1Uint8({ data, dims, voxelSizeMm, description: 'unit test' }); + expect(buf.byteLength).toBe(352 + data.length); + + const dv = new DataView(buf); + + // sizeof_hdr + expect(dv.getInt32(0, true)).toBe(348); + + // dim + expect(dv.getInt16(40, true)).toBe(3); + expect(dv.getInt16(42, true)).toBe(2); + expect(dv.getInt16(44, true)).toBe(3); + expect(dv.getInt16(46, true)).toBe(1); + + // datatype (uint8) + bitpix + expect(dv.getInt16(70, true)).toBe(2); + expect(dv.getInt16(72, true)).toBe(8); + + // vox_offset + expect(dv.getFloat32(108, true)).toBe(352); + + // pixdim + expect(dv.getFloat32(80, true)).toBeCloseTo(1.5); + expect(dv.getFloat32(84, true)).toBeCloseTo(2.0); + expect(dv.getFloat32(88, true)).toBeCloseTo(3.0); + + // magic + const magic = String.fromCharCode(dv.getUint8(344), dv.getUint8(345), dv.getUint8(346)); + expect(magic).toBe('n+1'); + + // payload + const payload = new Uint8Array(buf, 352); + expect(Array.from(payload)).toEqual(Array.from(data)); + }); + + it('throws on size mismatch', () => { + expect(() => buildNifti1Uint8({ data: new Uint8Array([1, 2, 3]), dims: [2, 2, 1], voxelSizeMm: [1, 1, 1] })).toThrow( + /data length mismatch/i + ); + }); +}); From ce4c5bc14539ebb4f049c7ec21bc31212fe6b28c Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 14:37:43 -0800 Subject: [PATCH 07/16] svr3d: add brush refinement in slice inspector --- frontend/src/components/SvrVolume3DViewer.tsx | 309 +++++++++++++++++- 1 file changed, 294 insertions(+), 15 deletions(-) diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index 737151d..7bab204 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -557,6 +557,13 @@ export const SvrVolume3DViewer = forwardRef(null); + // Phase 4b: manual refinement brush. + const [segTool, setSegTool] = useState<'seed' | 'brush'>('seed'); + const [brushLabel, setBrushLabel] = useState(BRATS_LABEL_ID.ENHANCING); + const [brushRadiusVox, setBrushRadiusVox] = useState(2); + const [labelsEditTick, setLabelsEditTick] = useState(0); + const brushDragRef = useRef<{ pointerId: number; last: Vec3i | null; data: Uint8Array } | null>(null); + // When the underlying volume changes, drop any internally-generated labels and seed. useEffect(() => { setGeneratedLabels(null); @@ -565,6 +572,9 @@ export const SvrVolume3DViewer = forwardRef) => { - if (!volume) return; + const inspectorPointerToVoxel = useCallback( + (e: React.PointerEvent): Vec3i | null => { + if (!volume) return null; const rect = e.currentTarget.getBoundingClientRect(); - const nx = Math.max(1, rect.width); - const ny = Math.max(1, rect.height); + const w = Math.max(1, rect.width); + const h = Math.max(1, rect.height); - const u = (e.clientX - rect.left) / nx; - const v = (e.clientY - rect.top) / ny; + const u = (e.clientX - rect.left) / w; + const v = (e.clientY - rect.top) / h; const srcCols = inspectorInfo.srcCols; const srcRows = inspectorInfo.srcRows; @@ -851,24 +861,224 @@ export const SvrVolume3DViewer = forwardRef { + if (!volume) return; + + const [nx, ny, nz] = volume.dims; + const strideY = nx; + const strideZ = nx * ny; + + const r = Math.max(0, Math.round(brushRadiusVox)); + const r2 = r * r; + const labelId = (brushLabel & 0xff) >>> 0; + + const set = (x: number, y: number, z: number) => { + if (x < 0 || x >= nx) return; + if (y < 0 || y >= ny) return; + if (z < 0 || z >= nz) return; + labelData[z * strideZ + y * strideY + x] = labelId; + }; + + if (inspectPlane === 'axial') { + const z = voxel.z; + for (let dy = -r; dy <= r; dy++) { + const y = voxel.y + dy; + const dy2 = dy * dy; + for (let dx = -r; dx <= r; dx++) { + if (dx * dx + dy2 > r2) continue; + set(voxel.x + dx, y, z); + } + } + return; + } + + if (inspectPlane === 'coronal') { + const y = voxel.y; + for (let dz = -r; dz <= r; dz++) { + const z = voxel.z + dz; + const dz2 = dz * dz; + for (let dx = -r; dx <= r; dx++) { + if (dx * dx + dz2 > r2) continue; + set(voxel.x + dx, y, z); + } + } + return; + } + + // sagittal + const x = voxel.x; + for (let dz = -r; dz <= r; dz++) { + const z = voxel.z + dz; + const dz2 = dz * dz; + for (let dy = -r; dy <= r; dy++) { + if (dy * dy + dz2 > r2) continue; + set(x, voxel.y + dy, z); + } + } + }, + [brushLabel, brushRadiusVox, inspectPlane, volume] + ); + + const paintBrushStroke = useCallback( + (labelData: Uint8Array, from: Vec3i | null, to: Vec3i) => { + if (!from) { + paintBrushAtVoxel(labelData, to); + return; + } + + // Interpolate in the 2D slice plane so fast drags don't leave gaps. + let a0 = 0; + let b0 = 0; + let a1 = 0; + let b1 = 0; + let fixed = 0; + + if (inspectPlane === 'axial') { + a0 = from.x; + b0 = from.y; + a1 = to.x; + b1 = to.y; + fixed = to.z; } else if (inspectPlane === 'coronal') { - seed = { x: sx, y: sliceIdx, z: sy }; + a0 = from.x; + b0 = from.z; + a1 = to.x; + b1 = to.z; + fixed = to.y; } else { // sagittal - seed = { x: sliceIdx, y: sx, z: sy }; + a0 = from.y; + b0 = from.z; + a1 = to.y; + b1 = to.z; + fixed = to.x; + } + + const da = a1 - a0; + const db = b1 - b0; + const steps = Math.max(Math.abs(da), Math.abs(db)); + + if (steps <= 1) { + paintBrushAtVoxel(labelData, to); + return; } - setSeedVoxel(seed); + for (let i = 0; i <= steps; i++) { + const t = i / steps; + const aa = Math.round(a0 + da * t); + const bb = Math.round(b0 + db * t); + + const v: Vec3i = + inspectPlane === 'axial' + ? { x: aa, y: bb, z: fixed } + : inspectPlane === 'coronal' + ? { x: aa, y: fixed, z: bb } + : { x: fixed, y: aa, z: bb }; + + paintBrushAtVoxel(labelData, v); + } + }, + [inspectPlane, paintBrushAtVoxel] + ); + + const onSliceInspectorPointerDown = useCallback( + (e: React.PointerEvent) => { + if (!volume) return; + + const voxel = inspectorPointerToVoxel(e); + if (!voxel) return; + + if (segTool === 'brush') { + if (growStatus.running) return; + + // Ensure we have a mutable label volume to edit. + let editable: SvrLabelVolume; + if (generatedLabels) { + editable = generatedLabels; + } else if (labelsOverride) { + editable = { data: new Uint8Array(labelsOverride.data), dims: labelsOverride.dims, meta: labelsOverride.meta }; + setGeneratedLabels(editable); + } else { + editable = { data: new Uint8Array(volume.data.length), dims: volume.dims, meta: BRATS_BASE_LABEL_META }; + setGeneratedLabels(editable); + } + + setLabelsEnabled(true); + + brushDragRef.current = { pointerId: e.pointerId, last: voxel, data: editable.data }; + e.currentTarget.setPointerCapture(e.pointerId); + + paintBrushStroke(editable.data, null, voxel); + setLabelsEditTick((t) => t + 1); + + e.preventDefault(); + e.stopPropagation(); + return; + } + + // Seed tool: click to set the seed voxel. + setSeedVoxel(voxel); e.preventDefault(); e.stopPropagation(); }, - [inspectIndex, inspectPlane, inspectorInfo.maxIndex, inspectorInfo.srcCols, inspectorInfo.srcRows, volume] + [generatedLabels, growStatus.running, inspectorPointerToVoxel, labelsOverride, paintBrushStroke, segTool, volume] + ); + + const onSliceInspectorPointerMove = useCallback( + (e: React.PointerEvent) => { + const st = brushDragRef.current; + if (!st || st.pointerId !== e.pointerId) return; + if (!volume) return; + if (segTool !== 'brush') return; + + const voxel = inspectorPointerToVoxel(e); + if (!voxel) return; + + paintBrushStroke(st.data, st.last, voxel); + st.last = voxel; + setLabelsEditTick((t) => t + 1); + + e.preventDefault(); + e.stopPropagation(); + }, + [inspectorPointerToVoxel, paintBrushStroke, segTool, volume] ); + const onSliceInspectorPointerUp = useCallback((e: React.PointerEvent) => { + const st = brushDragRef.current; + if (!st || st.pointerId !== e.pointerId) return; + + brushDragRef.current = null; + + try { + e.currentTarget.releasePointerCapture(e.pointerId); + } catch { + // Ignore. + } + + // Commit labels object to trigger the 3D label texture upload once at the end of the stroke. + setGeneratedLabels((prev) => (prev ? { ...prev } : prev)); + + e.preventDefault(); + e.stopPropagation(); + }, []); + const cancelSeedGrow = useCallback(() => { growAbortRef.current?.abort(); growAbortRef.current = null; @@ -1258,7 +1468,20 @@ export const SvrVolume3DViewer = forwardRef { setInitError(null); @@ -1983,6 +2206,54 @@ void main() {
{labelMix.toFixed(2)}
+
+ + + +
+ + +
Seed:{' '} @@ -1990,6 +2261,8 @@ void main() { {seedVoxel.x},{seedVoxel.y},{seedVoxel.z} + ) : segTool === 'brush' ? ( + Switch tool to Seed, then click slice inspector ) : ( Click the slice inspector to set )} @@ -2268,8 +2541,14 @@ void main() {
From 05847ea14be356776c50ca19a7b0192b1a68c62a Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 14:59:41 -0800 Subject: [PATCH 08/16] onnx: fix vitest import + lint --- frontend/src/components/SvrVolume3DViewer.tsx | 2 +- frontend/src/utils/segmentation/onnx/ortLoader.ts | 4 +++- frontend/tests/onnxTumorSegmentation.test.ts | 9 ++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index 7bab204..9f82b43 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -1215,7 +1215,7 @@ export const SvrVolume3DViewer = forwardRef { mod = await import('onnxruntime-web'); } else { // In production builds, load from our vendored runtime assets. - mod = await import(/* @vite-ignore */ '/onnxruntime/ort.all.bundle.min.mjs'); + // IMPORTANT: keep the specifier non-literal so Vite/Vitest don't try to resolve it during import analysis. + const bundleUrl = '/onnxruntime/' + 'ort.all.bundle.min.mjs'; + mod = await import(/* @vite-ignore */ bundleUrl); } // The ESM bundles export both named exports and a default export. diff --git a/frontend/tests/onnxTumorSegmentation.test.ts b/frontend/tests/onnxTumorSegmentation.test.ts index 78d7d06..9dea168 100644 --- a/frontend/tests/onnxTumorSegmentation.test.ts +++ b/frontend/tests/onnxTumorSegmentation.test.ts @@ -1,4 +1,7 @@ import { describe, expect, it, vi } from 'vitest'; +import type * as Ort from 'onnxruntime-web'; + +type TensorLike = { type: string; dims: number[]; data: unknown }; // Mock ORT loader to avoid pulling real onnxruntime-web + wasm during unit tests. vi.mock('../src/utils/segmentation/onnx/ortLoader', () => { @@ -43,7 +46,7 @@ describe('runTumorSegmentationOnnx', () => { const session = { inputNames: ['input'], outputNames: ['logits'], - run: vi.fn(async (feeds: Record) => { + run: vi.fn(async (feeds: Record) => { expect(Object.keys(feeds)).toEqual(['input']); expect(feeds.input.type).toBe('float32'); expect(feeds.input.dims).toEqual([1, 1, 1, 1, 2]); // [N,C,Z,Y,X] @@ -55,8 +58,8 @@ describe('runTumorSegmentationOnnx', () => { data: logits, }, }; - }), - } as any; + }) as unknown as Ort.InferenceSession['run'], + } as unknown as Ort.InferenceSession; const out = await runTumorSegmentationOnnx({ session, volume, dims }); expect(Array.from(out.labels)).toEqual([1, 4]); From c83b720ced446ff83ba71aea3441df9d633d627e Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 15:31:14 -0800 Subject: [PATCH 09/16] svr3d: fix onnx session lifecycle + seed-grow labels --- frontend/src/components/SvrVolume3DViewer.tsx | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index 9f82b43..af54d5b 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -500,6 +500,25 @@ export const SvrVolume3DViewer = forwardRef(null); const onnxSessionModeRef = useRef(null); const onnxFileInputRef = useRef(null); + + const releaseOnnxSession = useCallback((reason: string) => { + const session = onnxSessionRef.current; + onnxSessionRef.current = null; + onnxSessionModeRef.current = null; + + if (session) { + // Avoid leaking WebGPU/WASM resources if the user swaps/clears models. + void session.release().catch((e) => { + console.warn('[onnx] Failed to release session', { reason, e }); + }); + } + }, []); + + useEffect(() => { + return () => { + releaseOnnxSession('unmount'); + }; + }, [releaseOnnxSession]); const [onnxStatus, setOnnxStatus] = useState<{ cached: boolean; savedAtMs: number | null; @@ -1127,7 +1146,7 @@ export const SvrVolume3DViewer = forwardRef { if (controller.signal.aborted) return; - const next = labels ? new Uint8Array(labels.data) : new Uint8Array(volume.data.length); + const next = hasLabels && labels ? new Uint8Array(labels.data) : new Uint8Array(volume.data.length); for (let i = 0; i < res.mask.length; i++) { if (res.mask[i]) next[i] = growTargetLabel; } @@ -1150,15 +1169,14 @@ export const SvrVolume3DViewer = forwardRef { onnxFileInputRef.current?.click(); }, []); const onnxClearModel = useCallback(() => { - onnxSessionRef.current = null; - onnxSessionModeRef.current = null; + releaseOnnxSession('clear-model'); setOnnxStatus((s) => ({ ...s, sessionReady: false, loading: true, message: 'Clearing cached model…', error: undefined })); void deleteModelBlob(ONNX_TUMOR_MODEL_KEY) @@ -1170,12 +1188,11 @@ export const SvrVolume3DViewer = forwardRef ({ ...s, loading: false, error: msg })); }); - }, [refreshOnnxCacheStatus]); + }, [refreshOnnxCacheStatus, releaseOnnxSession]); const onnxHandleSelectedFile = useCallback( (file: File) => { - onnxSessionRef.current = null; - onnxSessionModeRef.current = null; + releaseOnnxSession('upload-model'); setOnnxStatus((s) => ({ ...s, loading: true, @@ -1193,8 +1210,8 @@ export const SvrVolume3DViewer = forwardRef ({ ...s, loading: false, error: msg })); }); - }, - [refreshOnnxCacheStatus] + }, + [refreshOnnxCacheStatus, releaseOnnxSession] ); const ensureOnnxSession = useCallback(async (): Promise<{ session: Ort.InferenceSession; mode: OnnxSessionMode }> => { From 687f9bd6dd47357af9a1dac80035adcbda61eb6b Mon Sep 17 00:00:00 2001 From: Siqi Chen Date: Mon, 2 Feb 2026 15:38:34 -0800 Subject: [PATCH 10/16] svr3d: lock manual edits during onnx segmentation --- frontend/src/components/SvrVolume3DViewer.tsx | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/frontend/src/components/SvrVolume3DViewer.tsx b/frontend/src/components/SvrVolume3DViewer.tsx index af54d5b..c7e3a45 100644 --- a/frontend/src/components/SvrVolume3DViewer.tsx +++ b/frontend/src/components/SvrVolume3DViewer.tsx @@ -543,7 +543,7 @@ export const SvrVolume3DViewer = forwardRef { void getModelSavedAtMs(ONNX_TUMOR_MODEL_KEY) .then((savedAtMs) => { - setOnnxStatus((s) => ({ ...s, cached: savedAtMs !== null, savedAtMs })); + setOnnxStatus((s) => ({ ...s, cached: savedAtMs !== null, savedAtMs, error: undefined })); }) .catch((e) => { const msg = e instanceof Error ? e.message : String(e); @@ -1024,6 +1024,7 @@ export const SvrVolume3DViewer = forwardRef) => { @@ -2230,7 +2232,7 @@ void main() { value={segTool} onChange={(e) => setSegTool(e.target.value as 'seed' | 'brush')} className="mt-1 w-full px-2 py-1 rounded border border-[var(--border-color)] bg-[var(--bg-secondary)]" - disabled={!volume || growStatus.running} + disabled={!volume || growStatus.running || onnxSegRunning} > @@ -2259,7 +2261,7 @@ void main() { value={brushLabel} onChange={(e) => setBrushLabel(Number(e.target.value))} className="mt-1 w-full px-2 py-1 rounded border border-[var(--border-color)] bg-[var(--bg-secondary)]" - disabled={!volume || segTool !== 'brush' || growStatus.running} + disabled={!volume || segTool !== 'brush' || growStatus.running || onnxSegRunning} > @@ -2287,7 +2289,7 @@ void main() { @@ -2117,468 +2123,485 @@ void main() {
{controlsCollapsed ? null : ( -
+
3D Controls
- - - - - - - - - - -
-
Segmentation
-
-