diff --git a/.gitignore b/.gitignore index ec03a4c..a257ec8 100644 --- a/.gitignore +++ b/.gitignore @@ -208,3 +208,6 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Claude Code +CLAUDE.md diff --git a/gui/app.py b/gui/app.py index e0afddf..bb5a408 100644 --- a/gui/app.py +++ b/gui/app.py @@ -31,6 +31,8 @@ QGraphicsDropShadowEffect, QComboBox, QSlider, + QRadioButton, + QButtonGroup, ) from PyQt5.QtCore import Qt, QThread, pyqtSignal, QMimeData, QPropertyAnimation, QEasingCurve from PyQt5.QtGui import QDragEnterEvent, QDropEvent, QFont, QColor, QPalette, QLinearGradient @@ -222,6 +224,48 @@ QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { height: 0px; } + +QRadioButton { + spacing: 8px; + font-size: 13px; + color: #1d1d1f; +} + +QRadioButton::indicator { + width: 16px; + height: 16px; + border-radius: 8px; + border: 2px solid #c7c7cc; + background-color: #ffffff; +} + +QRadioButton::indicator:checked { + background-color: #0071e3; + border-color: #0071e3; +} + +QRadioButton::indicator:hover { + border-color: #0071e3; +} + +QPushButton#calcFlatfieldButton { + background-color: #5856d6; + color: white; + font-size: 13px; + font-weight: 600; + border: none; + border-radius: 8px; + padding: 8px 16px; +} + +QPushButton#calcFlatfieldButton:hover { + background-color: #6866e0; +} + +QPushButton#calcFlatfieldButton:disabled { + background-color: #c7c7cc; + color: #8e8e93; +} """ @@ -232,12 +276,22 @@ class PreviewWorker(QThread): finished = pyqtSignal(object, object, object) # color_before, color_after, fused error = pyqtSignal(str) - def __init__(self, tiff_path, preview_cols, preview_rows, downsample_factor): + def __init__( + self, + tiff_path, + preview_cols, + preview_rows, + downsample_factor, + flatfield=None, + darkfield=None, + ): super().__init__() self.tiff_path = tiff_path self.preview_cols = preview_cols self.preview_rows = preview_rows self.downsample_factor = downsample_factor + self.flatfield = flatfield + self.darkfield = darkfield def run(self): try: @@ -248,7 +302,10 @@ def run(self): # Create TileFusion instance - handles both OME-TIFF and SQUID formats tf_full = TileFusion( - self.tiff_path, downsample_factors=(self.downsample_factor, self.downsample_factor) + self.tiff_path, + downsample_factors=(self.downsample_factor, self.downsample_factor), + flatfield=self.flatfield, + darkfield=self.darkfield, ) positions = np.array(tf_full._tile_positions) @@ -416,7 +473,14 @@ class FusionWorker(QThread): error = pyqtSignal(str) def __init__( - self, tiff_path, do_registration, blend_pixels, downsample_factor, fusion_mode="blended" + self, + tiff_path, + do_registration, + blend_pixels, + downsample_factor, + fusion_mode="blended", + flatfield=None, + darkfield=None, ): super().__init__() self.tiff_path = tiff_path @@ -424,6 +488,8 @@ def __init__( self.blend_pixels = blend_pixels self.downsample_factor = downsample_factor self.fusion_mode = fusion_mode + self.flatfield = flatfield + self.darkfield = darkfield self.output_path = None def run(self): @@ -464,6 +530,8 @@ def run(self): output_path=output_path, blend_pixels=self.blend_pixels, downsample_factors=(self.downsample_factor, self.downsample_factor), + flatfield=self.flatfield, + darkfield=self.darkfield, ) load_time = time.time() - step_start self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]") @@ -676,19 +744,196 @@ def setFile(self, file_path): ) +class FlatfieldDropArea(QFrame): + """Small drag and drop area for flatfield .npy files.""" + + fileDropped = pyqtSignal(str) + + def __init__(self): + super().__init__() + self.setAcceptDrops(True) + self.setFrameStyle(QFrame.StyledPanel | QFrame.Sunken) + self.setMinimumHeight(60) + self.setMaximumHeight(80) + self.setStyleSheet( + """ + QFrame { + border: 2px dashed #c7c7cc; + border-radius: 8px; + background-color: #ffffff; + } + QFrame:hover { + border-color: #5856d6; + background-color: #f5f5ff; + } + """ + ) + + layout = QHBoxLayout(self) + layout.setSpacing(8) + + self.icon_label = QLabel("📄") + self.icon_label.setStyleSheet("font-size: 20px; border: none; background: transparent;") + layout.addWidget(self.icon_label) + + self.label = QLabel("Drop flatfield .npy file here or click to browse") + self.label.setStyleSheet( + "color: #86868b; font-size: 12px; border: none; background: transparent;" + ) + layout.addWidget(self.label) + layout.addStretch() + + self.file_path = None + + def dragEnterEvent(self, event: QDragEnterEvent): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + self.setStyleSheet( + """ + QFrame { + border: 2px dashed #5856d6; + border-radius: 8px; + background-color: #ebebff; + } + """ + ) + + def dragLeaveEvent(self, event): + self.setStyleSheet( + """ + QFrame { + border: 2px dashed #c7c7cc; + border-radius: 8px; + background-color: #ffffff; + } + """ + ) + + def dropEvent(self, event: QDropEvent): + self.setStyleSheet( + """ + QFrame { + border: 2px dashed #c7c7cc; + border-radius: 8px; + background-color: #ffffff; + } + """ + ) + + urls = event.mimeData().urls() + if urls: + file_path = urls[0].toLocalFile() + if file_path.endswith(".npy"): + self.setFile(file_path) + self.fileDropped.emit(file_path) + + def mousePressEvent(self, event): + file_path, _ = QFileDialog.getOpenFileName( + self, "Select flatfield file", "", "NumPy files (*.npy);;All files (*.*)" + ) + if file_path: + self.setFile(file_path) + self.fileDropped.emit(file_path) + + def setFile(self, file_path): + self.file_path = file_path + path = Path(file_path) + self.icon_label.setText("✅") + self.label.setText(path.name) + self.label.setStyleSheet( + "color: #5856d6; font-size: 12px; font-weight: 600; border: none; background: transparent;" + ) + + def clear(self): + self.file_path = None + self.icon_label.setText("📄") + self.label.setText("Drop flatfield .npy file here or click to browse") + self.label.setStyleSheet( + "color: #86868b; font-size: 12px; border: none; background: transparent;" + ) + + +class FlatfieldWorker(QThread): + """Worker thread for calculating flatfield using BaSiCPy.""" + + progress = pyqtSignal(str) + finished = pyqtSignal(object, object) # flatfield, darkfield (or None) + error = pyqtSignal(str) + + def __init__(self, file_path, n_samples=50, use_darkfield=False): + super().__init__() + self.file_path = file_path + self.n_samples = n_samples + self.use_darkfield = use_darkfield + + def run(self): + try: + import numpy as np + from tilefusion import TileFusion, calculate_flatfield, HAS_BASICPY + + if not HAS_BASICPY: + self.error.emit("BaSiCPy is not installed. Install with: pip install basicpy") + return + + self.progress.emit("Loading metadata...") + + # Create TileFusion instance to read tiles. + # NOTE: No flatfield/darkfield passed intentionally - flatfield estimation + # must be performed on raw, uncorrected tiles. + tf = TileFusion(self.file_path) + + # Determine how many tiles to sample + n_tiles = tf.n_tiles + n_samples = min(self.n_samples, n_tiles) + + self.progress.emit(f"Sampling {n_samples} tiles from {n_tiles} total...") + + # Random sample of tile indices + rng = np.random.default_rng(42) + sample_indices = rng.choice(n_tiles, size=n_samples, replace=False) + sample_indices = sorted(sample_indices) + + # Read sampled tiles + # NOTE: Using private method tf._read_tile intentionally. + # FlatfieldWorker needs direct access to raw tile data for sampling. + tiles = [] + for i, tile_idx in enumerate(sample_indices): + self.progress.emit(f"Reading tile {i+1}/{n_samples}...") + tile = tf._read_tile(tile_idx) + tiles.append(tile) + + self.progress.emit("Calculating flatfield with BaSiCPy...") + flatfield, darkfield = calculate_flatfield( + tiles, use_darkfield=self.use_darkfield, constant_darkfield=True + ) + + self.progress.emit("Flatfield calculation complete!") + self.finished.emit(flatfield, darkfield) + + except Exception as e: + import traceback + + self.error.emit(f"Error: {str(e)}\n{traceback.format_exc()}") + + class StitcherGUI(QMainWindow): """Main GUI window for the stitcher.""" def __init__(self): super().__init__() self.setWindowTitle("Stitcher") - self.setMinimumSize(500, 600) + self.setMinimumSize(500, 850) self.worker = None self.output_path = None self.regions = [] # List of region names for multi-region outputs self.is_multi_region = False + # Flatfield correction state + self.flatfield = None # Shape (C, Y, X) or None + self.darkfield = None # Shape (C, Y, X) or None + self.flatfield_worker = None + self.setup_ui() def setup_ui(self): @@ -740,6 +985,116 @@ def setup_ui(self): layout.addWidget(preview_group) + # Flatfield correction section + flatfield_group = QGroupBox("Flatfield Correction") + flatfield_layout = QVBoxLayout(flatfield_group) + flatfield_layout.setSpacing(10) + + # Enable flatfield checkbox + self.flatfield_checkbox = QCheckBox("Enable flatfield correction") + self.flatfield_checkbox.setChecked(True) # Default enabled + self.flatfield_checkbox.setMinimumHeight(32) + self.flatfield_checkbox.toggled.connect(self.on_flatfield_toggled) + flatfield_layout.addWidget(self.flatfield_checkbox) + + # Container for flatfield options (shown when enabled) + self.flatfield_options_widget = QWidget() + flatfield_options_layout = QVBoxLayout(self.flatfield_options_widget) + flatfield_options_layout.setContentsMargins(24, 0, 0, 0) + flatfield_options_layout.setSpacing(8) + + # Radio buttons for Calculate vs Load + self.flatfield_mode_group = QButtonGroup(self) + radio_layout = QHBoxLayout() + + self.calc_radio = QRadioButton("Calculate from tiles") + self.calc_radio.setChecked(True) + self.flatfield_mode_group.addButton(self.calc_radio, 0) + radio_layout.addWidget(self.calc_radio) + + self.load_radio = QRadioButton("Load from file") + self.flatfield_mode_group.addButton(self.load_radio, 1) + radio_layout.addWidget(self.load_radio) + + radio_layout.addStretch() + flatfield_options_layout.addLayout(radio_layout) + + # Calculate options container + self.calc_options_widget = QWidget() + calc_options_layout = QVBoxLayout(self.calc_options_widget) + calc_options_layout.setContentsMargins(0, 0, 0, 0) + calc_options_layout.setSpacing(8) + + # Darkfield checkbox + self.darkfield_checkbox = QCheckBox("Include darkfield correction") + self.darkfield_checkbox.setChecked(False) + calc_options_layout.addWidget(self.darkfield_checkbox) + + # Calculate and save buttons + calc_btn_layout = QHBoxLayout() + self.calc_flatfield_button = QPushButton("Calculate Flatfield") + self.calc_flatfield_button.setObjectName("calcFlatfieldButton") + self.calc_flatfield_button.setCursor(Qt.PointingHandCursor) + self.calc_flatfield_button.clicked.connect(self.calculate_flatfield) + self.calc_flatfield_button.setEnabled(False) + calc_btn_layout.addWidget(self.calc_flatfield_button) + + self.save_flatfield_button = QPushButton("Save") + self.save_flatfield_button.setCursor(Qt.PointingHandCursor) + self.save_flatfield_button.clicked.connect(self.save_flatfield) + self.save_flatfield_button.setEnabled(False) + self.save_flatfield_button.setToolTip("Save calculated flatfield to .npy file") + calc_btn_layout.addWidget(self.save_flatfield_button) + + calc_btn_layout.addStretch() + calc_options_layout.addLayout(calc_btn_layout) + + flatfield_options_layout.addWidget(self.calc_options_widget) + + # Load options container + self.load_options_widget = QWidget() + self.load_options_widget.setVisible(False) + load_options_layout = QVBoxLayout(self.load_options_widget) + load_options_layout.setContentsMargins(0, 0, 0, 0) + + self.flatfield_drop_area = FlatfieldDropArea() + self.flatfield_drop_area.fileDropped.connect(self.on_flatfield_dropped) + load_options_layout.addWidget(self.flatfield_drop_area) + + flatfield_options_layout.addWidget(self.load_options_widget) + + # Flatfield status and view button + status_layout = QHBoxLayout() + self.flatfield_status = QLabel("No flatfield") + self.flatfield_status.setStyleSheet("color: #86868b; font-size: 11px;") + status_layout.addWidget(self.flatfield_status) + + self.view_flatfield_button = QPushButton("View") + self.view_flatfield_button.setCursor(Qt.PointingHandCursor) + self.view_flatfield_button.clicked.connect(self.view_flatfield) + self.view_flatfield_button.setEnabled(False) + self.view_flatfield_button.setToolTip("View flatfield and darkfield") + self.view_flatfield_button.setFixedWidth(60) + status_layout.addWidget(self.view_flatfield_button) + + self.clear_flatfield_button = QPushButton("Clear") + self.clear_flatfield_button.setCursor(Qt.PointingHandCursor) + self.clear_flatfield_button.clicked.connect(self.clear_flatfield) + self.clear_flatfield_button.setEnabled(False) + self.clear_flatfield_button.setToolTip("Clear loaded flatfield") + self.clear_flatfield_button.setFixedWidth(60) + status_layout.addWidget(self.clear_flatfield_button) + status_layout.addStretch() + + flatfield_options_layout.addLayout(status_layout) + + flatfield_layout.addWidget(self.flatfield_options_widget) + + # Connect radio button signals + self.flatfield_mode_group.buttonClicked.connect(self.on_flatfield_mode_changed) + + layout.addWidget(flatfield_group) + # Registration settings reg_group = QGroupBox("Settings") reg_layout = QVBoxLayout(reg_group) @@ -862,6 +1217,33 @@ def on_file_dropped(self, file_path): self.log(f"Selected OME-TIFF: {file_path}") self.run_button.setEnabled(True) self.preview_button.setEnabled(True) + self.calc_flatfield_button.setEnabled(True) + # Clear previous flatfield when new file is selected + self.flatfield = None + self.darkfield = None + self.flatfield_status.setText("No flatfield") + self.flatfield_status.setStyleSheet("color: #86868b; font-size: 11px;") + self.flatfield_drop_area.clear() + self.view_flatfield_button.setEnabled(False) + self.clear_flatfield_button.setEnabled(False) + self.save_flatfield_button.setEnabled(False) + + # Auto-load existing flatfield if present, otherwise disable correction + # For directories (SQUID folders), also check inside the directory + if path.is_dir(): + flatfield_path = path / f"{path.name}_flatfield.npy" + if not flatfield_path.exists(): + # Fallback: check next to the directory + flatfield_path = path.parent / f"{path.name}_flatfield.npy" + else: + flatfield_path = path.parent / f"{path.stem}_flatfield.npy" + + if flatfield_path.exists(): + self.log(f"Found existing flatfield: {flatfield_path.name}") + self.on_flatfield_dropped(str(flatfield_path)) + self.flatfield_drop_area.setFile(str(flatfield_path)) + else: + self.flatfield_checkbox.setChecked(False) def on_registration_toggled(self, checked): self.downsample_widget.setVisible(checked) @@ -869,6 +1251,208 @@ def on_registration_toggled(self, checked): def on_blend_toggled(self, checked): self.blend_value_widget.setVisible(checked) + def on_flatfield_toggled(self, checked): + # Only show/hide flatfield options; preserve any loaded/calculated data + self.flatfield_options_widget.setVisible(checked) + + def on_flatfield_mode_changed(self, button): + is_calculate = self.calc_radio.isChecked() + self.calc_options_widget.setVisible(is_calculate) + self.load_options_widget.setVisible(not is_calculate) + + def calculate_flatfield(self): + if not self.drop_area.file_path: + return + + self.calc_flatfield_button.setEnabled(False) + self.flatfield_status.setText("Calculating flatfield...") + self.flatfield_status.setStyleSheet("color: #ff9500; font-size: 11px;") + + self.flatfield_worker = FlatfieldWorker( + self.drop_area.file_path, + n_samples=50, + use_darkfield=self.darkfield_checkbox.isChecked(), + ) + self.flatfield_worker.progress.connect(self.log) + self.flatfield_worker.finished.connect(self.on_flatfield_calculated) + self.flatfield_worker.error.connect(self.on_flatfield_error) + self.flatfield_worker.start() + + def on_flatfield_calculated(self, flatfield, darkfield): + self.flatfield = flatfield + self.darkfield = darkfield + self.calc_flatfield_button.setEnabled(True) + self.save_flatfield_button.setEnabled(True) + self.view_flatfield_button.setEnabled(True) + self.clear_flatfield_button.setEnabled(True) + + n_channels = flatfield.shape[0] + status = f"Flatfield ready ({n_channels} channels)" + if darkfield is not None: + status += " + darkfield" + self.flatfield_status.setText(status) + self.flatfield_status.setStyleSheet("color: #34c759; font-size: 11px; font-weight: 600;") + self.log(f"Flatfield calculation complete: {flatfield.shape}") + + # Auto-save flatfield next to input file + if self.drop_area.file_path: + try: + from tilefusion import save_flatfield as save_ff + + input_path = Path(self.drop_area.file_path) + # Use path.name for directories, path.stem for files (consistent with auto-load) + if input_path.is_dir(): + auto_save_path = input_path / f"{input_path.name}_flatfield.npy" + else: + auto_save_path = input_path.parent / f"{input_path.stem}_flatfield.npy" + save_ff(auto_save_path, self.flatfield, self.darkfield) + self.log(f"Auto-saved flatfield to {auto_save_path}") + except Exception as e: + self.log(f"Warning: Could not auto-save flatfield: {e}") + + def save_flatfield(self): + if self.flatfield is None: + return + + # Default path based on input (consistent with auto-save/auto-load) + default_path = "flatfield.npy" + if self.drop_area.file_path: + input_path = Path(self.drop_area.file_path) + if input_path.is_dir(): + default_path = str(input_path / f"{input_path.name}_flatfield.npy") + else: + default_path = str(input_path.parent / f"{input_path.stem}_flatfield.npy") + + file_path, _ = QFileDialog.getSaveFileName( + self, + "Save Flatfield", + default_path, + "NumPy files (*.npy);;All files (*.*)", + ) + if file_path: + try: + from tilefusion import save_flatfield as save_ff + + save_ff(Path(file_path), self.flatfield, self.darkfield) + self.log(f"Saved flatfield to {file_path}") + except Exception as e: + self.log(f"Error saving flatfield: {e}") + + def on_flatfield_error(self, error_msg): + self.calc_flatfield_button.setEnabled(True) + self.flatfield_status.setText("Calculation failed") + self.flatfield_status.setStyleSheet("color: #ff3b30; font-size: 11px;") + self.log(error_msg) + + def on_flatfield_dropped(self, file_path): + import numpy as np + + try: + from tilefusion import load_flatfield + + self.flatfield, self.darkfield = load_flatfield(Path(file_path)) + n_channels = self.flatfield.shape[0] + status = f"Loaded ({n_channels} channels)" + if self.darkfield is not None: + status += " + darkfield" + self.flatfield_status.setText(status) + self.flatfield_status.setStyleSheet( + "color: #34c759; font-size: 11px; font-weight: 600;" + ) + self.view_flatfield_button.setEnabled(True) + self.clear_flatfield_button.setEnabled(True) + self.save_flatfield_button.setEnabled(True) + # Enable flatfield correction when successfully loaded + self.flatfield_checkbox.setChecked(True) + self.log(f"Loaded flatfield from {file_path}: {self.flatfield.shape}") + except Exception as e: + # Clear any stale flatfield data on load failure + self.flatfield = None + self.darkfield = None + self.flatfield_status.setText(f"Load failed: {e}") + self.flatfield_status.setStyleSheet("color: #ff3b30; font-size: 11px;") + self.view_flatfield_button.setEnabled(False) + self.log(f"Error loading flatfield: {e}") + + def view_flatfield(self): + if self.flatfield is None: + return + + try: + import matplotlib + + matplotlib.use("Agg") # Non-interactive backend + import matplotlib.pyplot as plt + import numpy as np + import tempfile + import subprocess + + n_channels = self.flatfield.shape[0] + has_darkfield = self.darkfield is not None + n_rows = 2 if has_darkfield else 1 + + fig, axes = plt.subplots(n_rows, n_channels, figsize=(4 * n_channels, 4 * n_rows)) + + # Handle single channel case (axes not 2D) + if n_channels == 1 and n_rows == 1: + axes = [[axes]] + elif n_channels == 1: + axes = [[ax] for ax in axes] + elif n_rows == 1: + axes = [axes] + + # First row: flatfield + for ch in range(n_channels): + ax = axes[0][ch] + im = ax.imshow(self.flatfield[ch], cmap="viridis", vmin=0) + ax.set_title(f"Flatfield Ch{ch}") + ax.axis("off") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Second row: darkfield (if available) + if has_darkfield: + for ch in range(n_channels): + ax = axes[1][ch] + im = ax.imshow(self.darkfield[ch], cmap="magma", vmin=0) + # Show constant value in title if darkfield is uniform + df_val = self.darkfield[ch].ravel()[0] + if np.allclose(self.darkfield[ch], df_val): + ax.set_title(f"Darkfield Ch{ch} (={df_val:.1f})") + else: + ax.set_title(f"Darkfield Ch{ch}") + ax.axis("off") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + plt.tight_layout() + + # Save to temp file and open with system viewer + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + fig.savefig(f.name, dpi=150, bbox_inches="tight") + plt.close(fig) + # Open with default image viewer + if sys.platform == "darwin": + subprocess.Popen(["open", f.name]) + elif sys.platform == "win32": + subprocess.Popen(["cmd", "/c", "start", "", f.name]) + else: + subprocess.Popen(["xdg-open", f.name]) + + self.log("Opened flatfield viewer") + except Exception as e: + self.log(f"Error opening viewer: {e}") + + def clear_flatfield(self): + """Clear loaded/calculated flatfield.""" + self.flatfield = None + self.darkfield = None + self.flatfield_status.setText("No flatfield") + self.flatfield_status.setStyleSheet("color: #86868b; font-size: 11px;") + self.view_flatfield_button.setEnabled(False) + self.clear_flatfield_button.setEnabled(False) + self.save_flatfield_button.setEnabled(False) + self.flatfield_drop_area.clear() + self.log("Flatfield cleared") + def log(self, message): self.log_text.append(message) self.log_text.verticalScrollBar().setValue(self.log_text.verticalScrollBar().maximum()) @@ -890,12 +1474,18 @@ def run_stitching(self): blend_pixels = (0, 0) fusion_mode = "direct" + # Get flatfield if enabled + flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None + darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + self.worker = FusionWorker( self.drop_area.file_path, self.registration_checkbox.isChecked(), blend_pixels, self.downsample_spin.value(), fusion_mode, + flatfield=flatfield, + darkfield=darkfield, ) self.worker.progress.connect(self.log) self.worker.finished.connect(self.on_fusion_finished) @@ -950,11 +1540,17 @@ def run_preview(self): self.log_text.clear() self.log("Starting preview...") + # Get flatfield if enabled + flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None + darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + self.preview_worker = PreviewWorker( self.drop_area.file_path, self.preview_cols_spin.value(), self.preview_rows_spin.value(), self.downsample_spin.value(), + flatfield=flatfield, + darkfield=darkfield, ) self.preview_worker.progress.connect(self.log) self.preview_worker.finished.connect(self.on_preview_finished) diff --git a/src/tilefusion/__init__.py b/src/tilefusion/__init__.py index c88ef52..496a351 100644 --- a/src/tilefusion/__init__.py +++ b/src/tilefusion/__init__.py @@ -12,6 +12,24 @@ from .core import TileFusion from .utils import USING_GPU +from .flatfield import ( + calculate_flatfield, + apply_flatfield, + apply_flatfield_region, + save_flatfield, + load_flatfield, + HAS_BASICPY, +) __version__ = "0.1.0" -__all__ = ["TileFusion", "USING_GPU", "__version__"] +__all__ = [ + "TileFusion", + "USING_GPU", + "__version__", + "calculate_flatfield", + "apply_flatfield", + "apply_flatfield_region", + "save_flatfield", + "load_flatfield", + "HAS_BASICPY", +] diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index 4086a55..1c49d68 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -32,6 +32,7 @@ ) from .fusion import accumulate_tile_shard, normalize_shard from .optimization import links_from_pairwise_metrics, solve_global, two_round_optimization +from .flatfield import apply_flatfield, apply_flatfield_region from .io import ( load_ome_tiff_metadata, load_individual_tiffs_metadata, @@ -107,6 +108,8 @@ def __init__( channel_to_use: int = 0, multiscale_downsample: str = "stride", region: Optional[str] = None, + flatfield: Optional[np.ndarray] = None, + darkfield: Optional[np.ndarray] = None, ): self.tiff_path = Path(tiff_path) if not self.tiff_path.exists(): @@ -223,6 +226,23 @@ def __init__( self.fused_ts = None self.center = None + # Flatfield correction (optional) + self._flatfield = flatfield # Shape (C, Y, X) or None + self._darkfield = darkfield # Shape (C, Y, X) or None + + # Validate flatfield/darkfield shapes match tile dimensions + expected_shape = (self.channels, self.Y, self.X) + if flatfield is not None and flatfield.shape != expected_shape: + raise ValueError( + f"flatfield.shape {flatfield.shape} does not match expected " + f"tile shape {expected_shape} (channels, Y, X)" + ) + if darkfield is not None and darkfield.shape != expected_shape: + raise ValueError( + f"darkfield.shape {darkfield.shape} does not match expected " + f"tile shape {expected_shape} (channels, Y, X)" + ) + # Thread-local storage for TiffFile handles (thread-safe concurrent access) self._thread_local = threading.local() self._handles_lock = threading.Lock() @@ -435,9 +455,9 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> n if self._is_zarr_format: zarr_ts = self._metadata["tensorstore"] is_3d = self._metadata.get("is_3d", False) - return read_zarr_tile(zarr_ts, tile_idx, is_3d) + tile = read_zarr_tile(zarr_ts, tile_idx, is_3d) elif self._is_individual_tiffs_format: - return read_individual_tiffs_tile( + tile = read_individual_tiffs_tile( self._metadata["image_folder"], self._metadata["channel_names"], self._metadata["tile_identifiers"], @@ -447,7 +467,7 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> n time_folders=self._time_folders, ) elif self._is_ome_tiff_tiles_format: - return read_ome_tiff_tiles_tile( + tile = read_ome_tiff_tiles_tile( self._metadata["ome_tiff_folder"], self._metadata["tile_identifiers"], self._metadata["tile_file_map"], @@ -459,7 +479,13 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> n else: # Use thread-local handle for thread-safe concurrent reads handle = self._get_thread_local_handle() - return read_ome_tiff_tile(self.tiff_path, tile_idx, handle) + tile = read_ome_tiff_tile(self.tiff_path, tile_idx, handle) + + # Apply flatfield correction if enabled + if self._flatfield is not None: + tile = apply_flatfield(tile, self._flatfield, self._darkfield) + + return tile def _read_tile_region( self, @@ -476,9 +502,11 @@ def _read_tile_region( if self._is_zarr_format: zarr_ts = self._metadata["tensorstore"] is_3d = self._metadata.get("is_3d", False) - return read_zarr_region(zarr_ts, tile_idx, y_slice, x_slice, self.channel_to_use, is_3d) + region = read_zarr_region( + zarr_ts, tile_idx, y_slice, x_slice, self.channel_to_use, is_3d + ) elif self._is_individual_tiffs_format: - return read_individual_tiffs_region( + region = read_individual_tiffs_region( self._metadata["image_folder"], self._metadata["channel_names"], self._metadata["tile_identifiers"], @@ -491,7 +519,7 @@ def _read_tile_region( time_folders=self._time_folders, ) elif self._is_ome_tiff_tiles_format: - return read_ome_tiff_tiles_region( + region = read_ome_tiff_tiles_region( self._metadata["ome_tiff_folder"], self._metadata["tile_identifiers"], self._metadata["tile_file_map"], @@ -506,7 +534,15 @@ def _read_tile_region( else: # Use thread-local handle for thread-safe concurrent reads handle = self._get_thread_local_handle() - return read_ome_tiff_region(self.tiff_path, tile_idx, y_slice, x_slice, handle) + region = read_ome_tiff_region(self.tiff_path, tile_idx, y_slice, x_slice, handle) + + # Apply flatfield correction if enabled + if self._flatfield is not None: + region = apply_flatfield_region( + region, self._flatfield, self._darkfield, y_slice, x_slice + ) + + return region # ------------------------------------------------------------------------- # Registration diff --git a/src/tilefusion/flatfield.py b/src/tilefusion/flatfield.py new file mode 100644 index 0000000..1e858e1 --- /dev/null +++ b/src/tilefusion/flatfield.py @@ -0,0 +1,309 @@ +""" +Flatfield correction module using BaSiCPy. + +Provides functions to calculate and apply flatfield (and optionally darkfield) +correction for microscopy images. +""" + +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np + +try: + from basicpy import BaSiC + + HAS_BASICPY = True +except ImportError: + HAS_BASICPY = False + + +def calculate_flatfield( + tiles: List[np.ndarray], + use_darkfield: bool = False, + constant_darkfield: bool = True, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Calculate flatfield (and optionally darkfield) using BaSiCPy. + + Parameters + ---------- + tiles : list of ndarray + List of tile images, each with shape (C, Y, X) or (Y, X) for single-channel. + 2D arrays are automatically converted to 3D with shape (1, Y, X). + use_darkfield : bool + Whether to also compute darkfield correction. + constant_darkfield : bool + If True, darkfield is reduced to a single constant value (median) per + channel. This is physically appropriate since dark current is typically + uniform across the sensor. Default is True. + + Returns + ------- + flatfield : ndarray + Flatfield correction array with shape (C, Y, X), float32. + darkfield : ndarray or None + Darkfield correction array with shape (C, Y, X), or None if not computed. + If constant_darkfield=True, each channel slice will be a constant value. + + Raises + ------ + ImportError + If basicpy is not installed. + ValueError + If tiles list is empty or tiles have inconsistent shapes. + """ + if not HAS_BASICPY: + raise ImportError( + "basicpy is required for flatfield calculation. Install with: pip install basicpy" + ) + + if not tiles: + raise ValueError("tiles list is empty") + + # Validate tile dimensionality: only 2D (Y, X) or 3D (C, Y, X) supported + for i, t in enumerate(tiles): + if t.ndim not in (2, 3): + raise ValueError(f"Tile {i} has {t.ndim} dimensions; expected 2 (Y, X) or 3 (C, Y, X)") + + # Support 2D (Y, X) arrays by converting to 3D (1, Y, X) + tiles = [t[np.newaxis, ...] if t.ndim == 2 else t for t in tiles] + + # Get shape from first tile + n_channels = tiles[0].shape[0] + tile_shape = tiles[0].shape[1:] # (Y, X) + + # Validate all tiles have same shape + for i, tile in enumerate(tiles): + if tile.shape[0] != n_channels: + raise ValueError(f"Tile {i} has {tile.shape[0]} channels, expected {n_channels}") + if tile.shape[1:] != tile_shape: + raise ValueError(f"Tile {i} has shape {tile.shape[1:]}, expected {tile_shape}") + + # Calculate flatfield per channel + flatfield = np.zeros((n_channels,) + tile_shape, dtype=np.float32) + darkfield = np.zeros((n_channels,) + tile_shape, dtype=np.float32) if use_darkfield else None + + for ch in range(n_channels): + # Stack channel data from all tiles: shape (n_tiles, Y, X) + channel_stack = np.stack([tile[ch] for tile in tiles], axis=0) + + # Create BaSiC instance and fit + basic = BaSiC(get_darkfield=use_darkfield, smoothness_flatfield=1.0) + try: + basic.fit(channel_stack) + except Exception as exc: + raise RuntimeError( + f"BaSiCPy flatfield fitting failed for channel {ch} " + f"with data shape {channel_stack.shape}" + ) from exc + + flatfield[ch] = basic.flatfield.astype(np.float32) + + if use_darkfield: + if constant_darkfield: + # Use median value for constant darkfield (more robust than mean) + df_value = np.median(basic.darkfield) + darkfield[ch] = np.full(tile_shape, df_value, dtype=np.float32) + else: + darkfield[ch] = basic.darkfield.astype(np.float32) + + return flatfield, darkfield + + +def apply_flatfield( + tile: np.ndarray, + flatfield: np.ndarray, + darkfield: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Apply flatfield correction to a tile. + + Formula: + If darkfield is provided: corrected = (raw - darkfield) / flatfield + Otherwise: corrected = raw / flatfield + + Parameters + ---------- + tile : ndarray + Input tile with shape (C, Y, X). + flatfield : ndarray + Flatfield correction array with shape (C, Y, X). + darkfield : ndarray, optional + Darkfield correction array with shape (C, Y, X). + + Returns + ------- + corrected : ndarray + Corrected tile with shape (C, Y, X), cast back to the input dtype. + For integer dtypes, values are clipped to the valid range before + casting (e.g., negative values clipped to 0 for unsigned types). + + Raises + ------ + ValueError + If tile and flatfield shapes are incompatible. + """ + # Validate shapes + if tile.shape != flatfield.shape: + raise ValueError( + f"Tile shape {tile.shape} does not match flatfield shape {flatfield.shape}" + ) + if darkfield is not None and tile.shape != darkfield.shape: + raise ValueError( + f"Tile shape {tile.shape} does not match darkfield shape {darkfield.shape}" + ) + + # Convert to float32 to avoid underflow with unsigned integer types + tile_f = tile.astype(np.float32) + # For flatfield values <= 1e-6, use 1.0 to avoid division by zero/near-zero + flatfield_safe = np.where(flatfield > 1e-6, flatfield, 1.0).astype(np.float32) + + if darkfield is not None: + corrected = (tile_f - darkfield.astype(np.float32)) / flatfield_safe + else: + corrected = tile_f / flatfield_safe + + # Clip to valid range for integer dtypes to avoid wraparound + if np.issubdtype(tile.dtype, np.integer): + info = np.iinfo(tile.dtype) + corrected = np.clip(corrected, info.min, info.max) + + return corrected.astype(tile.dtype) + + +def apply_flatfield_region( + region: np.ndarray, + flatfield: np.ndarray, + darkfield: Optional[np.ndarray], + y_slice: slice, + x_slice: slice, +) -> np.ndarray: + """ + Apply flatfield correction to a tile region. + + Parameters + ---------- + region : ndarray + Input region with shape (C, h, w) or (h, w). + flatfield : ndarray + Full flatfield correction array with shape (C, Y, X). + darkfield : ndarray, optional + Full darkfield correction array with shape (C, Y, X). + y_slice, x_slice : slice + Slices defining the region within the full tile. + + Returns + ------- + corrected : ndarray + Corrected region with same shape as input. + + Raises + ------ + ValueError + If region and flatfield shapes are incompatible. + """ + # Validate channel count for 3D regions + if region.ndim == 3 and region.shape[0] != flatfield.shape[0]: + raise ValueError( + f"Region has {region.shape[0]} channels but flatfield has {flatfield.shape[0]} channels" + ) + + # Extract corresponding flatfield/darkfield regions + if region.ndim == 2: + ff_region = flatfield[0, y_slice, x_slice] + df_region = darkfield[0, y_slice, x_slice] if darkfield is not None else None + else: + ff_region = flatfield[:, y_slice, x_slice] + df_region = darkfield[:, y_slice, x_slice] if darkfield is not None else None + + # Convert to float32 to avoid underflow with unsigned integer types + region_f = region.astype(np.float32) + # For flatfield values <= 1e-6, use 1.0 to avoid division by zero/near-zero + ff_safe = np.where(ff_region > 1e-6, ff_region, 1.0).astype(np.float32) + + if df_region is not None: + corrected = (region_f - df_region.astype(np.float32)) / ff_safe + else: + corrected = region_f / ff_safe + + # Clip to valid range for integer dtypes to avoid wraparound + if np.issubdtype(region.dtype, np.integer): + info = np.iinfo(region.dtype) + corrected = np.clip(corrected, info.min, info.max) + + return corrected.astype(region.dtype) + + +def save_flatfield( + path: Path, + flatfield: np.ndarray, + darkfield: Optional[np.ndarray] = None, +) -> None: + """ + Save flatfield (and optionally darkfield) to a .npy file. + + Parameters + ---------- + path : Path + Output path (should end with .npy). + flatfield : ndarray + Flatfield array with shape (C, Y, X). + darkfield : ndarray, optional + Darkfield array with shape (C, Y, X). + """ + data = { + "flatfield": flatfield.astype(np.float32), + "darkfield": darkfield.astype(np.float32) if darkfield is not None else None, + "channels": flatfield.shape[0], + "shape": flatfield.shape[1:], + } + np.save(path, data, allow_pickle=True) + + +def load_flatfield(path: Path) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Load flatfield (and optionally darkfield) from a .npy file. + + Parameters + ---------- + path : Path + Path to .npy file. + + Returns + ------- + flatfield : ndarray + Flatfield array with shape (C, Y, X). + darkfield : ndarray or None + Darkfield array with shape (C, Y, X), or None if not present. + + Raises + ------ + OSError + If the file cannot be read (not found, permission denied, etc.). + ValueError + If the file format is invalid (not a dictionary with 'flatfield' key). + """ + try: + loaded = np.load(path, allow_pickle=True) + except OSError as exc: + raise OSError(f"Cannot read flatfield file '{path}': {exc}") from exc + + try: + data = loaded.item() + except (AttributeError, ValueError) as exc: + raise ValueError( + f"Invalid flatfield file format at '{path}'. " + "Expected a NumPy .npy file containing a dictionary as saved by " + "`save_flatfield` (with keys like 'flatfield' and 'darkfield')." + ) from exc + + if not isinstance(data, dict) or "flatfield" not in data: + raise ValueError( + f"Invalid flatfield file format at '{path}'. " + "Expected a dictionary with at least a 'flatfield' entry." + ) + + flatfield = data["flatfield"] + darkfield = data.get("darkfield", None) + return flatfield, darkfield