diff --git a/software/README.md b/software/README.md
index 726aba311..ea9b1e13a 100644
--- a/software/README.md
+++ b/software/README.md
@@ -10,32 +10,6 @@ Reboot the computer to finish the installation.
## Optional or Hardware-specific dependencies
-
-image stitching dependencies (optional)
-For optional image stitching using ImageJ, additionally run the following:
-
-```
-sudo apt-get update
-sudo apt-get install openjdk-11-jdk
-sudo apt-get install maven
-pip3 install pyimagej
-pip3 instlal scyjava
-pip3 install tifffile
-pip3 install imagecodecs
-```
-
-Then, add the following line to the top of `/etc/environment` (needs to be edited with `sudo [your text editor]`):
-```
-JAVA_HOME=/usr/lib/jvm/default-java
-```
-Then, add the following lines to the top of `~/.bashrc` (or whichever file your terminal sources upon startup):
-```
-source /etc/environment
-export JAVA_HOME = $JAVA_HOME
-export PATH=$JAVA_HOME/bin:$PATH
-```
-
-
Installing drivers and libraries for FLIR camera support
Go to FLIR's page for downloading their Spinnaker SDK (https://www.flir.com/support/products/spinnaker-sdk/) and register.
diff --git a/software/control/_def.py b/software/control/_def.py
index b59a803c0..335a24854 100644
--- a/software/control/_def.py
+++ b/software/control/_def.py
@@ -7,7 +7,7 @@
import json
import csv
import squid.logging
-from enum import Enum, auto
+from enum import Enum
log = squid.logging.get_logger(__name__)
@@ -225,10 +225,6 @@ class ILLUMINATION_CODE:
ILLUMINATION_SOURCE_730NM = 15
-class VOLUMETRIC_IMAGING:
- NUM_PLANES_PER_VOLUME = 20
-
-
class CMD_EXECUTION_STATUS:
COMPLETED_WITHOUT_ERRORS = 0
IN_PROGRESS = 1
@@ -443,15 +439,6 @@ def convert_to_enum(option: Union[str, "FocusMeasureOperator"]) -> "FocusMeasure
FILE_ID_PADDING = 0
-class PLATE_READER:
- NUMBER_OF_ROWS = 8
- NUMBER_OF_COLUMNS = 12
- ROW_SPACING_MM = 9
- COLUMN_SPACING_MM = 9
- OFFSET_COLUMN_1_MM = 20
- OFFSET_ROW_A_MM = 20
-
-
CAMERA_PIXEL_SIZE_UM = {
"IMX290": 2.9,
"IMX178": 2.4,
@@ -467,11 +454,6 @@ class PLATE_READER:
TUBE_LENS_MM = 50
CAMERA_SENSOR = "IMX226"
-TRACKERS = ["csrt", "kcf", "mil", "tld", "medianflow", "mosse", "daSiamRPN"]
-DEFAULT_TRACKER = "csrt"
-
-ENABLE_TRACKING = False
-TRACKING_SHOW_MICROSCOPE_CONFIGURATIONS = False # set to true when doing multimodal acquisition
class CAMERA_CONFIG:
@@ -499,15 +481,6 @@ class AF:
CROP_HEIGHT = 800
-class Tracking:
- SEARCH_AREA_RATIO = 10 # @@@ check
- CROPPED_IMG_RATIO = 10 # @@@ check
- BBOX_SCALE_FACTOR = 1.2
- DEFAULT_TRACKER = "csrt"
- INIT_METHODS = ["roi"]
- DEFAULT_INIT_METHOD = "roi"
-
-
SHOW_DAC_CONTROL = False
@@ -543,9 +516,6 @@ class SOFTWARE_POS_LIMIT:
Z_NEGATIVE = 0.05
-SHOW_AUTOLEVEL_BTN = False
-AUTOLEVEL_DEFAULT_SETTING = False
-
MULTIPOINT_AUTOFOCUS_CHANNEL = "BF LED matrix full"
# MULTIPOINT_AUTOFOCUS_CHANNEL = 'BF LED matrix left half'
MULTIPOINT_AUTOFOCUS_ENABLE_BY_DEFAULT = False
@@ -559,7 +529,6 @@ class SOFTWARE_POS_LIMIT:
ENABLE_FLEXIBLE_MULTIPOINT = True
USE_OVERLAP_FOR_FLEXIBLE = True
ENABLE_WELLPLATE_MULTIPOINT = True
-ENABLE_RECORDING = False
RESUME_LIVE_AFTER_ACQUISITION = True
@@ -569,7 +538,6 @@ class SOFTWARE_POS_LIMIT:
CAMERA_SN = {"ch 1": "SN1", "ch 2": "SN2"} # for multiple cameras, to be overwritten in the configuration file
-ENABLE_STROBE_OUTPUT = False
ACQUISITION_PATTERN = "S-Pattern" # 'S-Pattern', 'Unidirectional'
FOV_PATTERN = "Unidirectional" # 'S-Pattern', 'Unidirectional'
@@ -613,7 +581,6 @@ class SOFTWARE_POS_LIMIT:
LASER_AF_MIN_PEAK_DISTANCE = 10
LASER_AF_MIN_PEAK_PROMINENCE = 0.20
LASER_AF_SPOT_SPACING = 100
-SHOW_LEGACY_DISPLACEMENT_MEASUREMENT_WINDOWS = False
LASER_AF_FILTER_SIGMA = None
LASER_AF_INITIALIZE_CROP_WIDTH = 1200
LASER_AF_INITIALIZE_CROP_HEIGHT = 800
@@ -624,20 +591,11 @@ class SOFTWARE_POS_LIMIT:
RETRACT_OBJECTIVE_BEFORE_MOVING_TO_LOADING_POSITION = True
OBJECTIVE_RETRACTED_POS_MM = 0.1
-TWO_CLASSIFICATION_MODELS = False
-CLASSIFICATION_MODEL_PATH = "models/resnet18_en/version1/best.pt"
-CLASSIFICATION_MODEL_PATH2 = "models/resnet18_en/version2/best.pt"
-CLASSIFICATION_TEST_MODE = False
-CLASSIFICATION_TH = 0.3
-
SEGMENTATION_MODEL_PATH = "models/m2unet_model_flat_erode1_wdecay5_smallbatch/model_4000_11.pth"
ENABLE_SEGMENTATION = True
USE_TRT_SEGMENTATION = False
SEGMENTATION_CROP = 1500
-DISP_TH_DURING_MULTIPOINT = 0.95
-SORT_DURING_MULTIPOINT = False
-
INVERTED_OBJECTIVE = False
ILLUMINATION_INTENSITY_FACTOR = 0.6
@@ -687,7 +645,6 @@ class SOFTWARE_POS_LIMIT:
# Napari integration
USE_NAPARI_FOR_LIVE_VIEW = False
-USE_NAPARI_FOR_MULTIPOINT = True
USE_NAPARI_FOR_MOSAIC_DISPLAY = True
USE_NAPARI_WELL_SELECTION = False
USE_NAPARI_FOR_LIVE_CONTROL = False
@@ -708,10 +665,7 @@ class SOFTWARE_POS_LIMIT:
# Navigation Settings
ENABLE_CLICK_TO_MOVE_BY_DEFAULT = True
-# Stitcher
IS_HCS = False
-DYNAMIC_REGISTRATION = False
-STITCH_COMPLETE_ACQUISITION = False
# Pseudo color settings
CHANNEL_COLORS_MAP = {
diff --git a/software/control/core/channel_configuration_mananger.py b/software/control/core/channel_configuration_mananger.py
index 50ba14383..0f5b632c2 100644
--- a/software/control/core/channel_configuration_mananger.py
+++ b/software/control/core/channel_configuration_mananger.py
@@ -75,12 +75,6 @@ def save_configurations(self, objective: str) -> None:
# Save only channel configurations
self._save_xml_config(objective, ConfigType.CHANNEL)
- def save_current_configuration_to_path(self, objective: str, path: Path) -> None:
- """Only used in TrackingController. Might be temporary."""
- config = self.all_configs[self.active_config_type][objective]
- xml_str = config.to_xml(pretty_print=True, encoding="utf-8")
- path.write_bytes(xml_str)
-
def get_configurations(self, objective: str) -> List[ChannelMode]:
"""Get channel modes for current active type"""
config = self.all_configs[self.active_config_type].get(objective)
diff --git a/software/control/core/core.py b/software/control/core/core.py
index e412d4840..d80530777 100644
--- a/software/control/core/core.py
+++ b/software/control/core/core.py
@@ -1,7 +1,5 @@
# set QT_API environment variable
import os
-import sys
-import tempfile
# qt libraries
os.environ["QT_API"] = "pyqt5"
@@ -18,20 +16,14 @@
from control.core.contrast_manager import ContrastManager
from control.core.laser_af_settings_manager import LaserAFSettingManager
from control.core.live_controller import LiveController
-from control.core.multi_point_worker import MultiPointWorker
from control.core.objective_store import ObjectiveStore
from control.core.scan_coordinates import ScanCoordinates
from control.core.stream_handler import StreamHandlerFunctions, StreamHandler
from control.microcontroller import Microcontroller
-from control.piezo import PiezoStage
-from squid.abc import AbstractStage, AbstractCamera, CameraAcquisitionMode, CameraFrame
+from squid.abc import CameraFrame
import control._def
import control.serial_peripherals as serial_peripherals
-import control.tracking as tracking
import control.utils as utils
-import control.utils_acquisition as utils_acquisition
-import control.utils_channel as utils_channel
-import control.utils_config as utils_config
import squid.logging
@@ -40,14 +32,10 @@
from threading import Thread, Lock
from pathlib import Path
from datetime import datetime
-from enum import Enum
-from control.utils_config import ChannelConfig, ChannelMode, LaserAFConfig
import time
-import itertools
import json
import math
import numpy as np
-import pandas as pd
import cv2
import imageio as iio
import squid.abc
@@ -190,63 +178,6 @@ def close(self):
self.thread.join()
-class ImageSaver_Tracking(QObject):
- def __init__(self, base_path, image_format="bmp"):
- QObject.__init__(self)
- self.base_path = base_path
- self.image_format = image_format
- self.max_num_image_per_folder = 1000
- self.queue = Queue(100) # max 100 items in the queue
- self.image_lock = Lock()
- self.stop_signal_received = False
- self.thread = Thread(target=self.process_queue, daemon=True)
- self.thread.start()
-
- def process_queue(self):
- while True:
- # stop the thread if stop signal is received
- if self.stop_signal_received:
- return
- # process the queue
- try:
- [image, frame_counter, postfix] = self.queue.get(timeout=0.1)
- self.image_lock.acquire(True)
- folder_ID = int(frame_counter / self.max_num_image_per_folder)
- file_ID = int(frame_counter % self.max_num_image_per_folder)
- # create a new folder
- if file_ID == 0:
- utils.ensure_directory_exists(os.path.join(self.base_path, str(folder_ID)))
- if image.dtype == np.uint16:
- saving_path = os.path.join(
- self.base_path,
- str(folder_ID),
- str(file_ID) + "_" + str(frame_counter) + "_" + postfix + ".tiff",
- )
- iio.imwrite(saving_path, image)
- else:
- saving_path = os.path.join(
- self.base_path,
- str(folder_ID),
- str(file_ID) + "_" + str(frame_counter) + "_" + postfix + "." + self.image_format,
- )
- cv2.imwrite(saving_path, image)
- self.queue.task_done()
- self.image_lock.release()
- except:
- pass
-
- def enqueue(self, image, frame_counter, postfix):
- try:
- self.queue.put_nowait([image, frame_counter, postfix])
- except:
- print("imageSaver queue is full, image discarded")
-
- def close(self):
- self.queue.join()
- self.stop_signal_received = True
- self.thread.join()
-
-
class ImageDisplay(QObject):
image_to_display = Signal(np.ndarray)
@@ -292,376 +223,6 @@ def close(self):
self.thread.join()
-class TrackingController(QObject):
- signal_tracking_stopped = Signal()
- image_to_display = Signal(np.ndarray)
- image_to_display_multi = Signal(np.ndarray, int)
- signal_current_configuration = Signal(ChannelMode)
-
- def __init__(
- self,
- camera: AbstractCamera,
- microcontroller: Microcontroller,
- stage: AbstractStage,
- objectiveStore,
- channelConfigurationManager,
- liveController: LiveController,
- autofocusController,
- imageDisplayWindow,
- ):
- QObject.__init__(self)
- self._log = squid.logging.get_logger(self.__class__.__name__)
- self.camera: AbstractCamera = camera
- self.microcontroller = microcontroller
- self.stage = stage
- self.objectiveStore = objectiveStore
- self.channelConfigurationManager = channelConfigurationManager
- self.liveController = liveController
- self.autofocusController = autofocusController
- self.imageDisplayWindow = imageDisplayWindow
- self.tracker = tracking.Tracker_Image()
-
- self.tracking_time_interval_s = 0
-
- self.display_resolution_scaling = Acquisition.IMAGE_DISPLAY_SCALING_FACTOR
- self.counter = 0
- self.experiment_ID = None
- self.base_path = None
- self.selected_configurations = []
-
- self.flag_stage_tracking_enabled = True
- self.flag_AF_enabled = False
- self.flag_save_image = False
- self.flag_stop_tracking_requested = False
-
- self.pixel_size_um = None
- self.objective = None
-
- def start_tracking(self):
-
- # save pre-tracking configuration
- self._log.info("start tracking")
- self.configuration_before_running_tracking = self.liveController.currentConfiguration
-
- # stop live
- if self.liveController.is_live:
- self.was_live_before_tracking = True
- self.liveController.stop_live() # @@@ to do: also uncheck the live button
- else:
- self.was_live_before_tracking = False
-
- # disable callback
- if self.camera.get_callbacks_enabled():
- self.camera_callback_was_enabled_before_tracking = True
- self.camera.enable_callbacks(False)
- else:
- self.camera_callback_was_enabled_before_tracking = False
-
- # hide roi selector
- self.imageDisplayWindow.hide_ROI_selector()
-
- # run tracking
- self.flag_stop_tracking_requested = False
- # create a QThread object
- try:
- if self.thread.isRunning():
- self._log.info("*** previous tracking thread is still running ***")
- self.thread.terminate()
- self.thread.wait()
- self._log.info("*** previous tracking threaded manually stopped ***")
- except:
- pass
- self.thread = QThread()
- # create a worker object
- self.trackingWorker = TrackingWorker(self)
- # move the worker to the thread
- self.trackingWorker.moveToThread(self.thread)
- # connect signals and slots
- self.thread.started.connect(self.trackingWorker.run)
- self.trackingWorker.finished.connect(self._on_tracking_stopped)
- self.trackingWorker.finished.connect(self.trackingWorker.deleteLater)
- self.trackingWorker.finished.connect(self.thread.quit)
- self.trackingWorker.image_to_display.connect(self.slot_image_to_display)
- self.trackingWorker.image_to_display_multi.connect(self.slot_image_to_display_multi)
- self.trackingWorker.signal_current_configuration.connect(self.slot_current_configuration)
- # self.thread.finished.connect(self.thread.deleteLater)
- self.thread.finished.connect(self.thread.quit)
- # start the thread
- self.thread.start()
-
- def _on_tracking_stopped(self):
-
- # restore the previous selected mode
- self.signal_current_configuration.emit(self.configuration_before_running_tracking)
- self.liveController.set_microscope_mode(self.configuration_before_running_tracking)
-
- # re-enable callback
- if self.camera_callback_was_enabled_before_tracking:
- self.camera.enable_callbacks(True)
- self.camera_callback_was_enabled_before_tracking = False
-
- # re-enable live if it's previously on
- if self.was_live_before_tracking:
- self.liveController.start_live()
-
- # show ROI selector
- self.imageDisplayWindow.show_ROI_selector()
-
- # emit the acquisition finished signal to enable the UI
- self.signal_tracking_stopped.emit()
- QApplication.processEvents()
-
- def start_new_experiment(self, experiment_ID): # @@@ to do: change name to prepare_folder_for_new_experiment
- # generate unique experiment ID
- self.experiment_ID = experiment_ID + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f")
- self.recording_start_time = time.time()
- # create a new folder
- try:
- utils.ensure_directory_exists(os.path.join(self.base_path, self.experiment_ID))
- self.channelConfigurationManager.save_current_configuration_to_path(
- self.objectiveStore.current_objective,
- os.path.join(self.base_path, self.experiment_ID) + "/configurations.xml",
- ) # save the configuration for the experiment
- except:
- self._log.info("error in making a new folder")
- pass
-
- def set_selected_configurations(self, selected_configurations_name):
- self.selected_configurations = []
- for configuration_name in selected_configurations_name:
- config = self.channelConfigurationManager.get_channel_configuration_by_name(
- self.objectiveStore.current_objective, configuration_name
- )
- if config:
- self.selected_configurations.append(config)
-
- def toggle_stage_tracking(self, state):
- self.flag_stage_tracking_enabled = state > 0
- self._log.info("set stage tracking enabled to " + str(self.flag_stage_tracking_enabled))
-
- def toggel_enable_af(self, state):
- self.flag_AF_enabled = state > 0
- self._log.info("set af enabled to " + str(self.flag_AF_enabled))
-
- def toggel_save_images(self, state):
- self.flag_save_image = state > 0
- self._log.info("set save images to " + str(self.flag_save_image))
-
- def set_base_path(self, path):
- self.base_path = path
-
- def stop_tracking(self):
- self.flag_stop_tracking_requested = True
- self._log.info("stop tracking requested")
-
- def slot_image_to_display(self, image):
- self.image_to_display.emit(image)
-
- def slot_image_to_display_multi(self, image, illumination_source):
- self.image_to_display_multi.emit(image, illumination_source)
-
- def slot_current_configuration(self, configuration):
- self.signal_current_configuration.emit(configuration)
-
- def update_pixel_size(self, pixel_size_um):
- self.pixel_size_um = pixel_size_um
-
- def update_tracker_selection(self, tracker_str):
- self.tracker.update_tracker_type(tracker_str)
-
- def set_tracking_time_interval(self, time_interval):
- self.tracking_time_interval_s = time_interval
-
- def update_image_resizing_factor(self, image_resizing_factor):
- self.image_resizing_factor = image_resizing_factor
- self._log.info("update tracking image resizing factor to " + str(self.image_resizing_factor))
- self.pixel_size_um_scaled = self.pixel_size_um / self.image_resizing_factor
-
-
-class TrackingWorker(QObject):
- finished = Signal()
- image_to_display = Signal(np.ndarray)
- image_to_display_multi = Signal(np.ndarray, int)
- signal_current_configuration = Signal(ChannelMode)
-
- def __init__(self, trackingController: TrackingController):
- QObject.__init__(self)
- self._log = squid.logging.get_logger(self.__class__.__name__)
- self.trackingController = trackingController
-
- self.camera: AbstractCamera = self.trackingController.camera
- self.stage = self.trackingController.stage
- self.microcontroller = self.trackingController.microcontroller
- self.liveController = self.trackingController.liveController
- self.autofocusController = self.trackingController.autofocusController
- self.channelConfigurationManager = self.trackingController.channelConfigurationManager
- self.imageDisplayWindow = self.trackingController.imageDisplayWindow
- self.display_resolution_scaling = self.trackingController.display_resolution_scaling
- self.counter = self.trackingController.counter
- self.experiment_ID = self.trackingController.experiment_ID
- self.base_path = self.trackingController.base_path
- self.selected_configurations = self.trackingController.selected_configurations
- self.tracker = trackingController.tracker
-
- self.number_of_selected_configurations = len(self.selected_configurations)
-
- self.image_saver = ImageSaver_Tracking(
- base_path=os.path.join(self.base_path, self.experiment_ID), image_format="bmp"
- )
-
- def _select_config(self, config: ChannelMode):
- self.signal_current_configuration.emit(config)
- # TODO(imo): replace with illumination controller.
- self.liveController.set_microscope_mode(config)
- self.microcontroller.wait_till_operation_is_completed()
- self.liveController.turn_on_illumination() # keep illumination on for single configuration acqusition
- self.microcontroller.wait_till_operation_is_completed()
-
- def run(self):
-
- tracking_frame_counter = 0
- t0 = time.time()
-
- # save metadata
- self.txt_file = open(os.path.join(self.base_path, self.experiment_ID, "metadata.txt"), "w+")
- self.txt_file.write("t0: " + datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f") + "\n")
- self.txt_file.write("objective: " + self.trackingController.objective + "\n")
- self.txt_file.close()
-
- # create a file for logging
- self.csv_file = open(os.path.join(self.base_path, self.experiment_ID, "track.csv"), "w+")
- self.csv_file.write(
- "dt (s), x_stage (mm), y_stage (mm), z_stage (mm), x_image (mm), y_image(mm), image_filename\n"
- )
-
- # reset tracker
- self.tracker.reset()
-
- # get the manually selected roi
- init_roi = self.imageDisplayWindow.get_roi_bounding_box()
- self.tracker.set_roi_bbox(init_roi)
-
- # tracking loop
- while not self.trackingController.flag_stop_tracking_requested:
- self._log.info("tracking_frame_counter: " + str(tracking_frame_counter))
- if tracking_frame_counter == 0:
- is_first_frame = True
- else:
- is_first_frame = False
-
- # timestamp
- timestamp_last_frame = time.time()
-
- # switch to the tracking config
- config = self.selected_configurations[0]
-
- # do autofocus
- if self.trackingController.flag_AF_enabled and tracking_frame_counter > 1:
- # do autofocus
- self._log.info(">>> autofocus")
- self.autofocusController.autofocus()
- self.autofocusController.wait_till_autofocus_has_completed()
- self._log.info(">>> autofocus completed")
-
- # get current position
- pos = self.stage.get_pos()
-
- # grab an image
- config = self.selected_configurations[0]
- if self.number_of_selected_configurations > 1:
- self._select_config(config)
- self.camera.send_trigger()
- camera_frame = self.camera.read_camera_frame()
- image = camera_frame.frame
- t = camera_frame.timestamp
- if self.number_of_selected_configurations > 1:
- self.liveController.turn_off_illumination() # keep illumination on for single configuration acqusition
- image = np.squeeze(image)
- # get image size
- image_shape = image.shape
- image_center = np.array([image_shape[1] * 0.5, image_shape[0] * 0.5])
-
- # image the rest configurations
- for config_ in self.selected_configurations[1:]:
- self._select_config(config_)
-
- self.camera.send_trigger()
- image_ = self.camera.read_frame()
- # TODO(imo): use illumination controller
- self.liveController.turn_off_illumination()
- image_ = np.squeeze(image_)
- # display image
- image_to_display_ = utils.crop_image(
- image_,
- round(image_.shape[1] * self.liveController.display_resolution_scaling),
- round(image_.shape[0] * self.liveController.display_resolution_scaling),
- )
- self.image_to_display_multi.emit(image_to_display_, config_.illumination_source)
- # save image
- if self.trackingController.flag_save_image:
- if camera_frame.is_color():
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- self.image_saver.enqueue(image_, tracking_frame_counter, str(config_.name))
-
- # track
- object_found, centroid, rect_pts = self.tracker.track(image, None, is_first_frame=is_first_frame)
- if not object_found:
- self._log.error("tracker: object not found")
- break
- in_plane_position_error_pixel = image_center - centroid
- in_plane_position_error_mm = (
- in_plane_position_error_pixel * self.trackingController.pixel_size_um_scaled / 1000
- )
- x_error_mm = in_plane_position_error_mm[0]
- y_error_mm = in_plane_position_error_mm[1]
-
- # display the new bounding box and the image
- self.imageDisplayWindow.update_bounding_box(rect_pts)
- self.imageDisplayWindow.display_image(image)
-
- # move
- if self.trackingController.flag_stage_tracking_enabled:
- # TODO(imo): This needs testing!
- self.stage.move_x(x_error_mm)
- self.stage.move_y(y_error_mm)
-
- # save image
- if self.trackingController.flag_save_image:
- self.image_saver.enqueue(image, tracking_frame_counter, str(config.name))
-
- # save position data
- self.csv_file.write(
- str(t)
- + ","
- + str(pos.x_mm)
- + ","
- + str(pos.y_mm)
- + ","
- + str(pos.z_mm)
- + ","
- + str(x_error_mm)
- + ","
- + str(y_error_mm)
- + ","
- + str(tracking_frame_counter)
- + "\n"
- )
- if tracking_frame_counter % 100 == 0:
- self.csv_file.flush()
-
- # wait till tracking interval has elapsed
- while time.time() - timestamp_last_frame < self.trackingController.tracking_time_interval_s:
- time.sleep(0.005)
-
- # increament counter
- tracking_frame_counter = tracking_frame_counter + 1
-
- # tracking terminated
- self.csv_file.close()
- self.image_saver.close()
- self.finished.emit()
-
-
class ImageDisplayWindow(QMainWindow):
image_click_coordinates = Signal(int, int, int, int)
@@ -1220,13 +781,6 @@ def display_image(self, image):
self.first_image = False
self.btn_line_profiler.setEnabled(True)
- if ENABLE_TRACKING:
- image = np.copy(image)
- self.image_height, self.image_width = image.shape[:2]
- if self.draw_rectangle:
- cv2.rectangle(image, self.ptRect1, self.ptRect2, (255, 255, 255), 4)
- self.draw_rectangle = False
-
info = np.iinfo(image.dtype) if np.issubdtype(image.dtype, np.integer) else np.finfo(image.dtype)
min_val, max_val = info.min, info.max
@@ -1681,71 +1235,6 @@ def handle_mouse_click(self, evt):
return
-class ImageArrayDisplayWindow(QMainWindow):
-
- def __init__(self, window_title=""):
- super().__init__()
- self.setWindowTitle(window_title)
- self.setWindowFlags(self.windowFlags() | Qt.CustomizeWindowHint)
- self.setWindowFlags(self.windowFlags() & ~Qt.WindowCloseButtonHint)
- self.widget = QWidget()
-
- # interpret image data as row-major instead of col-major
- pg.setConfigOptions(imageAxisOrder="row-major")
-
- self.graphics_widget_1 = pg.GraphicsLayoutWidget()
- self.graphics_widget_1.view = self.graphics_widget_1.addViewBox()
- self.graphics_widget_1.view.setAspectLocked(True)
- self.graphics_widget_1.img = pg.ImageItem(border="w")
- self.graphics_widget_1.view.addItem(self.graphics_widget_1.img)
- self.graphics_widget_1.view.invertY()
-
- self.graphics_widget_2 = pg.GraphicsLayoutWidget()
- self.graphics_widget_2.view = self.graphics_widget_2.addViewBox()
- self.graphics_widget_2.view.setAspectLocked(True)
- self.graphics_widget_2.img = pg.ImageItem(border="w")
- self.graphics_widget_2.view.addItem(self.graphics_widget_2.img)
- self.graphics_widget_2.view.invertY()
-
- self.graphics_widget_3 = pg.GraphicsLayoutWidget()
- self.graphics_widget_3.view = self.graphics_widget_3.addViewBox()
- self.graphics_widget_3.view.setAspectLocked(True)
- self.graphics_widget_3.img = pg.ImageItem(border="w")
- self.graphics_widget_3.view.addItem(self.graphics_widget_3.img)
- self.graphics_widget_3.view.invertY()
-
- self.graphics_widget_4 = pg.GraphicsLayoutWidget()
- self.graphics_widget_4.view = self.graphics_widget_4.addViewBox()
- self.graphics_widget_4.view.setAspectLocked(True)
- self.graphics_widget_4.img = pg.ImageItem(border="w")
- self.graphics_widget_4.view.addItem(self.graphics_widget_4.img)
- self.graphics_widget_4.view.invertY()
- ## Layout
- layout = QGridLayout()
- layout.addWidget(self.graphics_widget_1, 0, 0)
- layout.addWidget(self.graphics_widget_2, 0, 1)
- layout.addWidget(self.graphics_widget_3, 1, 0)
- layout.addWidget(self.graphics_widget_4, 1, 1)
- self.widget.setLayout(layout)
- self.setCentralWidget(self.widget)
-
- # set window size
- desktopWidget = QDesktopWidget()
- width = min(desktopWidget.height() * 0.9, 1000) # @@@TO MOVE@@@#
- height = width
- self.setFixedSize(int(width), int(height))
-
- def display_image(self, image, illumination_source):
- if illumination_source < 11:
- self.graphics_widget_1.img.setImage(image, autoLevels=False)
- elif illumination_source == 11:
- self.graphics_widget_2.img.setImage(image, autoLevels=False)
- elif illumination_source == 12:
- self.graphics_widget_3.img.setImage(image, autoLevels=False)
- elif illumination_source == 13:
- self.graphics_widget_4.img.setImage(image, autoLevels=False)
-
-
from scipy.interpolate import SmoothBivariateSpline, RBFInterpolator
diff --git a/software/control/core/live_controller.py b/software/control/core/live_controller.py
index 150bad055..86118a69e 100644
--- a/software/control/core/live_controller.py
+++ b/software/control/core/live_controller.py
@@ -20,7 +20,6 @@ def __init__(
camera: AbstractCamera,
control_illumination: bool = True,
use_internal_timer_for_hardware_trigger: bool = True,
- for_displacement_measurement: bool = False,
):
self._log = squid.logging.get_logger(self.__class__.__name__)
self.microscope = microscope
@@ -33,7 +32,6 @@ def __init__(
self.use_internal_timer_for_hardware_trigger = (
use_internal_timer_for_hardware_trigger # use Timer vs timer in the MCU
)
- self.for_displacement_measurement = for_displacement_measurement
self.fps_trigger = 1
self.timer_trigger_interval = (1.0 / self.fps_trigger) * 1000
@@ -177,9 +175,6 @@ def start_live(self):
):
self.camera.enable_callbacks(True) # in case it's disabled e.g. by the laser AF controller
self._start_triggerred_acquisition()
- # if controlling the laser displacement measurement camera
- if self.for_displacement_measurement:
- self.microscope.low_level_drivers.microcontroller.set_pin_level(MCU_PINS.AF_LASER, 1)
def stop_live(self):
if self.is_live:
@@ -194,9 +189,6 @@ def stop_live(self):
self._stop_triggerred_acquisition()
if self.control_illumination:
self.turn_off_illumination()
- # if controlling the laser displacement measurement camera
- if self.for_displacement_measurement:
- self.microscope.low_level_drivers.microcontroller.set_pin_level(MCU_PINS.AF_LASER, 0)
def _trigger_acquisition_timer_fn(self):
if self.trigger_acquisition():
diff --git a/software/control/core/multi_point_controller.py b/software/control/core/multi_point_controller.py
index 8b04c1728..7c98ec046 100644
--- a/software/control/core/multi_point_controller.py
+++ b/software/control/core/multi_point_controller.py
@@ -25,7 +25,7 @@
from control.core.objective_store import ObjectiveStore
from control.microcontroller import Microcontroller
from control.piezo import PiezoStage
-from squid.abc import CameraFrame, AbstractCamera, AbstractStage
+from squid.abc import AbstractCamera, AbstractStage
import squid.logging
diff --git a/software/control/core/stream_handler.py b/software/control/core/stream_handler.py
index 54e62aa19..b5fb15559 100644
--- a/software/control/core/stream_handler.py
+++ b/software/control/core/stream_handler.py
@@ -3,7 +3,6 @@
from typing import Callable
import numpy as np
-import cv2
from control import utils
import control._def
diff --git a/software/control/core_PDAF.py b/software/control/core_PDAF.py
deleted file mode 100644
index 935955b81..000000000
--- a/software/control/core_PDAF.py
+++ /dev/null
@@ -1,334 +0,0 @@
-# set QT_API environment variable
-import os
-
-os.environ["QT_API"] = "pyqt5"
-
-# qt libraries
-from qtpy.QtCore import *
-from qtpy.QtWidgets import *
-
-import control.utils as utils
-from control._def import *
-
-import time
-import numpy as np
-import cv2
-from datetime import datetime
-
-import skimage # pip3 install -U scikit-image
-import skimage.registration
-
-import squid.camera.utils
-from squid.abc import AbstractCamera
-
-
-class PDAFController(QObject):
-
- # input: stream from camera 1, stream from camera 2
- # input: from internal_states shared variables
- # output: amount of defocus, which may be read by or emitted to focusTrackingController (that manages focus tracking on/off, PID coefficients)
-
- def __init__(self, internal_states):
- QObject.__init__(self)
- self.coefficient_shift2defocus = 1
- self.registration_upsample_factor = 5
- self.image1_received = False
- self.image2_received = False
- self.locked = False
- self.shared_variables = internal_states
-
- def register_image_from_camera_1(self, image):
- if self.locked == True:
- return
- self.image1 = np.copy(image)
- self.image1_received = True
- if self.image2_received:
- self.calculate_defocus()
-
- def register_image_from_camera_2(self, image):
- if self.locked == True:
- return
- self.image2 = np.copy(image)
- self.image2 = np.fliplr(self.image2) # can be flipud depending on camera orientation
- self.image2_received = True
- if self.image1_received:
- self.calculate_defocus()
-
- def calculate_defocus(self):
- self.locked = True
- # cropping parameters
- self.x = self.shared_variables.x
- self.y = self.shared_variables.y
- self.w = self.shared_variables.w * 2 # double check which dimension to multiply
- self.h = self.shared_variables.h
- # crop
- self.image1 = self.image1[
- (self.y - int(self.h / 2)) : (self.y + int(self.h / 2)),
- (self.x - int(self.w / 2)) : (self.x + int(self.w / 2)),
- ]
- self.image2 = self.image2[
- (self.y - int(self.h / 2)) : (self.y + int(self.h / 2)),
- (self.x - int(self.w / 2)) : (self.x + int(self.w / 2)),
- ] # additional offsets may need to be added
- shift = self._compute_shift_from_image_pair()
- self.defocus = shift * self.coefficient_shift2defocus
- self.image1_received = False
- self.image2_received = False
- self.locked = False
-
- def _compute_shift_from_image_pair(self):
- # method 1: calculate 2D cross correlation -> find peak or centroid
- """
- I1 = np.array(self.image1,dtype=np.int)
- I2 = np.array(self.image2,dtype=np.int)
- I1 = I1 - np.mean(I1)
- I2 = I2 - np.mean(I2)
- xcorr = cv2.filter2D(I1,cv2.CV_32F,I2)
- cv2.imshow('xcorr',np.array(255*xcorr/np.max(xcorr),dtype=np.uint8))
- cv2.waitKey(15)
- """
- # method 2: use skimage.registration.phase_cross_correlation
- shifts, error, phasediff = skimage.registration.phase_cross_correlation(
- self.image1, self.image2, upsample_factor=self.registration_upsample_factor, space="real"
- )
- print(shifts) # for debugging
- return shifts[0] # can be shifts[1] - depending on camera orientation
-
- def close(self):
- pass
-
-
-class TwoCamerasPDAFCalibrationController(QObject):
-
- acquisitionFinished = Signal()
- image_to_display_camera1 = Signal(np.ndarray)
- image_to_display_camera2 = Signal(np.ndarray)
- signal_current_configuration = Signal(Configuration)
-
- z_pos = Signal(float)
-
- def __init__(
- self,
- camera1: AbstractCamera,
- camera2: AbstractCamera,
- navigationController,
- liveController1,
- liveController2,
- configurationManager=None,
- ):
- QObject.__init__(self)
-
- self.camera1: AbstractCamera = camera1
- self.camera2: AbstractCamera = camera2
- self.navigationController = navigationController
- self.liveController1 = liveController1
- self.liveController2 = liveController2
- self.configurationManager = configurationManager
- self.NZ = 1
- self.Nt = 1
- self.deltaZ = Acquisition.DZ / 1000
- self.deltaZ_usteps = round((Acquisition.DZ / 1000) * Motion.STEPS_PER_MM_Z)
- self.crop_width = Acquisition.CROP_WIDTH
- self.crop_height = Acquisition.CROP_HEIGHT
- self.display_resolution_scaling = Acquisition.IMAGE_DISPLAY_SCALING_FACTOR
- self.counter = 0
- self.experiment_ID = None
- self.base_path = None
-
- def set_NX(self, N):
- self.NX = N
-
- def set_NY(self, N):
- self.NY = N
-
- def set_NZ(self, N):
- self.NZ = N
-
- def set_Nt(self, N):
- self.Nt = N
-
- def set_deltaX(self, delta):
- self.deltaX = delta
- self.deltaX_usteps = round(delta * Motion.STEPS_PER_MM_XY)
-
- def set_deltaY(self, delta):
- self.deltaY = delta
- self.deltaY_usteps = round(delta * Motion.STEPS_PER_MM_XY)
-
- def set_deltaZ(self, delta_um):
- self.deltaZ = delta_um / 1000
- self.deltaZ_usteps = round((delta_um / 1000) * Motion.STEPS_PER_MM_Z)
-
- def set_deltat(self, delta):
- self.deltat = delta
-
- def set_af_flag(self, flag):
- self.do_autofocus = flag
-
- def set_crop(self, crop_width, height):
- self.crop_width = crop_width
- self.crop_height = crop_height
-
- def set_base_path(self, path):
- self.base_path = path
-
- def start_new_experiment(self, experiment_ID): # @@@ to do: change name to prepare_folder_for_new_experiment
- # generate unique experiment ID
- self.experiment_ID = experiment_ID + "_" + datetime.now().strftime("%Y-%m-%d %H-%M-%S.%f")
- self.recording_start_time = time.time()
- # create a new folder
- try:
- utils.ensure_directory_exists(os.path.join(self.base_path, self.experiment_ID))
- if self.configurationManager:
- self.configurationManager.write_configuration(
- os.path.join(self.base_path, self.experiment_ID) + "/configurations.xml"
- ) # save the configuration for the experiment
- except:
- pass
-
- def set_selected_configurations(self, selected_configurations_name):
- self.selected_configurations = []
- for configuration_name in selected_configurations_name:
- self.selected_configurations.append(
- next(
- (config for config in self.configurationManager.configurations if config.name == configuration_name)
- )
- )
-
- def run_acquisition(self): # @@@ to do: change name to run_experiment
- print("start multipoint")
-
- # stop live
- if self.liveController1.is_live:
- self.liveController1.was_live_before_multipoint = True
- self.liveController1.stop_live() # @@@ to do: also uncheck the live button
- else:
- self.liveController1.was_live_before_multipoint = False
- # stop live
- if self.liveController2.is_live:
- self.liveController2.was_live_before_multipoint = True
- self.liveController2.stop_live() # @@@ to do: also uncheck the live button
- else:
- self.liveController2.was_live_before_multipoint = False
-
- # disable callback
- if self.camera1.get_callbacks_enabled():
- self.camera1.callback_was_enabled_before_multipoint = True
- self.camera1.stop_streaming()
- self.camera1.enable_callbacks(False)
- self.camera1.start_streaming() # @@@ to do: absorb stop/start streaming into enable/disable callback - add a flag is_streaming to the camera class
- else:
- self.camera1.callback_was_enabled_before_multipoint = False
- # disable callback
- if self.camera2.get_callbacks_enabled():
- self.camera2.callback_was_enabled_before_multipoint = True
- self.camera2.stop_streaming()
- self.camera2.enable_callbacks(False)
- self.camera2.start_streaming() # @@@ to do: absorb stop/start streaming into enable/disable callback - add a flag is_streaming to the camera class
- else:
- self.camera2.callback_was_enabled_before_multipoint = False
-
- for self.time_point in range(self.Nt):
- self._run_multipoint_single()
-
- # re-enable callback
- if self.camera1.callback_was_enabled_before_multipoint:
- self.camera1.stop_streaming()
- self.camera1.enable_callbacks(True)
- self.camera1.start_streaming()
- self.camera1.callback_was_enabled_before_multipoint = False
- # re-enable callback
- if self.camera2.callback_was_enabled_before_multipoint:
- self.camera2.stop_streaming()
- self.camera2.enable_callbacks(True)
- self.camera2.start_streaming()
- self.camera2.callback_was_enabled_before_multipoint = False
-
- if self.liveController1.was_live_before_multipoint:
- self.liveController1.start_live()
- if self.liveController2.was_live_before_multipoint:
- self.liveController2.start_live()
-
- # emit acquisitionFinished signal
- self.acquisitionFinished.emit()
- QApplication.processEvents()
-
- def _run_multipoint_single(self):
- # for each time point, create a new folder
- current_path = os.path.join(self.base_path, self.experiment_ID, str(self.time_point))
- os.mkdir(current_path)
-
- # z-stack
- for k in range(self.NZ):
- file_ID = str(k)
- if self.configurationManager:
- # iterate through selected modes
- for config in self.selected_configurations:
- self.signal_current_configuration.emit(config)
- self.camera1.send_trigger()
- image = self.camera1.read_frame()
- image = utils.crop_image(image, self.crop_width, self.crop_height)
- saving_path = os.path.join(
- current_path, "camera1_" + file_ID + str(config.name) + "." + Acquisition.IMAGE_FORMAT
- )
- image_to_display = utils.crop_image(
- image,
- round(self.crop_width * self.liveController1.display_resolution_scaling),
- round(self.crop_height * self.liveController1.display_resolution_scaling),
- )
- self.image_to_display_camera1.emit(image_to_display)
- if self.camera1.is_color:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(saving_path, image)
-
- self.camera2.send_trigger()
- image = self.camera2.read_frame()
- image = utils.crop_image(image, self.crop_width, self.crop_height)
- saving_path = os.path.join(
- current_path, "camera2_" + file_ID + str(config.name) + "." + Acquisition.IMAGE_FORMAT
- )
- image_to_display = utils.crop_image(
- image,
- round(self.crop_width * self.liveController2.display_resolution_scaling),
- round(self.crop_height * self.liveController2.display_resolution_scaling),
- )
- self.image_to_display_camera2.emit(image_to_display)
- if self.camera2.is_color:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(saving_path, image)
- QApplication.processEvents()
- else:
- self.camera1.send_trigger()
- image = self.camera1.read_frame()
- image = utils.crop_image(image, self.crop_width, self.crop_height)
- saving_path = os.path.join(current_path, "camera1_" + file_ID + "." + Acquisition.IMAGE_FORMAT)
- image_to_display = utils.crop_image(
- image,
- round(self.crop_width * self.liveController1.display_resolution_scaling),
- round(self.crop_height * self.liveController1.display_resolution_scaling),
- )
- self.image_to_display_camera1.emit(image_to_display)
- if self.camera1.is_color:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(saving_path, image)
-
- self.camera2.send_trigger()
- image = self.camera2.read_frame()
- image = utils.crop_image(image, self.crop_width, self.crop_height)
- saving_path = os.path.join(current_path, "camera2_" + file_ID + "." + Acquisition.IMAGE_FORMAT)
- image_to_display = utils.crop_image(
- image,
- round(self.crop_width * self.liveController2.display_resolution_scaling),
- round(self.crop_height * self.liveController2.display_resolution_scaling),
- )
- self.image_to_display_camera2.emit(image_to_display)
- if self.camera2.is_color:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(saving_path, image)
- QApplication.processEvents()
- # move z
- if k < self.NZ - 1:
- self.navigationController.move_z_usteps(self.deltaZ_usteps)
-
- # move z back
- self.navigationController.move_z_usteps(-self.deltaZ_usteps * (self.NZ - 1))
diff --git a/software/control/core_displacement_measurement.py b/software/control/core_displacement_measurement.py
deleted file mode 100644
index c86794d3b..000000000
--- a/software/control/core_displacement_measurement.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# set QT_API environment variable
-import os
-
-os.environ["QT_API"] = "pyqt5"
-import qtpy
-
-# qt libraries
-from qtpy.QtCore import *
-from qtpy.QtWidgets import *
-from qtpy.QtGui import *
-
-import control.utils as utils
-from control._def import *
-
-import time
-import numpy as np
-import cv2
-
-
-class DisplacementMeasurementController(QObject):
-
- signal_readings = Signal(list)
- signal_plots = Signal(np.ndarray, np.ndarray)
-
- def __init__(self, x_offset=0, y_offset=0, x_scaling=1, y_scaling=1, N_average=1, N=10000):
-
- QObject.__init__(self)
- self.x_offset = x_offset
- self.y_offset = y_offset
- self.x_scaling = x_scaling
- self.y_scaling = y_scaling
- self.N_average = N_average
- self.N = N # length of array to emit
- self.t_array = np.array([])
- self.x_array = np.array([])
- self.y_array = np.array([])
-
- def update_measurement(self, image):
-
- t = time.time()
-
- if len(image.shape) == 3:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-
- h, w = image.shape
- x, y = np.meshgrid(range(w), range(h))
- I = image.astype(float)
- I = I - np.amin(I)
- I[I / np.amax(I) < 0.2] = 0
- x = np.sum(x * I) / np.sum(I)
- y = np.sum(y * I) / np.sum(I)
-
- x = x - self.x_offset
- y = y - self.y_offset
- x = x * self.x_scaling
- y = y * self.y_scaling
-
- self.t_array = np.append(self.t_array, t)
- self.x_array = np.append(self.x_array, x)
- self.y_array = np.append(self.y_array, y)
-
- self.signal_plots.emit(self.t_array[-self.N :], np.vstack((self.x_array[-self.N :], self.y_array[-self.N :])))
- self.signal_readings.emit([np.mean(self.x_array[-self.N_average :]), np.mean(self.y_array[-self.N_average :])])
-
- def update_settings(self, x_offset, y_offset, x_scaling, y_scaling, N_average, N):
- self.N = N
- self.N_average = N_average
- self.x_offset = x_offset
- self.y_offset = y_offset
- self.x_scaling = x_scaling
- self.y_scaling = y_scaling
diff --git a/software/control/core_platereader.py b/software/control/core_platereader.py
deleted file mode 100644
index 547320e44..000000000
--- a/software/control/core_platereader.py
+++ /dev/null
@@ -1,383 +0,0 @@
-# set QT_API environment variable
-import os
-
-os.environ["QT_API"] = "pyqt5"
-import qtpy
-
-# qt libraries
-from qtpy.QtCore import *
-from qtpy.QtWidgets import *
-from qtpy.QtGui import *
-
-import control.utils as utils
-from control._def import *
-import control.tracking as tracking
-from control.core import *
-
-from queue import Queue
-from threading import Thread, Lock
-import time
-import numpy as np
-import pyqtgraph as pg
-import cv2
-from datetime import datetime
-
-from lxml import etree as ET
-from pathlib import Path
-import control.utils_config as utils_config
-
-import math
-
-
-class PlateReadingWorker(QObject):
-
- finished = Signal()
- image_to_display = Signal(np.ndarray)
- image_to_display_multi = Signal(np.ndarray, int)
- signal_current_configuration = Signal(Configuration)
-
- def __init__(self, plateReadingController):
- QObject.__init__(self)
- self.plateReadingController = plateReadingController
-
- self.camera = self.plateReadingController.camera
- self.microcontroller = self.plateReadingController.microcontroller
- self.plateReaderNavigationController = self.plateReadingController.plateReaderNavigationController
- self.liveController = self.plateReadingController.liveController
- self.autofocusController = self.plateReadingController.autofocusController
- self.configurationManager = self.plateReadingController.configurationManager
- self.NX = self.plateReadingController.NX
- self.NY = self.plateReadingController.NY
- self.NZ = self.plateReadingController.NZ
- self.Nt = self.plateReadingController.Nt
- self.deltaX = self.plateReadingController.deltaX
- self.deltaX_usteps = self.plateReadingController.deltaX_usteps
- self.deltaY = self.plateReadingController.deltaY
- self.deltaY_usteps = self.plateReadingController.deltaY_usteps
- self.deltaZ = self.plateReadingController.deltaZ
- self.deltaZ_usteps = self.plateReadingController.deltaZ_usteps
- self.dt = self.plateReadingController.deltat
- self.do_autofocus = self.plateReadingController.do_autofocus
- self.crop_width = self.plateReadingController.crop_width
- self.crop_height = self.plateReadingController.crop_height
- self.display_resolution_scaling = self.plateReadingController.display_resolution_scaling
- self.counter = self.plateReadingController.counter
- self.experiment_ID = self.plateReadingController.experiment_ID
- self.base_path = self.plateReadingController.base_path
- self.timestamp_acquisition_started = self.plateReadingController.timestamp_acquisition_started
- self.time_point = 0
- self.abort_acquisition_requested = False
- self.selected_configurations = self.plateReadingController.selected_configurations
- self.selected_columns = self.plateReadingController.selected_columns
-
- def run(self):
- self.abort_acquisition_requested = False
- self.plateReaderNavigationController.is_scanning = True
- while self.time_point < self.Nt and self.abort_acquisition_requested == False:
- # continous acquisition
- if self.dt == 0:
- self.run_single_time_point()
- self.time_point = self.time_point + 1
- # timed acquisition
- else:
- self.run_single_time_point()
- self.time_point = self.time_point + 1
- # check if the aquisition has taken longer than dt or integer multiples of dt, if so skip the next time point(s)
- while time.time() > self.timestamp_acquisition_started + self.time_point * self.dt:
- print("skip time point " + str(self.time_point + 1))
- self.time_point = self.time_point + 1
- if self.time_point == self.Nt:
- break # no waiting after taking the last time point
- # wait until it's time to do the next acquisition
- while time.time() < self.timestamp_acquisition_started + self.time_point * self.dt:
- time.sleep(0.05)
- self.plateReaderNavigationController.is_scanning = False
- self.finished.emit()
-
- def wait_till_operation_is_completed(self):
- while self.microcontroller.is_busy():
- time.sleep(SLEEP_TIME_S)
-
- def run_single_time_point(self):
- self.FOV_counter = 0
- column_counter = 0
- print("multipoint acquisition - time point " + str(self.time_point + 1))
-
- # for each time point, create a new folder
- current_path = os.path.join(self.base_path, self.experiment_ID, str(self.time_point))
- utils.ensure_directory_exists(current_path)
-
- # run homing
- self.plateReaderNavigationController.home()
- self.wait_till_operation_is_completed()
-
- # row scan direction
- row_scan_direction = 1 # 1: A -> H, 0: H -> A
-
- # go through columns
- for column in self.selected_columns:
-
- # increament counter
- column_counter = column_counter + 1
-
- # move to the current column
- self.plateReaderNavigationController.moveto_column(column - 1)
- self.wait_till_operation_is_completed()
-
- """
- # row homing
- if column_counter > 1:
- self.plateReaderNavigationController.home_y()
- self.wait_till_operation_is_completed()
- """
-
- # go through rows
- for row in range(PLATE_READER.NUMBER_OF_ROWS):
-
- if row_scan_direction == 0: # reverse scan:
- row = PLATE_READER.NUMBER_OF_ROWS - 1 - row
-
- row_str = chr(ord("A") + row)
- file_ID = row_str + str(column)
-
- # move to the selected row
- self.plateReaderNavigationController.moveto_row(row)
- self.wait_till_operation_is_completed()
- time.sleep(SCAN_STABILIZATION_TIME_MS_Y / 1000)
-
- # AF
- if (
- (self.NZ == 1)
- and (self.do_autofocus)
- and (self.FOV_counter % Acquisition.NUMBER_OF_FOVS_PER_AF == 0)
- ):
- configuration_name_AF = "BF LED matrix full"
- config_AF = next(
- (
- config
- for config in self.configurationManager.configurations
- if config.name == configuration_name_AF
- )
- )
- self.signal_current_configuration.emit(config_AF)
- self.autofocusController.autofocus()
- self.autofocusController.wait_till_autofocus_has_completed()
-
- # z stack
- for k in range(self.NZ):
-
- if self.NZ > 1:
- # update file ID
- file_ID = file_ID + "_" + str(k)
- # maneuver for achiving uniform step size and repeatability when using open-loop control
- self.plateReaderNavigationController.move_z_usteps(80)
- self.wait_till_operation_is_completed()
- self.plateReaderNavigationController.move_z_usteps(-80)
- self.wait_till_operation_is_completed()
- time.sleep(SCAN_STABILIZATION_TIME_MS_Z / 1000)
-
- # iterate through selected modes
- for config in self.selected_configurations:
- self.signal_current_configuration.emit(config)
- self.wait_till_operation_is_completed()
- self.liveController.turn_on_illumination()
- self.wait_till_operation_is_completed()
- self.camera.send_trigger()
- image = self.camera.read_frame()
- self.liveController.turn_off_illumination()
- image = utils.crop_image(image, self.crop_width, self.crop_height)
- saving_path = os.path.join(
- current_path, file_ID + "_" + str(config.name) + "." + Acquisition.IMAGE_FORMAT
- )
- # self.image_to_display.emit(cv2.resize(image,(round(self.crop_width*self.display_resolution_scaling), round(self.crop_height*self.display_resolution_scaling)),cv2.INTER_LINEAR))
- # image_to_display = utils.crop_image(image,round(self.crop_width*self.liveController.display_resolution_scaling), round(self.crop_height*self.liveController.display_resolution_scaling))
- image_to_display = utils.crop_image(image, round(self.crop_width), round(self.crop_height))
- self.image_to_display.emit(image_to_display)
- self.image_to_display_multi.emit(image_to_display, config.illumination_source)
- if self.camera.is_color:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(saving_path, image)
- QApplication.processEvents()
-
- if self.NZ > 1:
- # move z
- if k < self.NZ - 1:
- self.plateReaderNavigationController.move_z_usteps(self.deltaZ_usteps)
- self.wait_till_operation_is_completed()
- time.sleep(SCAN_STABILIZATION_TIME_MS_Z / 1000)
-
- if self.NZ > 1:
- # move z back
- self.plateReaderNavigationController.move_z_usteps(-self.deltaZ_usteps * (self.NZ - 1))
- self.wait_till_operation_is_completed()
-
- if self.abort_acquisition_requested:
- return
-
- # update row scan direction
- row_scan_direction = 1 - row_scan_direction
-
-
-class PlateReadingController(QObject):
-
- acquisitionFinished = Signal()
- image_to_display = Signal(np.ndarray)
- image_to_display_multi = Signal(np.ndarray, int)
- signal_current_configuration = Signal(Configuration)
-
- def __init__(
- self, camera, plateReaderNavigationController, liveController, autofocusController, configurationManager
- ):
- QObject.__init__(self)
-
- self.camera = camera
- self.microcontroller = plateReaderNavigationController.microcontroller # to move to gui for transparency
- self.plateReaderNavigationController = plateReaderNavigationController
- self.liveController = liveController
- self.autofocusController = autofocusController
- self.configurationManager = configurationManager
- self.NX = 1
- self.NY = 1
- self.NZ = 1
- self.Nt = 1
- mm_per_ustep_X = SCREW_PITCH_X_MM / (self.plateReaderNavigationController.x_microstepping * FULLSTEPS_PER_REV_X)
- mm_per_ustep_Y = SCREW_PITCH_Y_MM / (self.plateReaderNavigationController.y_microstepping * FULLSTEPS_PER_REV_Y)
- mm_per_ustep_Z = SCREW_PITCH_Z_MM / (self.plateReaderNavigationController.z_microstepping * FULLSTEPS_PER_REV_Z)
- self.deltaX = Acquisition.DX
- self.deltaX_usteps = round(self.deltaX / mm_per_ustep_X)
- self.deltaY = Acquisition.DY
- self.deltaY_usteps = round(self.deltaY / mm_per_ustep_Y)
- self.deltaZ = Acquisition.DZ / 1000
- self.deltaZ_usteps = round(self.deltaZ / mm_per_ustep_Z)
- self.deltat = 0
- self.do_autofocus = False
- self.crop_width = Acquisition.CROP_WIDTH
- self.crop_height = Acquisition.CROP_HEIGHT
- self.display_resolution_scaling = Acquisition.IMAGE_DISPLAY_SCALING_FACTOR
- self.counter = 0
- self.experiment_ID = None
- self.base_path = None
- self.selected_configurations = []
- self.selected_columns = []
-
- def set_NZ(self, N):
- self.NZ = N
-
- def set_Nt(self, N):
- self.Nt = N
-
- def set_deltaZ(self, delta_um):
- mm_per_ustep_Z = SCREW_PITCH_Z_MM / (self.plateReaderNavigationController.z_microstepping * FULLSTEPS_PER_REV_Z)
- self.deltaZ = delta_um / 1000
- self.deltaZ_usteps = round((delta_um / 1000) / mm_per_ustep_Z)
-
- def set_deltat(self, delta):
- self.deltat = delta
-
- def set_af_flag(self, flag):
- self.do_autofocus = flag
-
- def set_crop(self, crop_width, height):
- self.crop_width = crop_width
- self.crop_height = crop_height
-
- def set_base_path(self, path):
- self.base_path = path
-
- def start_new_experiment(self, experiment_ID): # @@@ to do: change name to prepare_folder_for_new_experiment
- # generate unique experiment ID
- self.experiment_ID = experiment_ID + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f")
- self.recording_start_time = time.time()
- # create a new folder
- try:
- os.mkdir(os.path.join(self.base_path, self.experiment_ID))
- self.configurationManager.write_configuration(
- os.path.join(self.base_path, self.experiment_ID) + "/configurations.xml"
- ) # save the configuration for the experiment
- except:
- pass
-
- def set_selected_configurations(self, selected_configurations_name):
- self.selected_configurations = []
- for configuration_name in selected_configurations_name:
- self.selected_configurations.append(
- next(
- (config for config in self.configurationManager.configurations if config.name == configuration_name)
- )
- )
-
- def set_selected_columns(self, selected_columns):
- selected_columns.sort()
- self.selected_columns = selected_columns
-
- def run_acquisition(self): # @@@ to do: change name to run_experiment
- print("start plate reading")
- # save the current microscope configuration
- self.configuration_before_running_multipoint = self.liveController.currentConfiguration
- # stop live
- if self.liveController.is_live:
- self.liveController.was_live_before_multipoint = True
- self.liveController.stop_live() # @@@ to do: also uncheck the live button
- else:
- self.liveController.was_live_before_multipoint = False
- # disable callback
- if self.camera.callback_is_enabled:
- self.camera.callback_was_enabled_before_multipoint = True
- self.camera.stop_streaming()
- self.camera.disable_callback()
- self.camera.start_streaming() # @@@ to do: absorb stop/start streaming into enable/disable callback - add a flag is_streaming to the camera class
- else:
- self.camera.callback_was_enabled_before_multipoint = False
-
- # run the acquisition
- self.timestamp_acquisition_started = time.time()
- # create a QThread object
- self.thread = QThread()
- # create a worker object
- self.plateReadingWorker = PlateReadingWorker(self)
- # move the worker to the thread
- self.plateReadingWorker.moveToThread(self.thread)
- # connect signals and slots
- self.thread.started.connect(self.plateReadingWorker.run)
- self.plateReadingWorker.finished.connect(self._on_acquisition_completed)
- self.plateReadingWorker.finished.connect(self.plateReadingWorker.deleteLater)
- self.plateReadingWorker.finished.connect(self.thread.quit)
- self.plateReadingWorker.image_to_display.connect(self.slot_image_to_display)
- self.plateReadingWorker.image_to_display_multi.connect(self.slot_image_to_display_multi)
- self.plateReadingWorker.signal_current_configuration.connect(
- self.slot_current_configuration, type=Qt.BlockingQueuedConnection
- )
- self.thread.finished.connect(self.thread.deleteLater)
- # start the thread
- self.thread.start()
-
- def stop_acquisition(self):
- self.plateReadingWorker.abort_acquisition_requested = True
-
- def _on_acquisition_completed(self):
- # restore the previous selected mode
- self.signal_current_configuration.emit(self.configuration_before_running_multipoint)
-
- # re-enable callback
- if self.camera.callback_was_enabled_before_multipoint:
- self.camera.stop_streaming()
- self.camera.enable_callback()
- self.camera.start_streaming()
- self.camera.callback_was_enabled_before_multipoint = False
-
- # re-enable live if it's previously on
- if self.liveController.was_live_before_multipoint:
- self.liveController.start_live()
-
- # emit the acquisition finished signal to enable the UI
- self.acquisitionFinished.emit()
- QApplication.processEvents()
-
- def slot_image_to_display(self, image):
- self.image_to_display.emit(image)
-
- def slot_image_to_display_multi(self, image, illumination_source):
- self.image_to_display_multi.emit(image, illumination_source)
-
- def slot_current_configuration(self, configuration):
- self.signal_current_configuration.emit(configuration)
diff --git a/software/control/core_usbspectrometer.py b/software/control/core_usbspectrometer.py
deleted file mode 100644
index 3adb4f131..000000000
--- a/software/control/core_usbspectrometer.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# set QT_API environment variable
-import os
-
-os.environ["QT_API"] = "pyqt5"
-import qtpy
-
-# qt libraries
-from qtpy.QtCore import *
-from qtpy.QtWidgets import *
-from qtpy.QtGui import *
-
-import control.utils as utils
-from control._def import *
-import control.tracking as tracking
-
-from queue import Queue
-from threading import Thread, Lock
-import time
-import numpy as np
-import pyqtgraph as pg
-import cv2
-from datetime import datetime
-
-from lxml import etree as ET
-from pathlib import Path
-import control.utils_config as utils_config
-
-import math
-import json
-import pandas as pd
-
-
-class SpectrumStreamHandler(QObject):
-
- spectrum_to_display = Signal(np.ndarray)
- spectrum_to_write = Signal(np.ndarray)
- signal_new_spectrum_received = Signal()
-
- def __init__(self):
- QObject.__init__(self)
- self.fps_display = 30
- self.fps_save = 1
- self.timestamp_last_display = 0
- self.timestamp_last_save = 0
-
- self.save_spectrum_flag = False
-
- # for fps measurement
- self.timestamp_last = 0
- self.counter = 0
- self.fps_real = 0
-
- def start_recording(self):
- self.save_spectrum_flag = True
-
- def stop_recording(self):
- self.save_spectrum_flag = False
-
- def set_display_fps(self, fps):
- self.fps_display = fps
-
- def set_save_fps(self, fps):
- self.fps_save = fps
-
- def on_new_measurement(self, data):
- self.signal_new_spectrum_received.emit()
- # measure real fps
- timestamp_now = round(time.time())
- if timestamp_now == self.timestamp_last:
- self.counter = self.counter + 1
- else:
- self.timestamp_last = timestamp_now
- self.fps_real = self.counter
- self.counter = 0
- print("real spectrometer fps is " + str(self.fps_real))
- # send image to display
- time_now = time.time()
- if time_now - self.timestamp_last_display >= 1 / self.fps_display:
- self.spectrum_to_display.emit(data)
- self.timestamp_last_display = time_now
- # send image to write
- if self.save_spectrum_flag and time_now - self.timestamp_last_save >= 1 / self.fps_save:
- self.spectrum_to_write.emit(data)
- self.timestamp_last_save = time_now
-
-
-class SpectrumSaver(QObject):
-
- stop_recording = Signal()
-
- def __init__(self):
- QObject.__init__(self)
- self.base_path = "./"
- self.experiment_ID = ""
- self.max_num_file_per_folder = 1000
- self.queue = Queue(10) # max 10 items in the queue
- self.stop_signal_received = False
- self.thread = Thread(target=self.process_queue)
- self.thread.start()
- self.counter = 0
- self.recording_start_time = 0
- self.recording_time_limit = -1
-
- def process_queue(self):
- while True:
- # stop the thread if stop signal is received
- if self.stop_signal_received:
- return
- # process the queue
- try:
- data = self.queue.get(timeout=0.1)
- folder_ID = int(self.counter / self.max_num_file_per_folder)
- file_ID = int(self.counter % self.max_num_file_per_folder)
- # create a new folder
- if file_ID == 0:
- utils.ensure_directory_exists(os.path.join(self.base_path, self.experiment_ID, str(folder_ID)))
-
- saving_path = os.path.join(self.base_path, self.experiment_ID, str(folder_ID), str(file_ID) + ".csv")
- np.savetxt(saving_path, data, delimiter=",")
-
- self.counter = self.counter + 1
- self.queue.task_done()
- except:
- pass
-
- def enqueue(self, data):
- try:
- self.queue.put_nowait(data)
- if (self.recording_time_limit > 0) and (
- time.time() - self.recording_start_time >= self.recording_time_limit
- ):
- self.stop_recording.emit()
- # when using self.queue.put(str_), program can be slowed down despite multithreading because of the block and the GIL
- except:
- print("imageSaver queue is full, image discarded")
-
- def set_base_path(self, path):
- self.base_path = path
-
- def set_recording_time_limit(self, time_limit):
- self.recording_time_limit = time_limit
-
- def start_new_experiment(self, experiment_ID, add_timestamp=True):
- if add_timestamp:
- # generate unique experiment ID
- self.experiment_ID = experiment_ID + "_spectrum_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f")
- else:
- self.experiment_ID = experiment_ID
- self.recording_start_time = time.time()
- # create a new folder
- try:
- os.mkdir(os.path.join(self.base_path, self.experiment_ID))
- # to do: save configuration
- except:
- pass
- # reset the counter
- self.counter = 0
-
- def close(self):
- self.queue.join()
- self.stop_signal_received = True
- self.thread.join()
diff --git a/software/control/core_volumetric_imaging.py b/software/control/core_volumetric_imaging.py
deleted file mode 100644
index c82b76d9a..000000000
--- a/software/control/core_volumetric_imaging.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# set QT_API environment variable
-import os
-
-os.environ["QT_API"] = "pyqt5"
-import qtpy
-
-# qt libraries
-from qtpy.QtCore import *
-from qtpy.QtWidgets import *
-from qtpy.QtGui import *
-
-import control.utils as utils
-from control._def import *
-import control.tracking as tracking
-
-from queue import Queue
-from threading import Thread, Lock
-import time
-import numpy as np
-import pyqtgraph as pg
-import cv2
-from datetime import datetime
-
-from lxml import etree as ET
-from pathlib import Path
-import control.utils_config as utils_config
-
-
-class StreamHandler(QObject):
-
- image_to_display = Signal(np.ndarray)
- packet_image_to_write = Signal(np.ndarray, int, float)
- packet_image_for_tracking = Signal(np.ndarray, int, float)
- packet_image_for_array_display = Signal(np.ndarray, int)
- signal_new_frame_received = Signal()
-
- def __init__(
- self, crop_width=Acquisition.CROP_WIDTH, crop_height=Acquisition.CROP_HEIGHT, display_resolution_scaling=0.5
- ):
- QObject.__init__(self)
- self.fps_display = 1
- self.fps_save = 1
- self.fps_track = 1
- self.timestamp_last_display = 0
- self.timestamp_last_save = 0
- self.timestamp_last_track = 0
-
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.display_resolution_scaling = display_resolution_scaling
-
- self.save_image_flag = False
- self.track_flag = False
- self.handler_busy = False
-
- # for fps measurement
- self.timestamp_last = 0
- self.counter = 0
- self.fps_real = 0
-
- def start_recording(self):
- self.save_image_flag = True
-
- def stop_recording(self):
- self.save_image_flag = False
-
- def start_tracking(self):
- self.tracking_flag = True
-
- def stop_tracking(self):
- self.tracking_flag = False
-
- def set_display_fps(self, fps):
- self.fps_display = fps
-
- def set_save_fps(self, fps):
- self.fps_save = fps
-
- def set_crop(self, crop_width, height):
- self.crop_width = crop_width
- self.crop_height = crop_height
-
- def set_display_resolution_scaling(self, display_resolution_scaling):
- self.display_resolution_scaling = display_resolution_scaling / 100
- print(self.display_resolution_scaling)
-
- def on_new_frame(self, camera):
-
- self.handler_busy = True
- self.signal_new_frame_received.emit() # self.liveController.turn_off_illumination()
-
- # measure real fps
- timestamp_now = round(time.time())
- if timestamp_now == self.timestamp_last:
- self.counter = self.counter + 1
- else:
- self.timestamp_last = timestamp_now
- self.fps_real = self.counter
- self.counter = 0
- print("real camera fps is " + str(self.fps_real))
-
- # crop image
- image_cropped = utils.crop_image(camera.current_frame, self.crop_width, self.crop_height)
- image_cropped = np.squeeze(image_cropped)
-
- # send image to display
- time_now = time.time()
- if time_now - self.timestamp_last_display >= 1 / self.fps_display:
- # self.image_to_display.emit(cv2.resize(image_cropped,(round(self.crop_width*self.display_resolution_scaling), round(self.crop_height*self.display_resolution_scaling)),cv2.INTER_LINEAR))
- self.image_to_display.emit(
- utils.crop_image(
- image_cropped,
- round(self.crop_width * self.display_resolution_scaling),
- round(self.crop_height * self.display_resolution_scaling),
- )
- )
- self.timestamp_last_display = time_now
-
- # send image to array display
- self.packet_image_for_array_display.emit(
- image_cropped,
- (camera.frame_ID - camera.frame_ID_offset_hardware_trigger - 1) % VOLUMETRIC_IMAGING.NUM_PLANES_PER_VOLUME,
- )
-
- # send image to write
- if self.save_image_flag and time_now - self.timestamp_last_save >= 1 / self.fps_save:
- if camera.is_color:
- image_cropped = cv2.cvtColor(image_cropped, cv2.COLOR_RGB2BGR)
- self.packet_image_to_write.emit(image_cropped, camera.frame_ID, camera.timestamp)
- self.timestamp_last_save = time_now
-
- # send image to track
- if self.track_flag and time_now - self.timestamp_last_track >= 1 / self.fps_track:
- # track is a blocking operation - it needs to be
- # @@@ will cropping before emitting the signal lead to speedup?
- self.packet_image_for_tracking.emit(image_cropped, camera.frame_ID, camera.timestamp)
- self.timestamp_last_track = time_now
-
- self.handler_busy = False
-
-
-class ImageArrayDisplayWindow(QMainWindow):
-
- def __init__(self, window_title=""):
- super().__init__()
- self.setWindowTitle(window_title)
- self.setWindowFlags(self.windowFlags() | Qt.CustomizeWindowHint)
- self.setWindowFlags(self.windowFlags() & ~Qt.WindowCloseButtonHint)
- self.widget = QWidget()
-
- # interpret image data as row-major instead of col-major
- pg.setConfigOptions(imageAxisOrder="row-major")
-
- self.sub_windows = []
- for i in range(9):
- self.sub_windows.append(pg.GraphicsLayoutWidget())
- self.sub_windows[i].view = self.sub_windows[i].addViewBox(enableMouse=True)
- self.sub_windows[i].img = pg.ImageItem(border="w")
- self.sub_windows[i].view.setAspectLocked(True)
- self.sub_windows[i].view.addItem(self.sub_windows[i].img)
-
- ## Layout
- layout = QGridLayout()
- layout.addWidget(self.sub_windows[0], 0, 0)
- layout.addWidget(self.sub_windows[1], 0, 1)
- layout.addWidget(self.sub_windows[2], 0, 2)
- layout.addWidget(self.sub_windows[3], 1, 0)
- layout.addWidget(self.sub_windows[4], 1, 1)
- layout.addWidget(self.sub_windows[5], 1, 2)
- layout.addWidget(self.sub_windows[6], 2, 0)
- layout.addWidget(self.sub_windows[7], 2, 1)
- layout.addWidget(self.sub_windows[8], 2, 2)
- self.widget.setLayout(layout)
- self.setCentralWidget(self.widget)
-
- # set window size
- desktopWidget = QDesktopWidget()
- width = min(desktopWidget.height() * 0.9, 1000) # @@@TO MOVE@@@#
- height = width
- self.setFixedSize(width, height)
-
- def display_image(self, image, i):
- if i < 9:
- self.sub_windows[i].img.setImage(image, autoLevels=False)
- self.sub_windows[i].view.autoRange(padding=0)
diff --git a/software/control/gui_hcs.py b/software/control/gui_hcs.py
index 488d4af20..b427db8f2 100644
--- a/software/control/gui_hcs.py
+++ b/software/control/gui_hcs.py
@@ -14,7 +14,6 @@
)
os.environ["QT_API"] = "pyqt5"
-import serial
import time
from typing import Any, Optional
import numpy as np
@@ -42,12 +41,10 @@
)
from control.core.objective_store import ObjectiveStore
from control.core.stream_handler import StreamHandler
-from control.lighting import LightSourceType, IntensityControlMode, ShutterControlMode, IlluminationController
from control.microcontroller import Microcontroller
from control.microscope import Microscope
from control.utils_config import ChannelMode
-from squid.abc import AbstractCamera, AbstractStage, AbstractFilterWheelController
-import control.lighting
+from squid.abc import AbstractCamera, AbstractStage
import control.microscope
import control.widgets as widgets
import pyqtgraph.dockarea as dock
@@ -59,10 +56,6 @@
log = squid.logging.get_logger(__name__)
-if USE_PRIOR_STAGE:
- import squid.stage.prior
-else:
- import squid.stage.cephla
from control.piezo import PiezoStage
if USE_XERYON:
@@ -75,8 +68,6 @@
import control.microcontroller as microcontroller
import control.serial_peripherals as serial_peripherals
-if SUPPORT_LASER_AUTOFOCUS:
- import control.core_displacement_measurement as core_displacement_measurement
SINGLE_WINDOW = True # set to False if use separate windows for display and control
@@ -300,9 +291,6 @@ def __init__(
self.liveController_focus_camera: Optional[AbstractCamera] = None
self.streamHandler_focus_camera: Optional[StreamHandler] = None
self.imageDisplayWindow_focus: Optional[core.ImageDisplayWindow] = None
- self.displacementMeasurementController: Optional[
- core_displacement_measurement.DisplacementMeasurementController
- ] = None
self.laserAutofocusController: Optional[LaserAutofocusController] = None
if SUPPORT_LASER_AUTOFOCUS:
@@ -311,7 +299,6 @@ def __init__(
accept_new_frame_fn=lambda: self.liveController_focus_camera.is_live
)
self.imageDisplayWindow_focus = core.ImageDisplayWindow(show_LUT=False, autoLevels=False)
- self.displacementMeasurementController = core_displacement_measurement.DisplacementMeasurementController()
self.laserAutofocusController = LaserAutofocusController(
self.microcontroller,
self.camera_focus,
@@ -334,7 +321,6 @@ def __init__(
self.autofocusController: AutoFocusController = None
self.imageSaver: core.ImageSaver = core.ImageSaver()
self.imageDisplay: core.ImageDisplay = core.ImageDisplay()
- self.trackingController: core.TrackingController = None
self.navigationViewer: core.NavigationViewer = None
self.scanCoordinates: Optional[ScanCoordinates] = None
self.load_objects(is_simulation=is_simulation)
@@ -357,14 +343,11 @@ def __init__(
self.objectivesWidget: Optional[widgets.ObjectivesWidget] = None
self.filterControllerWidget: Optional[widgets.FilterControllerWidget] = None
self.squidFilterWidget: Optional[widgets.SquidFilterWidget] = None
- self.recordingControlWidget: Optional[widgets.RecordingWidget] = None
self.wellplateFormatWidget: Optional[widgets.WellplateFormatWidget] = None
self.wellSelectionWidget: Optional[widgets.WellSelectionWidget] = None
self.focusMapWidget: Optional[widgets.FocusMapWidget] = None
self.cameraSettingWidget_focus_camera: Optional[widgets.CameraSettingsWidget] = None
self.laserAutofocusSettingWidget: Optional[widgets.LaserAutofocusSettingWidget] = None
- self.waveformDisplay: Optional[widgets.WaveformDisplay] = None
- self.displacementMeasurementWidget: Optional[widgets.DisplacementMeasurementWidget] = None
self.laserAutofocusControlWidget: Optional[widgets.LaserAutofocusControlWidget] = None
self.fluidicsWidget: Optional[widgets.FluidicsWidget] = None
self.flexibleMultiPointWidget: Optional[widgets.FlexibleMultiPointWidget] = None
@@ -372,12 +355,10 @@ def __init__(
self.templateMultiPointWidget: Optional[TemplateMultiPointWidget] = None
self.multiPointWithFluidicsWidget: Optional[widgets.MultiPointWithFluidicsWidget] = None
self.sampleSettingsWidget: Optional[widgets.SampleSettingsWidget] = None
- self.trackingControlWidget: Optional[widgets.TrackingControllerWidget] = None
self.napariLiveWidget: Optional[widgets.NapariLiveWidget] = None
self.imageDisplayWindow: Optional[core.ImageDisplayWindow] = None
self.imageDisplayWindow_focus: Optional[core.ImageDisplayWindow] = None
self.napariMultiChannelWidget: Optional[widgets.NapariMultiChannelWidget] = None
- self.imageArrayDisplayWindow: Optional[core.ImageArrayDisplayWindow] = None
self.zPlotWidget: Optional[widgets.SurfacePlotWidget] = None
self.recordTabWidget: QTabWidget = QTabWidget()
@@ -446,17 +427,6 @@ def load_objects(self, is_simulation):
self.autofocusController = QtAutoFocusController(
self.camera, self.stage, self.liveController, self.microcontroller, self.nl5
)
- if ENABLE_TRACKING:
- self.trackingController = core.TrackingController(
- self.camera,
- self.microcontroller,
- self.stage,
- self.objectiveStore,
- self.channelConfigurationManager,
- self.liveController,
- self.autofocusController,
- self.imageDisplayWindow,
- )
if WELLPLATE_FORMAT == "glass slide" and IS_HCS:
self.navigationViewer = core.NavigationViewer(self.objectiveStore, self.camera, sample="4 glass slide")
else:
@@ -610,7 +580,6 @@ def load_widgets(self):
self.emission_filter_wheel, self.liveController
)
- self.recordingControlWidget = widgets.RecordingWidget(self.streamHandler, self.imageSaver)
self.wellplateFormatWidget = widgets.WellplateFormatWidget(
self.stage, self.navigationViewer, self.streamHandler, self.liveController
)
@@ -644,10 +613,6 @@ def load_widgets(self):
self.laserAutofocusController,
stretch=False,
) # ,show_display_options=True)
- self.waveformDisplay = widgets.WaveformDisplay(N=1000, include_x=True, include_y=False)
- self.displacementMeasurementWidget = widgets.DisplacementMeasurementWidget(
- self.displacementMeasurementController, self.waveformDisplay
- )
self.laserAutofocusControlWidget: widgets.LaserAutofocusControlWidget = widgets.LaserAutofocusControlWidget(
self.laserAutofocusController, self.liveController
)
@@ -658,13 +623,9 @@ def load_widgets(self):
self.imageDisplayTabs = QTabWidget(parent=self)
if self.live_only_mode:
- if ENABLE_TRACKING:
- self.imageDisplayWindow = core.ImageDisplayWindow(self.liveController, self.contrastManager)
- self.imageDisplayWindow.show_ROI_selector()
- else:
- self.imageDisplayWindow = core.ImageDisplayWindow(
- self.liveController, self.contrastManager, show_LUT=True, autoLevels=True
- )
+ self.imageDisplayWindow = core.ImageDisplayWindow(
+ self.liveController, self.contrastManager, show_LUT=True, autoLevels=True
+ )
self.imageDisplayTabs = self.imageDisplayWindow.widget
self.napariMosaicDisplayWidget = None
else:
@@ -715,14 +676,6 @@ def load_widgets(self):
)
self.sampleSettingsWidget = widgets.SampleSettingsWidget(self.objectivesWidget, self.wellplateFormatWidget)
- if ENABLE_TRACKING:
- self.trackingControlWidget = widgets.TrackingControllerWidget(
- self.trackingController,
- self.objectiveStore,
- self.channelConfigurationManager,
- show_configurations=TRACKING_SHOW_MICROSCOPE_CONFIGURATIONS,
- )
-
self.setupRecordTabWidget()
self.setupCameraTabWidget()
@@ -739,24 +692,16 @@ def setupImageDisplayTabs(self):
)
self.imageDisplayTabs.addTab(self.napariLiveWidget, "Live View")
else:
- if ENABLE_TRACKING:
- self.imageDisplayWindow = core.ImageDisplayWindow(self.liveController, self.contrastManager)
- self.imageDisplayWindow.show_ROI_selector()
- else:
- self.imageDisplayWindow = core.ImageDisplayWindow(
- self.liveController, self.contrastManager, show_LUT=True, autoLevels=True
- )
+ self.imageDisplayWindow = core.ImageDisplayWindow(
+ self.liveController, self.contrastManager, show_LUT=True, autoLevels=True
+ )
self.imageDisplayTabs.addTab(self.imageDisplayWindow.widget, "Live View")
if not self.live_only_mode:
- if USE_NAPARI_FOR_MULTIPOINT:
- self.napariMultiChannelWidget = widgets.NapariMultiChannelWidget(
- self.objectiveStore, self.camera, self.contrastManager
- )
- self.imageDisplayTabs.addTab(self.napariMultiChannelWidget, "Multichannel Acquisition")
- else:
- self.imageArrayDisplayWindow = core.ImageArrayDisplayWindow()
- self.imageDisplayTabs.addTab(self.imageArrayDisplayWindow.widget, "Multichannel Acquisition")
+ self.napariMultiChannelWidget = widgets.NapariMultiChannelWidget(
+ self.objectiveStore, self.camera, self.contrastManager
+ )
+ self.imageDisplayTabs.addTab(self.napariMultiChannelWidget, "Multichannel Acquisition")
if USE_NAPARI_FOR_MOSAIC_DISPLAY:
self.napariMosaicDisplayWidget = widgets.NapariMosaicDisplayWidget(
@@ -791,25 +736,11 @@ def setupImageDisplayTabs(self):
dock_laserfocus_liveController.setStretch(x=100, y=100)
dock_laserfocus_liveController.setFixedWidth(self.laserAutofocusSettingWidget.minimumSizeHint().width())
- dock_waveform = dock.Dock("Displacement Measurement", autoOrientation=False)
- dock_waveform.showTitleBar()
- dock_waveform.addWidget(self.waveformDisplay)
- dock_waveform.setStretch(x=100, y=40)
-
- dock_displayMeasurement = dock.Dock("Displacement Measurement Control", autoOrientation=False)
- dock_displayMeasurement.showTitleBar()
- dock_displayMeasurement.addWidget(self.displacementMeasurementWidget)
- dock_displayMeasurement.setStretch(x=100, y=40)
- dock_displayMeasurement.setFixedWidth(self.displacementMeasurementWidget.minimumSizeHint().width())
-
laserfocus_dockArea = dock.DockArea()
laserfocus_dockArea.addDock(dock_laserfocus_image_display)
laserfocus_dockArea.addDock(
dock_laserfocus_liveController, "right", relativeTo=dock_laserfocus_image_display
)
- if SHOW_LEGACY_DISPLACEMENT_MEASUREMENT_WINDOWS:
- laserfocus_dockArea.addDock(dock_waveform, "bottom", relativeTo=dock_laserfocus_liveController)
- laserfocus_dockArea.addDock(dock_displayMeasurement, "bottom", relativeTo=dock_waveform)
self.imageDisplayTabs.addTab(laserfocus_dockArea, self.LASER_BASED_FOCUS_TAB_NAME)
@@ -825,10 +756,6 @@ def setupRecordTabWidget(self):
self.recordTabWidget.addTab(self.templateMultiPointWidget, "Template Multipoint")
if RUN_FLUIDICS:
self.recordTabWidget.addTab(self.multiPointWithFluidicsWidget, "Multipoint with Fluidics")
- if ENABLE_TRACKING:
- self.recordTabWidget.addTab(self.trackingControlWidget, "Tracking")
- if ENABLE_RECORDING:
- self.recordTabWidget.addTab(self.recordingControlWidget, "Simple Recording")
self.recordTabWidget.currentChanged.connect(lambda: self.resizeCurrentTab(self.recordTabWidget))
self.resizeCurrentTab(self.recordTabWidget)
@@ -1079,13 +1006,6 @@ def slot_settings_changed_laser_af():
)
self.streamHandler_focus_camera.image_to_display.connect(self.imageDisplayWindow_focus.display_image)
- self.streamHandler_focus_camera.image_to_display.connect(
- self.displacementMeasurementController.update_measurement
- )
- self.displacementMeasurementController.signal_plots.connect(self.waveformDisplay.plot)
- self.displacementMeasurementController.signal_readings.connect(
- self.displacementMeasurementWidget.display_readings
- )
self.laserAutofocusController.image_to_display.connect(self.imageDisplayWindow_focus.display_image)
# Add connection for piezo position updates
@@ -1180,54 +1100,51 @@ def makeNapariConnections(self):
if not self.live_only_mode:
# Setup multichannel widget connections
- if USE_NAPARI_FOR_MULTIPOINT:
- self.napari_connections["napariMultiChannelWidget"] = [
- (self.multipointController.napari_layers_init, self.napariMultiChannelWidget.initLayers),
- (self.multipointController.napari_layers_update, self.napariMultiChannelWidget.updateLayers),
- ]
+ self.napari_connections["napariMultiChannelWidget"] = [
+ (self.multipointController.napari_layers_init, self.napariMultiChannelWidget.initLayers),
+ (self.multipointController.napari_layers_update, self.napariMultiChannelWidget.updateLayers),
+ ]
- if ENABLE_FLEXIBLE_MULTIPOINT:
- self.napari_connections["napariMultiChannelWidget"].extend(
- [
- (
- self.flexibleMultiPointWidget.signal_acquisition_channels,
- self.napariMultiChannelWidget.initChannels,
- ),
- (
- self.flexibleMultiPointWidget.signal_acquisition_shape,
- self.napariMultiChannelWidget.initLayersShape,
- ),
- ]
- )
+ if ENABLE_FLEXIBLE_MULTIPOINT:
+ self.napari_connections["napariMultiChannelWidget"].extend(
+ [
+ (
+ self.flexibleMultiPointWidget.signal_acquisition_channels,
+ self.napariMultiChannelWidget.initChannels,
+ ),
+ (
+ self.flexibleMultiPointWidget.signal_acquisition_shape,
+ self.napariMultiChannelWidget.initLayersShape,
+ ),
+ ]
+ )
- if ENABLE_WELLPLATE_MULTIPOINT:
- self.napari_connections["napariMultiChannelWidget"].extend(
- [
- (
- self.wellplateMultiPointWidget.signal_acquisition_channels,
- self.napariMultiChannelWidget.initChannels,
- ),
- (
- self.wellplateMultiPointWidget.signal_acquisition_shape,
- self.napariMultiChannelWidget.initLayersShape,
- ),
- ]
- )
- if RUN_FLUIDICS:
- self.napari_connections["napariMultiChannelWidget"].extend(
- [
- (
- self.multiPointWithFluidicsWidget.signal_acquisition_channels,
- self.napariMultiChannelWidget.initChannels,
- ),
- (
- self.multiPointWithFluidicsWidget.signal_acquisition_shape,
- self.napariMultiChannelWidget.initLayersShape,
- ),
- ]
- )
- else:
- self.multipointController.image_to_display_multi.connect(self.imageArrayDisplayWindow.display_image)
+ if ENABLE_WELLPLATE_MULTIPOINT:
+ self.napari_connections["napariMultiChannelWidget"].extend(
+ [
+ (
+ self.wellplateMultiPointWidget.signal_acquisition_channels,
+ self.napariMultiChannelWidget.initChannels,
+ ),
+ (
+ self.wellplateMultiPointWidget.signal_acquisition_shape,
+ self.napariMultiChannelWidget.initLayersShape,
+ ),
+ ]
+ )
+ if RUN_FLUIDICS:
+ self.napari_connections["napariMultiChannelWidget"].extend(
+ [
+ (
+ self.multiPointWithFluidicsWidget.signal_acquisition_channels,
+ self.napariMultiChannelWidget.initChannels,
+ ),
+ (
+ self.multiPointWithFluidicsWidget.signal_acquisition_shape,
+ self.napariMultiChannelWidget.initLayersShape,
+ ),
+ ]
+ )
# Setup mosaic display widget connections
if USE_NAPARI_FOR_MOSAIC_DISPLAY:
@@ -1340,11 +1257,8 @@ def setAcquisitionDisplayTabs(self, selected_configurations, Nz):
print(configs)
if USE_NAPARI_FOR_MOSAIC_DISPLAY and Nz == 1:
self.imageDisplayTabs.setCurrentWidget(self.napariMosaicDisplayWidget)
-
- elif USE_NAPARI_FOR_MULTIPOINT:
- self.imageDisplayTabs.setCurrentWidget(self.napariMultiChannelWidget)
else:
- self.imageDisplayTabs.setCurrentIndex(0)
+ self.imageDisplayTabs.setCurrentWidget(self.napariMultiChannelWidget)
def openLedMatrixSettings(self):
if SUPPORT_SCIMICROSCOPY_LED_ARRAY:
@@ -1644,7 +1558,6 @@ def closeEvent(self, event):
self.imageDisplay.close()
if not SINGLE_WINDOW:
self.imageDisplayWindow.close()
- self.imageArrayDisplayWindow.close()
self.tabbedImageDisplayWindow.close()
self.microcontroller.close()
diff --git a/software/control/microscope.py b/software/control/microscope.py
index c319277d1..03bead508 100644
--- a/software/control/microscope.py
+++ b/software/control/microscope.py
@@ -1,4 +1,3 @@
-import serial
from typing import Optional, TypeVar
import control._def
@@ -15,7 +14,6 @@
from control.piezo import PiezoStage
from control.serial_peripherals import SciMicroscopyLEDArray
from squid.abc import CameraAcquisitionMode, AbstractCamera, AbstractStage, AbstractFilterWheelController
-from squid.stage.utils import move_z_axis_to_safety_position
from squid.stage.cephla import CephlaStage
from squid.stage.prior import PriorStage
import control.celesta
@@ -26,7 +24,6 @@
import squid.config
import squid.filter_wheel_controller.utils
import squid.logging
-import squid.stage.cephla
import squid.stage.utils
if control._def.USE_XERYON:
@@ -338,7 +335,6 @@ def __init__(
microscope=self,
camera=self.addons.camera_focus,
control_illumination=False,
- for_displacement_measurement=True,
)
self.live_controller: LiveController = LiveController(microscope=self, camera=self.camera)
diff --git a/software/control/processing_handler.py b/software/control/processing_handler.py
deleted file mode 100644
index 086765bad..000000000
--- a/software/control/processing_handler.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import threading
-import queue
-import numpy as np
-import pandas as pd
-import control.utils as utils
-
-
-def default_image_preprocessor(image, callable_list):
- """
- :param image: ndarray representing an image
- :param callable_list: List of dictionaries in the form {'func': callable,
- 'args': list of positional args, 'kwargs': dict of keyword args}. The function
- should take an image ndarray as its first positional argument,
- and the image should
- not be included in the collection of args/kwargs
- :return: Image with the elements of callable_list applied in sequence
- """
- output_image = np.copy(image)
- for c in callable_list:
- output_image = c["func"](output_image, *c["args"], **c["kwargs"])
- return output_image
-
-
-class ProcessingHandler:
- """
- :brief: Handler class for parallelizing FOV processing. GENERAL NOTE:
- REMEMBER TO PASS COPIES OF IMAGES WHEN QUEUEING THEM FOR PROCESSING
- """
-
- def __init__(self):
- self.processing_queue = queue.Queue() # elements in this queue are
- # dicts in the form
- # {'function': callable, 'args':list
- # of positional arguments to pass,
- # 'kwargs': dict of kwargs to pass}
- # a dict in the form {'function':'end'}
- # will cause processing to terminate
- # the function called should return
- # a dict in the same form it received,
- # in appropriate form to pass to the
- # upload queue
-
- self.upload_queue = queue.Queue() # elements in this queue are
- # dicts in the form
- # {'function': callable, 'args':list
- # of positional arguments to pass,
- # 'kwargs': dict of kwargs to pass}
- # a dict in the form {'function':'end'}
- # will cause the uploading to terminate
- self.processing_thread = None
- self.uploading_thread = None
-
- def processing_queue_handler(self, queue_timeout=None):
- while True:
- processing_task = None
- try:
- processing_task = self.processing_queue.get(timeout=queue_timeout)
- except queue.Empty:
- break
- if processing_task["function"] == "end":
- self.processing_queue.task_done()
- break
- else:
- upload_task = processing_task["function"](*processing_task["args"], **processing_task["kwargs"])
- self.upload_queue.put(upload_task)
- self.processing_queue.task_done()
-
- def upload_queue_handler(self, queue_timeout=None):
- while True:
- upload_task = None
- try:
- upload_task = self.upload_queue.get(timeout=queue_timeout)
- except queue.Empty:
- break
- if upload_task["function"] == "end":
- self.upload_queue.task_done()
- break
- else:
- upload_task["function"](*upload_task["args"], **upload_task["kwargs"])
- self.upload_queue.task_done()
-
- def start_processing(self, queue_timeout=None):
- self.processing_thread = threading.Thread(target=self.processing_queue_handler, args=[queue_timeout])
- self.processing_thread.start()
-
- def start_uploading(self, queue_timeout=None):
- self.uploading_thread = threading.Thread(target=self.upload_queue_handler, args=[queue_timeout])
- self.uploading_thread.start()
-
- def end_uploading(self, *args, **kwargs):
- return {"function": "end"}
-
- def end_processing(self):
- self.processing_queue.put({"function": self.end_uploading, "args": [], "kwargs": {}})
- self.processing_queue.put({"function": "end"})
diff --git a/software/control/spectrometer_oceanoptics.py b/software/control/spectrometer_oceanoptics.py
deleted file mode 100644
index 9bf197ce0..000000000
--- a/software/control/spectrometer_oceanoptics.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import argparse
-import cv2
-import time
-import numpy as np
-import threading
-
-try:
- import seabreeze as sb
- import seabreeze.spectrometers
-except:
- print("seabreeze import error")
-
-# installation: $ pip3 install seabreeze
-# installation: $ seabreeze_os_setup
-
-from control._def import *
-
-
-class Spectrometer(object):
-
- def __init__(self, sn=None):
- if sn == None:
- self.spectrometer = sb.spectrometers.Spectrometer.from_first_available()
- else:
- self.spectrometer = sb.spectrometers.Spectrometer.Spectrometer.from_serial_number(sn)
-
- self.new_data_callback_external = None
-
- self.streaming_started = False
- self.streaming_paused = False
- self.stop_streaming = False
- self.is_reading_spectrum = False
-
- self.thread_streaming = threading.Thread(target=self.stream, daemon=True)
-
- def set_integration_time_ms(self, integration_time_ms):
- self.spectrometer.integration_time_micros(int(1000 * integration_time_ms))
-
- def read_spectrum(self, correct_dark_counts=False, correct_nonlinearity=False):
- self.is_reading_spectrum = True
- data = self.spectrometer.spectrum(correct_dark_counts, correct_nonlinearity)
- self.is_reading_spectrum = False
- return data
-
- def set_callback(self, function):
- self.new_data_callback_external = function
-
- def start_streaming(self):
- if self.streaming_started == False:
- self.streaming_started = True
- self.streaming_paused = False
- self.thread_streaming.start()
- else:
- self.streaming_paused = False
-
- def pause_streaming(self):
- self.streaming_paused = True
-
- def resume_streaming(self):
- self.streaming_paused = False
-
- def stream(self):
- while self.stop_streaming == False:
- if self.streaming_paused:
- time.sleep(0.05)
- continue
- # avoid conflict
- while self.is_reading_spectrum:
- time.sleep(0.05)
- if self.new_data_callback_external != None:
- self.new_data_callback_external(self.read_spectrum())
-
- def close(self):
- if self.streaming_started:
- self.stop_streaming = True
- self.thread_streaming.join()
- self.spectrometer.close()
-
-
-class Spectrometer_Simulation(object):
-
- def __init__(self, sn=None):
- self.new_data_callback_external = None
- self.streaming_started = False
- self.stop_streaming = False
- self.streaming_paused = False
- self.is_reading_spectrum = False
- self.thread_streaming = threading.Thread(target=self.stream, daemon=True)
-
- def set_integration_time_us(self, integration_time_us):
- pass
-
- def read_spectrum(self, correct_dark_counts=False, correct_nonlinearity=False):
- N = 4096
- wavelength = np.linspace(400, 1100, N)
- intensity = np.random.randint(0, 65536, N)
- return np.stack((wavelength, intensity))
-
- def set_callback(self, function):
- self.new_data_callback_external = function
-
- def start_streaming(self):
- if self.streaming_started == False:
- self.streaming_started = True
- self.streaming_paused = False
- self.thread_streaming.start()
- else:
- self.streaming_paused = False
-
- def pause_streaming(self):
- self.streaming_paused = True
-
- def resume_streaming(self):
- self.streaming_paused = False
-
- def stream(self):
- while self.stop_streaming == False:
- if self.streaming_paused:
- time.sleep(0.05)
- continue
- # avoid conflict
- while self.is_reading_spectrum:
- time.sleep(0.05)
- if self.new_data_callback_external != None:
- print("read spectrum...")
- self.new_data_callback_external(self.read_spectrum())
-
- def close(self):
- if self.streaming_started:
- self.stop_streaming = True
- self.thread_streaming.join()
diff --git a/software/control/stitcher.py b/software/control/stitcher.py
deleted file mode 100644
index e58148cf2..000000000
--- a/software/control/stitcher.py
+++ /dev/null
@@ -1,1946 +0,0 @@
-# napari + stitching libs
-import os
-import sys
-from control._def import *
-from qtpy.QtCore import *
-
-import psutil
-import shutil
-import random
-import json
-import time
-import math
-from datetime import datetime
-from lxml import etree
-import numpy as np
-import pandas as pd
-import cv2
-import dask.array as da
-from dask.array import from_zarr
-from dask_image.imread import imread as dask_imread
-from skimage.registration import phase_cross_correlation
-import ome_zarr
-import zarr
-from tifffile import TiffWriter
-from aicsimageio.writers import OmeTiffWriter
-from aicsimageio.writers import OmeZarrWriter
-from aicsimageio import types
-from basicpy import BaSiC
-
-
-class Stitcher(QThread, QObject):
-
- update_progress = Signal(int, int)
- getting_flatfields = Signal()
- starting_stitching = Signal()
- starting_saving = Signal(bool)
- finished_saving = Signal(str, object)
-
- def __init__(
- self,
- input_folder,
- output_name="",
- output_format=".ome.zarr",
- apply_flatfield=0,
- use_registration=0,
- registration_channel="",
- registration_z_level=0,
- flexible=True,
- ):
- QThread.__init__(self)
- QObject.__init__(self)
- self.input_folder = input_folder
- self.image_folder = None
- self.output_name = output_name + output_format
- self.apply_flatfield = apply_flatfield
- self.use_registration = use_registration
- if use_registration:
- self.registration_channel = registration_channel
- self.registration_z_level = registration_z_level
-
- self.selected_modes = self.extract_selected_modes(self.input_folder)
- self.acquisition_params = self.extract_acquisition_parameters(self.input_folder)
- self.time_points = self.get_time_points(self.input_folder)
- print("timepoints:", self.time_points)
- self.is_reversed = self.determine_directions(self.input_folder) # init: top to bottom, left to right
- print(self.is_reversed)
- self.is_wellplate = IS_HCS
- self.flexible = flexible
- self.pixel_size_um = 1.0
- self.init_stitching_parameters()
- # self.overlap_percent = Acquisition.OVERLAP_PERCENT
-
- def init_stitching_parameters(self):
- self.is_rgb = {}
- self.regions = []
- self.channel_names = []
- self.mono_channel_names = []
- self.channel_colors = []
- self.num_z = self.num_c = 1
- self.num_cols = self.num_rows = 1
- self.input_height = self.input_width = 0
- self.num_pyramid_levels = 5
- self.v_shift = self.h_shift = (0, 0)
- self.max_x_overlap = self.max_y_overlap = 0
- self.flatfields = {}
- self.stitching_data = {}
- self.tczyx_shape = (
- len(self.time_points),
- self.num_c,
- self.num_z,
- self.num_rows * self.input_height,
- self.num_cols * self.input_width,
- )
- self.stitched_images = None
- self.chunks = None
- self.dtype = np.uint16
-
- def get_time_points(self, input_folder):
- try: # detects directories named as integers, representing time points.
- time_points = [
- d for d in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, d)) and d.isdigit()
- ]
- time_points.sort(key=int)
- return time_points
- except Exception as e:
- print(f"Error detecting time points: {e}")
- return ["0"]
-
- def extract_selected_modes(self, input_folder):
- try:
- configs_path = os.path.join(input_folder, "configurations.xml")
- tree = etree.parse(configs_path)
- root = tree.getroot()
- selected_modes = {}
- for mode in root.findall(".//mode"):
- if mode.get("Selected") == "1":
- mode_id = mode.get("ID")
- selected_modes[mode_id] = {
- "Name": mode.get("Name"),
- "ExposureTime": mode.get("ExposureTime"),
- "AnalogGain": mode.get("AnalogGain"),
- "IlluminationSource": mode.get("IlluminationSource"),
- "IlluminationIntensity": mode.get("IlluminationIntensity"),
- }
- return selected_modes
- except Exception as e:
- print(f"Error reading selected modes: {e}")
-
- def extract_acquisition_parameters(self, input_folder):
- acquistion_params_path = os.path.join(input_folder, "acquisition parameters.json")
- with open(acquistion_params_path, "r") as file:
- acquisition_params = json.load(file)
- return acquisition_params
-
- def extract_wavelength(self, name):
- # Split the string and find the wavelength number immediately after "Fluorescence"
- parts = name.split()
- if "Fluorescence" in parts:
- index = parts.index("Fluorescence") + 1
- if index < len(parts):
- return parts[index].split()[0] # Assuming '488 nm Ex' and taking '488'
- for color in ["R", "G", "B"]:
- if color in parts:
- return color
- return None
-
- def determine_directions(self, input_folder):
- # return {'rows': self.acquisition_params.get("row direction", False),
- # 'cols': self.acquisition_params.get("col direction", False),
- # 'z-planes': False}
- coordinates = pd.read_csv(os.path.join(input_folder, self.time_points[0], "coordinates.csv"))
- try:
- first_region = coordinates["region"].unique()[0]
- coordinates = coordinates[coordinates["region"] == first_region]
- self.is_wellplate = True
- except Exception as e:
- print("no coordinates.csv well data:", e)
- self.is_wellplate = False
- i_rev = not coordinates.sort_values(by="i")["y (mm)"].is_monotonic_increasing
- j_rev = not coordinates.sort_values(by="j")["x (mm)"].is_monotonic_increasing
- k_rev = not coordinates.sort_values(by="z_level")["z (um)"].is_monotonic_increasing
- return {"rows": i_rev, "cols": j_rev, "z-planes": k_rev}
-
- def parse_filenames(self, time_point):
- # Initialize directories and read files
- self.image_folder = os.path.join(self.input_folder, str(time_point))
- print("stitching image folder:", self.image_folder)
- self.init_stitching_parameters()
-
- all_files = os.listdir(self.image_folder)
- sorted_input_files = sorted(
- [
- filename
- for filename in all_files
- if filename.endswith((".bmp", ".tiff")) and "focus_camera" not in filename
- ]
- )
- if not sorted_input_files:
- raise Exception("No valid files found in directory.")
-
- first_filename = sorted_input_files[0]
- try:
- first_region, first_i, first_j, first_k, channel_name = os.path.splitext(first_filename)[0].split("_", 4)
- first_k = int(first_k)
- print("region_i_j_k_channel_name: ", os.path.splitext(first_filename)[0])
- self.is_wellplate = True
- except ValueError as ve:
- first_i, first_j, first_k, channel_name = os.path.splitext(first_filename)[0].split("_", 3)
- print("i_j_k_channel_name: ", os.path.splitext(first_filename)[0])
- self.is_wellplate = False
-
- input_extension = os.path.splitext(sorted_input_files[0])[1]
- max_i, max_j, max_k = 0, 0, 0
- regions, channel_names = set(), set()
-
- for filename in sorted_input_files:
- if self.is_wellplate:
- region, i, j, k, channel_name = os.path.splitext(filename)[0].split("_", 4)
- else:
- region = "0"
- i, j, k, channel_name = os.path.splitext(filename)[0].split("_", 3)
-
- channel_name = channel_name.replace("_", " ").replace("full ", "full_")
- i, j, k = int(i), int(j), int(k)
-
- regions.add(region)
- channel_names.add(channel_name)
- max_i, max_j, max_k = max(max_i, i), max(max_j, j), max(max_k, k)
-
- tile_info = {
- "filepath": os.path.join(self.image_folder, filename),
- "region": region,
- "channel": channel_name,
- "z_level": k,
- "row": i,
- "col": j,
- }
- self.stitching_data.setdefault(region, {}).setdefault(channel_name, {}).setdefault(k, {}).setdefault(
- (i, j), tile_info
- )
-
- self.regions = sorted(regions)
- self.channel_names = sorted(channel_names)
- self.num_z, self.num_cols, self.num_rows = max_k + 1, max_j + 1, max_i + 1
-
- first_coord = (
- f"{self.regions[0]}_{first_i}_{first_j}_{first_k}_"
- if self.is_wellplate
- else f"{first_i}_{first_j}_{first_k}_"
- )
- found_dims = False
- mono_channel_names = []
-
- for channel in self.channel_names:
- filename = first_coord + channel.replace(" ", "_") + input_extension
- image = dask_imread(os.path.join(self.image_folder, filename))[0]
-
- if not found_dims:
- self.dtype = np.dtype(image.dtype)
- self.input_height, self.input_width = image.shape[:2]
- self.chunks = (1, 1, 1, self.input_height // 2, self.input_width // 2)
- found_dims = True
- print("chunks", self.chunks)
-
- if len(image.shape) == 3:
- self.is_rgb[channel] = True
- channel = channel.split("_")[0]
- mono_channel_names.extend([f"{channel}_R", f"{channel}_G", f"{channel}_B"])
- else:
- self.is_rgb[channel] = False
- mono_channel_names.append(channel)
-
- self.mono_channel_names = mono_channel_names
- self.num_c = len(mono_channel_names)
- self.channel_colors = [
- CHANNEL_COLORS_MAP.get(self.extract_wavelength(name), {"hex": 0xFFFFFF})["hex"]
- for name in self.mono_channel_names
- ]
- print(self.mono_channel_names)
- print(self.regions)
-
- def get_flatfields(self, progress_callback=None):
- def process_images(images, channel_name):
- images = np.array(images)
- basic = BaSiC(get_darkfield=False, smoothness_flatfield=1)
- basic.fit(images)
- channel_index = self.mono_channel_names.index(channel_name)
- self.flatfields[channel_index] = basic.flatfield
- if progress_callback:
- progress_callback(channel_index + 1, self.num_c)
-
- # Iterate only over the channels you need to process
- for channel in self.channel_names:
- all_tiles = []
- # Collect tiles from all roi and z-levels for the current channel
- for roi in self.regions:
- for z_level in self.stitching_data[roi][channel]:
- for row_col, tile_info in self.stitching_data[roi][channel][z_level].items():
- all_tiles.append(tile_info)
-
- # Shuffle and select a subset of tiles for flatfield calculation
- random.shuffle(all_tiles)
- selected_tiles = all_tiles[: min(32, len(all_tiles))]
-
- if self.is_rgb[channel]:
- # Process each color channel if the channel is RGB
- images_r = [dask_imread(tile["filepath"])[0][:, :, 0] for tile in selected_tiles]
- images_g = [dask_imread(tile["filepath"])[0][:, :, 1] for tile in selected_tiles]
- images_b = [dask_imread(tile["filepath"])[0][:, :, 2] for tile in selected_tiles]
- channel = channel.split("_")[0]
- process_images(images_r, channel + "_R")
- process_images(images_g, channel + "_G")
- process_images(images_b, channel + "_B")
- else:
- # Process monochrome images
- images = [dask_imread(tile["filepath"])[0] for tile in selected_tiles]
- process_images(images, channel)
-
- def normalize_image(self, img):
- img_min, img_max = img.min(), img.max()
- img_normalized = (img - img_min) / (img_max - img_min)
- scale_factor = np.iinfo(self.dtype).max if np.issubdtype(self.dtype, np.integer) else 1
- return (img_normalized * scale_factor).astype(self.dtype)
-
- def visualize_image(self, img1, img2, title):
- if title == "horizontal":
- combined_image = np.hstack((img1, img2))
- else:
- combined_image = np.vstack((img1, img2))
- cv2.imwrite(f"{self.input_folder}/{title}.png", combined_image)
-
- def calculate_horizontal_shift(self, img1_path, img2_path, max_overlap, margin_ratio=0.2):
- try:
- img1 = dask_imread(img1_path)[0].compute()
- img2 = dask_imread(img2_path)[0].compute()
- img1 = self.normalize_image(img1)
- img2 = self.normalize_image(img2)
-
- margin = int(self.input_height * margin_ratio)
- img1_overlap = (img1[margin:-margin, -max_overlap:]).astype(self.dtype)
- img2_overlap = (img2[margin:-margin, :max_overlap]).astype(self.dtype)
-
- self.visualize_image(img1_overlap, img2_overlap, "horizontal")
- shift, error, diffphase = phase_cross_correlation(img1_overlap, img2_overlap, upsample_factor=10)
- return round(shift[0]), round(shift[1] - img1_overlap.shape[1])
- except Exception as e:
- print(f"Error calculating horizontal shift: {e}")
- return (0, 0)
-
- def calculate_vertical_shift(self, img1_path, img2_path, max_overlap, margin_ratio=0.2):
- try:
- img1 = dask_imread(img1_path)[0].compute()
- img2 = dask_imread(img2_path)[0].compute()
- img1 = self.normalize_image(img1)
- img2 = self.normalize_image(img2)
-
- margin = int(self.input_width * margin_ratio)
- img1_overlap = (img1[-max_overlap:, margin:-margin]).astype(self.dtype)
- img2_overlap = (img2[:max_overlap, margin:-margin]).astype(self.dtype)
-
- self.visualize_image(img1_overlap, img2_overlap, "vertical")
- shift, error, diffphase = phase_cross_correlation(img1_overlap, img2_overlap, upsample_factor=10)
- return round(shift[0] - img1_overlap.shape[0]), round(shift[1])
- except Exception as e:
- print(f"Error calculating vertical shift: {e}")
- return (0, 0)
-
- def calculate_shifts(self, roi=""):
- roi = self.regions[0] if roi not in self.regions else roi
- self.registration_channel = (
- self.registration_channel if self.registration_channel in self.channel_names else self.channel_names[0]
- )
-
- # Calculate estimated overlap from acquisition parameters
- dx_mm = self.acquisition_params["dx(mm)"]
- dy_mm = self.acquisition_params["dy(mm)"]
- obj_mag = self.acquisition_params["objective"]["magnification"]
- obj_tube_lens_mm = self.acquisition_params["objective"]["tube_lens_f_mm"]
- sensor_pixel_size_um = self.acquisition_params["sensor_pixel_size_um"]
- tube_lens_mm = self.acquisition_params["tube_lens_mm"]
-
- obj_focal_length_mm = obj_tube_lens_mm / obj_mag
- actual_mag = tube_lens_mm / obj_focal_length_mm
- self.pixel_size_um = sensor_pixel_size_um / actual_mag
- print("pixel_size_um:", self.pixel_size_um)
-
- dx_pixels = dx_mm * 1000 / self.pixel_size_um
- dy_pixels = dy_mm * 1000 / self.pixel_size_um
- print("dy_pixels", dy_pixels, ", dx_pixels:", dx_pixels)
-
- self.max_x_overlap = round(abs(self.input_width - dx_pixels) / 2)
- self.max_y_overlap = round(abs(self.input_height - dy_pixels) / 2)
- print(
- "objective calculated - vertical overlap:", self.max_y_overlap, ", horizontal overlap:", self.max_x_overlap
- )
-
- col_left, col_right = (self.num_cols - 1) // 2, (self.num_cols - 1) // 2 + 1
- if self.is_reversed["cols"]:
- col_left, col_right = col_right, col_left
-
- row_top, row_bottom = (self.num_rows - 1) // 2, (self.num_rows - 1) // 2 + 1
- if self.is_reversed["rows"]:
- row_top, row_bottom = row_bottom, row_top
-
- img1_path = img2_path_vertical = img2_path_horizontal = None
- for (row, col), tile_info in self.stitching_data[roi][self.registration_channel][
- self.registration_z_level
- ].items():
- if col == col_left and row == row_top:
- img1_path = tile_info["filepath"]
- elif col == col_left and row == row_bottom:
- img2_path_vertical = tile_info["filepath"]
- elif col == col_right and row == row_top:
- img2_path_horizontal = tile_info["filepath"]
-
- if img1_path is None:
- raise Exception(
- f"No input file found for c:{self.registration_channel} k:{self.registration_z_level} "
- f"j:{col_left} i:{row_top}"
- )
-
- self.v_shift = (
- self.calculate_vertical_shift(img1_path, img2_path_vertical, self.max_y_overlap)
- if self.max_y_overlap > 0 and img2_path_vertical and img1_path != img2_path_vertical
- else (0, 0)
- )
- self.h_shift = (
- self.calculate_horizontal_shift(img1_path, img2_path_horizontal, self.max_x_overlap)
- if self.max_x_overlap > 0 and img2_path_horizontal and img1_path != img2_path_horizontal
- else (0, 0)
- )
- print("vertical shift:", self.v_shift, ", horizontal shift:", self.h_shift)
-
- def calculate_dynamic_shifts(self, roi, channel, z_level, row, col):
- h_shift, v_shift = self.h_shift, self.v_shift
-
- # Check for left neighbor
- if (row, col - 1) in self.stitching_data[roi][channel][z_level]:
- left_tile_path = self.stitching_data[roi][channel][z_level][row, col - 1]["filepath"]
- current_tile_path = self.stitching_data[roi][channel][z_level][row, col]["filepath"]
- # Calculate horizontal shift
- new_h_shift = self.calculate_horizontal_shift(left_tile_path, current_tile_path, abs(self.h_shift[1]))
-
- # Check if the new horizontal shift is within 10% of the precomputed shift
- if self.h_shift == (0, 0) or (
- 0.95 * abs(self.h_shift[1]) <= abs(new_h_shift[1]) <= 1.05 * abs(self.h_shift[1])
- and 0.95 * abs(self.h_shift[0]) <= abs(new_h_shift[0]) <= 1.05 * abs(self.h_shift[0])
- ):
- print("new h shift", new_h_shift, h_shift)
- h_shift = new_h_shift
-
- # Check for top neighbor
- if (row - 1, col) in self.stitching_data[roi][channel][z_level]:
- top_tile_path = self.stitching_data[roi][channel][z_level][row - 1, col]["filepath"]
- current_tile_path = self.stitching_data[roi][channel][z_level][row, col]["filepath"]
- # Calculate vertical shift
- new_v_shift = self.calculate_vertical_shift(top_tile_path, current_tile_path, abs(self.v_shift[0]))
-
- # Check if the new vertical shift is within 10% of the precomputed shift
- if self.v_shift == (0, 0) or (
- 0.95 * abs(self.v_shift[0]) <= abs(new_v_shift[0]) <= 1.05 * abs(self.v_shift[0])
- and 0.95 * abs(self.v_shift[1]) <= abs(new_v_shift[1]) <= 1.05 * abs(self.v_shift[1])
- ):
- print("new v shift", new_v_shift, v_shift)
- v_shift = new_v_shift
-
- return h_shift, v_shift
-
- def init_output(self, time_point, region_id):
- output_folder = os.path.join(self.input_folder, f"{time_point}_stitched")
- os.makedirs(output_folder, exist_ok=True)
- self.output_path = os.path.join(
- output_folder, f"{region_id}_{self.output_name}" if self.is_wellplate else self.output_name
- )
-
- x_max = (
- self.input_width
- + ((self.num_cols - 1) * (self.input_width + self.h_shift[1])) # horizontal width with overlap
- + abs((self.num_rows - 1) * self.v_shift[1])
- ) # horizontal shift from vertical registration
- y_max = (
- self.input_height
- + ((self.num_rows - 1) * (self.input_height + self.v_shift[0])) # vertical height with overlap
- + abs((self.num_cols - 1) * self.h_shift[0])
- ) # vertical shift from horizontal registration
- if self.use_registration and DYNAMIC_REGISTRATION:
- y_max *= 1.05
- x_max *= 1.05
- size = max(y_max, x_max)
- num_levels = 1
-
- # Get the number of rows and columns
- if self.is_wellplate and STITCH_COMPLETE_ACQUISITION:
- rows, columns = self.get_rows_and_columns()
- self.num_pyramid_levels = math.ceil(np.log2(max(x_max, y_max) / 1024 * max(len(rows), len(columns))))
- else:
- self.num_pyramid_levels = math.ceil(np.log2(max(x_max, y_max) / 1024))
- print("num_pyramid_levels", self.num_pyramid_levels)
-
- tczyx_shape = (1, self.num_c, self.num_z, y_max, x_max)
- self.tczyx_shape = tczyx_shape
- print(f"(t:{time_point}, roi:{region_id}) output shape: {tczyx_shape}")
- return da.zeros(tczyx_shape, dtype=self.dtype, chunks=self.chunks)
-
- def stitch_images(self, time_point, roi, progress_callback=None):
- self.stitched_images = self.init_output(time_point, roi)
- total_tiles = sum(
- len(z_data) for channel_data in self.stitching_data[roi].values() for z_data in channel_data.values()
- )
- processed_tiles = 0
-
- for z_level in range(self.num_z):
-
- for row in range(self.num_rows):
- row = self.num_rows - 1 - row if self.is_reversed["rows"] else row
-
- for col in range(self.num_cols):
- col = self.num_cols - 1 - col if self.is_reversed["cols"] else col
-
- if self.use_registration and DYNAMIC_REGISTRATION and z_level == self.registration_z_level:
- if (row, col) in self.stitching_data[roi][self.registration_channel][z_level]:
- tile_info = self.stitching_data[roi][self.registration_channel][z_level][(row, col)]
- self.h_shift, self.v_shift = self.calculate_dynamic_shifts(
- roi, self.registration_channel, z_level, row, col
- )
-
- # Now apply the same shifts to all channels
- for channel in self.channel_names:
- if (row, col) in self.stitching_data[roi][channel][z_level]:
- tile_info = self.stitching_data[roi][channel][z_level][(row, col)]
- tile = dask_imread(tile_info["filepath"])[0]
- # tile = tile[:, ::-1]
- if self.is_rgb[channel]:
- for color_idx, color in enumerate(["R", "G", "B"]):
- tile_color = tile[:, :, color_idx]
- color_channel = f"{channel}_{color}"
- self.stitch_single_image(
- tile_color, z_level, self.mono_channel_names.index(color_channel), row, col
- )
- processed_tiles += 1
- else:
- self.stitch_single_image(
- tile, z_level, self.mono_channel_names.index(channel), row, col
- )
- processed_tiles += 1
- if progress_callback is not None:
- progress_callback(processed_tiles, total_tiles)
-
- def stitch_single_image(self, tile, z_level, channel_idx, row, col):
- # print(tile.shape)
- if self.apply_flatfield:
- tile = (
- (tile / self.flatfields[channel_idx])
- .clip(min=np.iinfo(self.dtype).min, max=np.iinfo(self.dtype).max)
- .astype(self.dtype)
- )
- # Determine crop for tile edges
- top_crop = max(0, (-self.v_shift[0] // 2) - abs(self.h_shift[0]) // 2) if row > 0 else 0
- bottom_crop = max(0, (-self.v_shift[0] // 2) - abs(self.h_shift[0]) // 2) if row < self.num_rows - 1 else 0
- left_crop = max(0, (-self.h_shift[1] // 2) - abs(self.v_shift[1]) // 2) if col > 0 else 0
- right_crop = max(0, (-self.h_shift[1] // 2) - abs(self.v_shift[1]) // 2) if col < self.num_cols - 1 else 0
-
- tile = tile[top_crop : tile.shape[0] - bottom_crop, left_crop : tile.shape[1] - right_crop]
-
- # Initialize starting coordinates based on tile position and shift
- y = row * (self.input_height + self.v_shift[0]) + top_crop
- if self.h_shift[0] < 0:
- y -= (self.num_cols - 1 - col) * self.h_shift[0] # Moves up if negative
- else:
- y += col * self.h_shift[0] # Moves down if positive
-
- x = col * (self.input_width + self.h_shift[1]) + left_crop
- if self.v_shift[1] < 0:
- x -= (self.num_rows - 1 - row) * self.v_shift[1] # Moves left if negative
- else:
- x += row * self.v_shift[1] # Moves right if positive
-
- # Place cropped tile on the stitched image canvas
- self.stitched_images[0, channel_idx, z_level, y : y + tile.shape[0], x : x + tile.shape[1]] = tile
- # print(f" col:{col}, \trow:{row},\ty:{y}-{y+tile.shape[0]}, \tx:{x}-{x+tile.shape[-1]}")
-
- def save_as_ome_tiff(self):
- dz_um = self.acquisition_params.get("dz(um)", None)
- sensor_pixel_size_um = self.acquisition_params.get("sensor_pixel_size_um", None)
- dims = "TCZYX"
- # if self.is_rgb:
- # dims += "S"
-
- ome_metadata = OmeTiffWriter.build_ome(
- image_name=[os.path.basename(self.output_path)],
- data_shapes=[self.stitched_images.shape],
- data_types=[self.stitched_images.dtype],
- dimension_order=[dims],
- channel_names=[self.mono_channel_names],
- physical_pixel_sizes=[types.PhysicalPixelSizes(dz_um, self.pixel_size_um, self.pixel_size_um)],
- # is_rgb=self.is_rgb
- # channel colors
- )
- OmeTiffWriter.save(
- data=self.stitched_images,
- uri=self.output_path,
- ome_xml=ome_metadata,
- dimension_order=[dims],
- # channel colors / names
- )
- self.stitched_images = None
-
- def save_as_ome_zarr(self):
- dz_um = self.acquisition_params.get("dz(um)", None)
- sensor_pixel_size_um = self.acquisition_params.get("sensor_pixel_size_um", None)
- dims = "TCZYX"
- intensity_min = np.iinfo(self.dtype).min
- intensity_max = np.iinfo(self.dtype).max
- channel_minmax = [(intensity_min, intensity_max)] * self.num_c
- for i in range(self.num_c):
- print(
- f"Channel {i}:",
- self.mono_channel_names[i],
- " \tColor:",
- self.channel_colors[i],
- " \tPixel Range:",
- channel_minmax[i],
- )
-
- zarr_writer = OmeZarrWriter(self.output_path)
- zarr_writer.build_ome(
- size_z=self.num_z,
- image_name=os.path.basename(self.output_path),
- channel_names=self.mono_channel_names,
- channel_colors=self.channel_colors,
- channel_minmax=channel_minmax,
- )
- zarr_writer.write_image(
- image_data=self.stitched_images,
- image_name=os.path.basename(self.output_path),
- physical_pixel_sizes=types.PhysicalPixelSizes(dz_um, self.pixel_size_um, self.pixel_size_um),
- channel_names=self.mono_channel_names,
- channel_colors=self.channel_colors,
- dimension_order=dims,
- scale_num_levels=self.num_pyramid_levels,
- chunk_dims=self.chunks,
- )
- self.stitched_images = None
-
- def create_complete_ome_zarr(self):
- """Creates a complete OME-ZARR with proper channel metadata."""
- final_path = os.path.join(
- self.input_folder, self.output_name.replace(".ome.zarr", "") + "_complete_acquisition.ome.zarr"
- )
- if len(self.time_points) == 1:
- zarr_path = os.path.join(self.input_folder, f"0_stitched", self.output_name)
- # final_path = zarr_path
- shutil.copytree(zarr_path, final_path)
- else:
- store = ome_zarr.io.parse_url(final_path, mode="w").store
- root_group = zarr.group(store=store)
- intensity_min = np.iinfo(self.dtype).min
- intensity_max = np.iinfo(self.dtype).max
-
- data = self.load_and_merge_timepoints()
- ome_zarr.writer.write_image(
- image=data,
- group=root_group,
- axes="tczyx",
- channel_names=self.mono_channel_names,
- storage_options=dict(chunks=self.chunks),
- )
-
- channel_info = [
- {
- "label": self.mono_channel_names[i],
- "color": f"{self.channel_colors[i]:06X}",
- "window": {"start": intensity_min, "end": intensity_max},
- "active": True,
- }
- for i in range(self.num_c)
- ]
-
- # Assign the channel metadata to the image group
- root_group.attrs["omero"] = {"channels": channel_info}
-
- print(f"Data saved in OME-ZARR format at: {final_path}")
- root = zarr.open(final_path, mode="r")
- print(root.tree())
- print(dict(root.attrs))
- self.finished_saving.emit(final_path, self.dtype)
-
- def create_hcs_ome_zarr(self):
- """Creates a hierarchical Zarr file in the HCS OME-ZARR format for visualization in napari."""
- hcs_path = os.path.join(
- self.input_folder, self.output_name.replace(".ome.zarr", "") + "_complete_acquisition.ome.zarr"
- )
- if len(self.time_points) == 1 and len(self.regions) == 1:
- stitched_zarr_path = os.path.join(self.input_folder, f"0_stitched", f"{self.regions[0]}_{self.output_name}")
- # hcs_path = stitched_zarr_path # replace next line with this if no copy wanted
- shutil.copytree(stitched_zarr_path, hcs_path)
- else:
- store = ome_zarr.io.parse_url(hcs_path, mode="w").store
- root_group = zarr.group(store=store)
-
- # Retrieve row and column information for plate metadata
- rows, columns = self.get_rows_and_columns()
- well_paths = [f"{well_id[0]}/{well_id[1:]}" for well_id in sorted(self.regions)]
- print(well_paths)
- ome_zarr.writer.write_plate_metadata(root_group, rows, [str(col) for col in columns], well_paths)
-
- # Loop over each well and save its data
- for well_id in self.regions:
- row, col = well_id[0], well_id[1:]
- row_group = root_group.require_group(row)
- well_group = row_group.require_group(col)
- self.write_well_and_metadata(well_id, well_group)
-
- print(f"Data saved in HCS OME-ZARR format at: {hcs_path}")
-
- print("HCS root attributes:")
- root = zarr.open(hcs_path, mode="r")
- print(root.tree())
- print(dict(root.attrs))
-
- self.finished_saving.emit(hcs_path, self.dtype)
-
- def write_well_and_metadata(self, well_id, well_group):
- """Process and save data for a single well across all timepoints."""
- # Load data from precomputed Zarrs for each timepoint
- data = self.load_and_merge_timepoints(well_id)
- intensity_min = np.iinfo(self.dtype).min
- intensity_max = np.iinfo(self.dtype).max
- # dataset = well_group.create_dataset("data", data=data, chunks=(1, 1, 1, self.input_height, self.input_width), dtype=data.dtype)
- field_paths = ["0"] # Assuming single field of view
- ome_zarr.writer.write_well_metadata(well_group, field_paths)
- for fi, field in enumerate(field_paths):
- image_group = well_group.require_group(str(field))
- ome_zarr.writer.write_image(
- image=data,
- group=image_group,
- axes="tczyx",
- channel_names=self.mono_channel_names,
- storage_options=dict(chunks=self.chunks),
- )
- channel_info = [
- {
- "label": self.mono_channel_names[c],
- "color": f"{self.channel_colors[c]:06X}",
- "window": {"start": intensity_min, "end": intensity_max},
- "active": True,
- }
- for c in range(self.num_c)
- ]
-
- image_group.attrs["omero"] = {"channels": channel_info}
-
- def pad_to_largest(self, array, target_shape):
- if array.shape == target_shape:
- return array
- pad_widths = [(0, max(0, ts - s)) for s, ts in zip(array.shape, target_shape)]
- return da.pad(array, pad_widths, mode="constant", constant_values=0)
-
- def load_and_merge_timepoints(self, well_id=""):
- """Load and merge data for a well from Zarr files for each timepoint."""
- t_data = []
- t_shapes = []
- for t in self.time_points:
- if self.is_wellplate:
- filepath = f"{well_id}_{self.output_name}"
- else:
- filepath = f"{self.output_name}"
- zarr_path = os.path.join(self.input_folder, f"{t}_stitched", filepath)
- print(f"t:{t} well:{well_id}, \t{zarr_path}")
- z = zarr.open(zarr_path, mode="r")
- # Ensure that '0' contains the data and it matches expected dimensions
- x_max = (
- self.input_width
- + ((self.num_cols - 1) * (self.input_width + self.h_shift[1]))
- + abs((self.num_rows - 1) * self.v_shift[1])
- )
- y_max = (
- self.input_height
- + ((self.num_rows - 1) * (self.input_height + self.v_shift[0]))
- + abs((self.num_cols - 1) * self.h_shift[0])
- )
- t_array = da.from_zarr(z["0"], chunks=self.chunks)
- t_data.append(t_array)
- t_shapes.append(t_array.shape)
-
- # Concatenate arrays along the existing time axis if multiple timepoints are present
- if len(t_data) > 1:
- max_shape = tuple(max(s) for s in zip(*t_shapes))
- padded_data = [self.pad_to_largest(t, max_shape) for t in t_data]
- data = da.concatenate(padded_data, axis=0)
- print(f"(merged timepoints, well:{well_id}) output shape: {data.shape}")
- return data
- elif len(t_data) == 1:
- data = t_data[0]
- return data
- else:
- raise ValueError("no data loaded from timepoints.")
-
- def get_rows_and_columns(self):
- """Utility to extract rows and columns from well identifiers."""
- rows = set()
- columns = set()
- for well_id in self.regions:
- rows.add(well_id[0]) # Assuming well_id like 'A1'
- columns.add(int(well_id[1:]))
- return sorted(rows), sorted(columns)
-
- def run(self):
- # Main stitching logic
- stime = time.time()
- try:
- for time_point in self.time_points:
- ttime = time.time()
- print(f"starting t:{time_point}...")
- self.parse_filenames(time_point)
-
- if self.apply_flatfield:
- print(f"getting flatfields...")
- self.getting_flatfields.emit()
- self.get_flatfields(progress_callback=self.update_progress.emit)
- print("time to apply flatfields", time.time() - ttime)
-
- if self.use_registration:
- shtime = time.time()
- print(f"calculating shifts...")
- self.calculate_shifts()
- print("time to calculate shifts", time.time() - shtime)
-
- for well in self.regions:
- wtime = time.time()
- self.starting_stitching.emit()
- print(f"\nstarting stitching...")
- self.stitch_images(time_point, well, progress_callback=self.update_progress.emit)
-
- sttime = time.time()
- print("time to stitch well", sttime - wtime)
-
- self.starting_saving.emit(not STITCH_COMPLETE_ACQUISITION)
- print(f"saving...")
- if ".ome.tiff" in self.output_path:
- self.save_as_ome_tiff()
- else:
- self.save_as_ome_zarr()
-
- print("time to save stitched well", time.time() - sttime)
- print("time per well", time.time() - wtime)
- if well != "0":
- print(f"...done saving well:{well}")
- print(f"...finished t:{time_point}")
- print("time per timepoint", time.time() - ttime)
-
- if STITCH_COMPLETE_ACQUISITION and not self.flexible and ".ome.zarr" in self.output_name:
- self.starting_saving.emit(True)
- scatime = time.time()
- if self.is_wellplate:
- self.create_hcs_ome_zarr()
- print(f"...done saving complete hcs successfully")
- else:
- self.create_complete_ome_zarr()
- print(f"...done saving complete successfully")
- print("time to save merged wells and timepoints", time.time() - scatime)
- else:
- self.finished_saving.emit(self.output_path, self.dtype)
- print("total time to stitch + save:", time.time() - stime)
-
- except Exception as e:
- print("time before error", time.time() - stime)
- print(f"error While Stitching: {e}")
-
-
-class CoordinateStitcher(QThread, QObject):
- update_progress = Signal(int, int)
- getting_flatfields = Signal()
- starting_stitching = Signal()
- starting_saving = Signal(bool)
- finished_saving = Signal(str, object)
-
- def __init__(
- self,
- input_folder,
- output_name="",
- output_format=".ome.zarr",
- apply_flatfield=0,
- use_registration=0,
- registration_channel="",
- registration_z_level=0,
- overlap_percent=0,
- ):
- super().__init__()
- self.input_folder = input_folder
- self.output_name = output_name + output_format
- self.output_format = output_format
- self.apply_flatfield = apply_flatfield
- self.use_registration = use_registration
- if use_registration:
- self.registration_channel = registration_channel
- self.registration_z_level = registration_z_level
- self.coordinates_df = None
- self.pixel_size_um = None
- self.acquisition_params = None
- self.time_points = []
- self.regions = []
- self.overlap_percent = overlap_percent
- self.scan_pattern = FOV_PATTERN
- self.init_stitching_parameters()
-
- def init_stitching_parameters(self):
- self.is_rgb = {}
- self.channel_names = []
- self.mono_channel_names = []
- self.channel_colors = []
- self.num_z = self.num_c = self.num_t = 1
- self.input_height = self.input_width = 0
- self.num_pyramid_levels = 5
- self.flatfields = {}
- self.stitching_data = {}
- self.dtype = np.uint16
- self.chunks = None
- self.h_shift = (0, 0)
- if self.scan_pattern == "S-Pattern":
- self.h_shift_rev = (0, 0)
- self.h_shift_rev_odd = 0 # 0 reverse even rows, 1 reverse odd rows
- self.v_shift = (0, 0)
- self.x_positions = set()
- self.y_positions = set()
-
- def get_time_points(self):
- self.time_points = [
- d
- for d in os.listdir(self.input_folder)
- if os.path.isdir(os.path.join(self.input_folder, d)) and d.isdigit()
- ]
- self.time_points.sort(key=int)
- return self.time_points
-
- def extract_acquisition_parameters(self):
- acquistion_params_path = os.path.join(self.input_folder, "acquisition parameters.json")
- with open(acquistion_params_path, "r") as file:
- self.acquisition_params = json.load(file)
-
- def get_pixel_size_from_params(self):
- obj_mag = self.acquisition_params["objective"]["magnification"]
- obj_tube_lens_mm = self.acquisition_params["objective"]["tube_lens_f_mm"]
- sensor_pixel_size_um = self.acquisition_params["sensor_pixel_size_um"]
- tube_lens_mm = self.acquisition_params["tube_lens_mm"]
-
- obj_focal_length_mm = obj_tube_lens_mm / obj_mag
- actual_mag = tube_lens_mm / obj_focal_length_mm
- self.pixel_size_um = sensor_pixel_size_um / actual_mag
- print("pixel_size_um:", self.pixel_size_um)
-
- def parse_filenames(self):
- self.extract_acquisition_parameters()
- self.get_pixel_size_from_params()
-
- self.stitching_data = {}
- self.regions = set()
- self.channel_names = set()
- max_z = 0
- max_fov = 0
-
- for t, time_point in enumerate(self.time_points):
- image_folder = os.path.join(self.input_folder, str(time_point))
- coordinates_path = os.path.join(self.input_folder, time_point, "coordinates.csv")
- coordinates_df = pd.read_csv(coordinates_path)
-
- print(f"Processing timepoint {time_point}, image folder: {image_folder}")
-
- image_files = sorted(
- [f for f in os.listdir(image_folder) if f.endswith((".bmp", ".tiff")) and "focus_camera" not in f]
- )
-
- if not image_files:
- raise Exception(f"No valid files found in directory for timepoint {time_point}.")
-
- for file in image_files:
- parts = file.split("_", 3)
- region, fov, z_level, channel = parts[0], int(parts[1]), int(parts[2]), os.path.splitext(parts[3])[0]
- channel = channel.replace("_", " ").replace("full ", "full_")
-
- coord_row = coordinates_df[
- (coordinates_df["region"] == region)
- & (coordinates_df["fov"] == fov)
- & (coordinates_df["z_level"] == z_level)
- ]
-
- if coord_row.empty:
- print(f"Warning: No matching coordinates found for file {file}")
- continue
-
- coord_row = coord_row.iloc[0]
-
- key = (t, region, fov, z_level, channel)
- self.stitching_data[key] = {
- "filepath": os.path.join(image_folder, file),
- "x": coord_row["x (mm)"],
- "y": coord_row["y (mm)"],
- "z": coord_row["z (um)"],
- "channel": channel,
- "z_level": z_level,
- "region": region,
- "fov_idx": fov,
- "t": t,
- }
-
- self.regions.add(region)
- self.channel_names.add(channel)
- max_z = max(max_z, z_level)
- max_fov = max(max_fov, fov)
-
- self.regions = sorted(self.regions)
- self.channel_names = sorted(self.channel_names)
- self.num_t = len(self.time_points)
- self.num_z = max_z + 1
- self.num_fovs_per_region = max_fov + 1
-
- # Set up image parameters based on the first image
- first_key = list(self.stitching_data.keys())[0]
- first_region = self.stitching_data[first_key]["region"]
- first_fov = self.stitching_data[first_key]["fov_idx"]
- first_z_level = self.stitching_data[first_key]["z_level"]
- first_image = dask_imread(self.stitching_data[first_key]["filepath"])[0]
-
- self.dtype = first_image.dtype
- if len(first_image.shape) == 2:
- self.input_height, self.input_width = first_image.shape
- elif len(first_image.shape) == 3:
- self.input_height, self.input_width = first_image.shape[:2]
- else:
- raise ValueError(f"Unexpected image shape: {first_image.shape}")
- self.chunks = (1, 1, 1, 512, 512)
-
- # Set up final monochrome channels
- self.mono_channel_names = []
- for channel in self.channel_names:
- channel_key = (t, first_region, first_fov, first_z_level, channel)
- channel_image = dask_imread(self.stitching_data[channel_key]["filepath"])[0]
- if len(channel_image.shape) == 3 and channel_image.shape[2] == 3:
- self.is_rgb[channel] = True
- channel = channel.split("_")[0]
- self.mono_channel_names.extend([f"{channel}_R", f"{channel}_G", f"{channel}_B"])
- else:
- self.is_rgb[channel] = False
- self.mono_channel_names.append(channel)
- self.num_c = len(self.mono_channel_names)
- self.channel_colors = [self.get_channel_color(name) for name in self.mono_channel_names]
-
- print(f"FOV dimensions: {self.input_height}x{self.input_width}")
- print(f"{self.num_z} Z levels, {self.num_t} Time points")
- print(f"{self.num_c} Channels: {self.mono_channel_names}")
- print(f"{len(self.regions)} Regions: {self.regions}")
-
- def get_channel_color(self, channel_name):
- color_map = {
- "405": 0x0000FF, # Blue
- "488": 0x00FF00, # Green
- "561": 0xFFCF00, # Yellow
- "638": 0xFF0000, # Red
- "730": 0x770000, # Dark Red"
- "_B": 0x0000FF, # Blue
- "_G": 0x00FF00, # Green
- "_R": 0xFF0000, # Red
- }
- for key in color_map:
- if key in channel_name:
- return color_map[key]
- return 0xFFFFFF # Default to white if no match found
-
- def calculate_output_dimensions(self, region):
- region_data = [tile_info for key, tile_info in self.stitching_data.items() if key[1] == region]
-
- if not region_data:
- raise ValueError(f"No data found for region {region}")
-
- self.x_positions = sorted(set(tile_info["x"] for tile_info in region_data))
- self.y_positions = sorted(set(tile_info["y"] for tile_info in region_data))
-
- if self.use_registration: # Add extra space for shifts
- num_cols = len(self.x_positions)
- num_rows = len(self.y_positions)
-
- if self.scan_pattern == "S-Pattern":
- max_h_shift = (max(self.h_shift[0], self.h_shift_rev[0]), max(self.h_shift[1], self.h_shift_rev[1]))
- else:
- max_h_shift = self.h_shift
-
- width_pixels = int(
- self.input_width + ((num_cols - 1) * (self.input_width + max_h_shift[1]))
- ) # horizontal width with overlap
- width_pixels += abs((num_rows - 1) * self.v_shift[1]) # horizontal shift from vertical registration
- height_pixels = int(
- self.input_height + ((num_rows - 1) * (self.input_height + self.v_shift[0]))
- ) # vertical height with overlap
- height_pixels += abs((num_cols - 1) * max_h_shift[0]) # vertical shift from horizontal registration
-
- else: # Use coordinates shifts
- width_mm = max(self.x_positions) - min(self.x_positions) + (self.input_width * self.pixel_size_um / 1000)
- height_mm = max(self.y_positions) - min(self.y_positions) + (self.input_height * self.pixel_size_um / 1000)
-
- width_pixels = int(np.ceil(width_mm * 1000 / self.pixel_size_um))
- height_pixels = int(np.ceil(height_mm * 1000 / self.pixel_size_um))
-
- # Round up to the next multiple of 4
- width_pixels = ((width_pixels + 3) & ~3) + 4
- height_pixels = ((height_pixels + 3) & ~3) + 4
-
- # Get the number of rows and columns
- if len(self.regions) > 1:
- rows, columns = self.get_rows_and_columns()
- max_dimension = max(len(rows), len(columns))
- else:
- max_dimension = 1
-
- # Calculate the number of pyramid levels
- self.num_pyramid_levels = math.ceil(np.log2(max(width_pixels, height_pixels) / 1024 * max_dimension))
- print("# Pyramid levels:", self.num_pyramid_levels)
- return width_pixels, height_pixels
-
- def init_output(self, region):
- width, height = self.calculate_output_dimensions(region)
- self.output_shape = (self.num_t, self.num_c, self.num_z, height, width)
- print(f"Output shape for region {region}: {self.output_shape}")
- return da.zeros(self.output_shape, dtype=self.dtype, chunks=self.chunks)
-
- def get_flatfields(self, progress_callback=None):
- def process_images(images, channel_name):
- if images.size == 0:
- print(f"WARNING: No images found for channel {channel_name}")
- return
-
- if images.ndim != 3 and images.ndim != 4:
- raise ValueError(
- f"Images must be 3 or 4-dimensional array, with dimension of (T, Y, X) or (T, Z, Y, X). Got shape {images.shape}"
- )
-
- basic = BaSiC(get_darkfield=False, smoothness_flatfield=1)
- basic.fit(images)
- channel_index = self.mono_channel_names.index(channel_name)
- self.flatfields[channel_index] = basic.flatfield
- if progress_callback:
- progress_callback(channel_index + 1, self.num_c)
-
- for channel in self.channel_names:
- print(f"Calculating {channel} flatfield...")
- images = []
- for t in self.time_points:
- time_images = [
- dask_imread(tile["filepath"])[0]
- for key, tile in self.stitching_data.items()
- if tile["channel"] == channel and key[0] == int(t)
- ]
- if not time_images:
- print(f"WARNING: No images found for channel {channel} at timepoint {t}")
- continue
- random.shuffle(time_images)
- selected_tiles = time_images[: min(32, len(time_images))]
- images.extend(selected_tiles)
-
- if not images:
- print(f"WARNING: No images found for channel {channel} across all timepoints")
- continue
-
- images = np.array(images)
-
- if images.ndim == 3:
- # Images are in the shape (N, Y, X)
- process_images(images, channel)
- elif images.ndim == 4:
- if images.shape[-1] == 3:
- # Images are in the shape (N, Y, X, 3) for RGB images
- images_r = images[..., 0]
- images_g = images[..., 1]
- images_b = images[..., 2]
- channel = channel.split("_")[0]
- process_images(images_r, channel + "_R")
- process_images(images_g, channel + "_G")
- process_images(images_b, channel + "_B")
- else:
- # Images are in the shape (N, Z, Y, X)
- process_images(images, channel)
- else:
- raise ValueError(f"Unexpected number of dimensions in images array: {images.ndim}")
-
- def calculate_shifts(self, region):
- region_data = [v for k, v in self.stitching_data.items() if k[1] == region]
-
- # Get unique x and y positions
- x_positions = sorted(set(tile["x"] for tile in region_data))
- y_positions = sorted(set(tile["y"] for tile in region_data))
-
- # Initialize shifts
- self.h_shift = (0, 0)
- self.v_shift = (0, 0)
-
- # Set registration channel if not already set
- if not self.registration_channel:
- self.registration_channel = self.channel_names[0]
- elif self.registration_channel not in self.channel_names:
- print(
- f"Warning: Specified registration channel '{self.registration_channel}' not found. Using {self.channel_names[0]}."
- )
- self.registration_channel = self.channel_names[0]
-
- max_x_overlap = round(self.input_width * self.overlap_percent / 2 / 100)
- max_y_overlap = round(self.input_height * self.overlap_percent / 2 / 100)
- print(f"Expected shifts - Horizontal: {(0, -max_x_overlap)}, Vertical: {(-max_y_overlap , 0)}")
-
- # Find center positions
- center_x_index = (len(x_positions) - 1) // 2
- center_y_index = (len(y_positions) - 1) // 2
-
- center_x = x_positions[center_x_index]
- center_y = y_positions[center_y_index]
-
- right_x = None
- bottom_y = None
-
- # Calculate horizontal shift
- if center_x_index + 1 < len(x_positions):
- right_x = x_positions[center_x_index + 1]
- center_tile = self.get_tile(
- region, center_x, center_y, self.registration_channel, self.registration_z_level
- )
- right_tile = self.get_tile(region, right_x, center_y, self.registration_channel, self.registration_z_level)
-
- if center_tile is not None and right_tile is not None:
- self.h_shift = self.calculate_horizontal_shift(center_tile, right_tile, max_x_overlap)
- else:
- print(f"Warning: Missing tiles for horizontal shift calculation in region {region}.")
-
- # Calculate vertical shift
- if center_y_index + 1 < len(y_positions):
- bottom_y = y_positions[center_y_index + 1]
- center_tile = self.get_tile(
- region, center_x, center_y, self.registration_channel, self.registration_z_level
- )
- bottom_tile = self.get_tile(
- region, center_x, bottom_y, self.registration_channel, self.registration_z_level
- )
-
- if center_tile is not None and bottom_tile is not None:
- self.v_shift = self.calculate_vertical_shift(center_tile, bottom_tile, max_y_overlap)
- else:
- print(f"Warning: Missing tiles for vertical shift calculation in region {region}.")
-
- if self.scan_pattern == "S-Pattern" and right_x and bottom_y:
- center_tile = self.get_tile(
- region, center_x, bottom_y, self.registration_channel, self.registration_z_level
- )
- right_tile = self.get_tile(region, right_x, bottom_y, self.registration_channel, self.registration_z_level)
-
- if center_tile is not None and right_tile is not None:
- self.h_shift_rev = self.calculate_horizontal_shift(center_tile, right_tile, max_x_overlap)
- self.h_shift_rev_odd = center_y_index % 2 == 0
- print(f"Bi-Directional Horizontal Shift - Reverse Horizontal: {self.h_shift_rev}")
- else:
- print(f"Warning: Missing tiles for reverse horizontal shift calculation in region {region}.")
-
- print(f"Calculated Uni-Directional Shifts - Horizontal: {self.h_shift}, Vertical: {self.v_shift}")
-
- def calculate_horizontal_shift(self, img1, img2, max_overlap):
- img1 = self.normalize_image(img1)
- img2 = self.normalize_image(img2)
-
- margin = int(img1.shape[0] * 0.2) # 20% margin
- img1_overlap = img1[margin:-margin, -max_overlap:]
- img2_overlap = img2[margin:-margin, :max_overlap]
-
- self.visualize_image(img1_overlap, img2_overlap, "horizontal")
-
- shift, error, diffphase = phase_cross_correlation(img1_overlap, img2_overlap, upsample_factor=10)
- return round(shift[0]), round(shift[1] - img1_overlap.shape[1])
-
- def calculate_vertical_shift(self, img1, img2, max_overlap):
- img1 = self.normalize_image(img1)
- img2 = self.normalize_image(img2)
-
- margin = int(img1.shape[1] * 0.2) # 20% margin
- img1_overlap = img1[-max_overlap:, margin:-margin]
- img2_overlap = img2[:max_overlap, margin:-margin]
-
- self.visualize_image(img1_overlap, img2_overlap, "vertical")
-
- shift, error, diffphase = phase_cross_correlation(img1_overlap, img2_overlap, upsample_factor=10)
- return round(shift[0] - img1_overlap.shape[0]), round(shift[1])
-
- def get_tile(self, region, x, y, channel, z_level):
- for key, value in self.stitching_data.items():
- if (
- key[1] == region
- and value["x"] == x
- and value["y"] == y
- and value["channel"] == channel
- and value["z_level"] == z_level
- ):
- try:
- return dask_imread(value["filepath"])[0]
- except FileNotFoundError:
- print(f"Warning: Tile file not found: {value['filepath']}")
- return None
- print(f"Warning: No matching tile found for region {region}, x={x}, y={y}, channel={channel}, z={z_level}")
- return None
-
- def normalize_image(self, img):
- img_min, img_max = img.min(), img.max()
- img_normalized = (img - img_min) / (img_max - img_min)
- scale_factor = np.iinfo(self.dtype).max if np.issubdtype(self.dtype, np.integer) else 1
- return (img_normalized * scale_factor).astype(self.dtype)
-
- def visualize_image(self, img1, img2, title):
- try:
- # Ensure images are numpy arrays
- img1 = np.asarray(img1)
- img2 = np.asarray(img2)
-
- if title == "horizontal":
- combined_image = np.hstack((img1, img2))
- else:
- combined_image = np.vstack((img1, img2))
-
- # Convert to uint8 for saving as PNG
- combined_image_uint8 = (combined_image / np.iinfo(self.dtype).max * 255).astype(np.uint8)
-
- cv2.imwrite(f"{self.input_folder}/{title}.png", combined_image_uint8)
-
- print(f"Saved {title}.png successfully")
- except Exception as e:
- print(f"Error in visualize_image: {e}")
-
- def stitch_and_save_region(self, region, progress_callback=None):
- stitched_images = self.init_output(region) # sets self.x_positions, self.y_positions
- region_data = {k: v for k, v in self.stitching_data.items() if k[1] == region}
- total_tiles = len(region_data)
- processed_tiles = 0
-
- x_min = min(self.x_positions)
- y_min = min(self.y_positions)
-
- for key, tile_info in region_data.items():
- t, _, fov, z_level, channel = key
- tile = dask_imread(tile_info["filepath"])[0]
- if self.use_registration:
- self.col_index = self.x_positions.index(tile_info["x"])
- self.row_index = self.y_positions.index(tile_info["y"])
-
- if self.scan_pattern == "S-Pattern" and self.row_index % 2 == self.h_shift_rev_odd:
- h_shift = self.h_shift_rev
- else:
- h_shift = self.h_shift
-
- # Initialize starting coordinates based on tile position and shift
- x_pixel = int(self.col_index * (self.input_width + h_shift[1]))
- y_pixel = int(self.row_index * (self.input_height + self.v_shift[0]))
-
- # Apply horizontal shift effect on y-coordinate
- if h_shift[0] < 0:
- y_pixel += int(
- (len(self.x_positions) - 1 - self.col_index) * abs(h_shift[0])
- ) # Fov moves up as cols go right
- else:
- y_pixel += int(self.col_index * h_shift[0]) # Fov moves down as cols go right
-
- # Apply vertical shift effect on x-coordinate
- if self.v_shift[1] < 0:
- x_pixel += int(
- (len(self.y_positions) - 1 - self.row_index) * abs(self.v_shift[1])
- ) # Fov moves left as rows go down
- else:
- x_pixel += int(self.row_index * self.v_shift[1]) # Fov moves right as rows go down
-
- else:
- # Calculate base position
- x_pixel = int((tile_info["x"] - x_min) * 1000 / self.pixel_size_um)
- y_pixel = int((tile_info["y"] - y_min) * 1000 / self.pixel_size_um)
-
- self.place_tile(stitched_images, tile, x_pixel, y_pixel, z_level, channel, t)
-
- processed_tiles += 1
- if progress_callback:
- progress_callback(processed_tiles, total_tiles)
-
- self.starting_saving.emit(False)
- if len(self.regions) > 1:
- self.save_region_to_hcs_ome_zarr(region, stitched_images)
- else:
- # self.save_as_ome_zarr(region, stitched_images)
- self.save_region_to_ome_zarr(
- region, stitched_images
- ) # bugs: when starting to save, main gui lags and disconnects
-
- def place_tile(self, stitched_images, tile, x_pixel, y_pixel, z_level, channel, t):
- if len(tile.shape) == 2:
- # Handle 2D grayscale image
- channel_idx = self.mono_channel_names.index(channel)
- self.place_single_channel_tile(stitched_images, tile, x_pixel, y_pixel, z_level, channel_idx, t)
-
- elif len(tile.shape) == 3:
- if tile.shape[2] == 3:
- # Handle RGB image
- channel = channel.split("_")[0]
- for i, color in enumerate(["R", "G", "B"]):
- channel_idx = self.mono_channel_names.index(f"{channel}_{color}")
- self.place_single_channel_tile(
- stitched_images, tile[:, :, i], x_pixel, y_pixel, z_level, channel_idx, t
- )
- elif tile.shape[0] == 1:
- channel_idx = self.mono_channel_names.index(channel)
- self.place_single_channel_tile(stitched_images, tile[0], x_pixel, y_pixel, z_level, channel_idx, t)
- else:
- raise ValueError(f"Unexpected tile shape: {tile.shape}")
-
- def place_single_channel_tile(self, stitched_images, tile, x_pixel, y_pixel, z_level, channel_idx, t):
- if len(stitched_images.shape) != 5:
- raise ValueError(
- f"Unexpected stitched_images shape: {stitched_images.shape}. Expected 5D array (t, c, z, y, x)."
- )
-
- if self.apply_flatfield:
- tile = self.apply_flatfield_correction(tile, channel_idx)
-
- if self.use_registration:
- if self.scan_pattern == "S-Pattern" and self.row_index % 2 == self.h_shift_rev_odd:
- h_shift = self.h_shift_rev
- else:
- h_shift = self.h_shift
-
- # Determine crop for tile edges
- top_crop = max(0, (-self.v_shift[0] // 2) - abs(h_shift[0]) // 2) if self.row_index > 0 else 0 # if y
- bottom_crop = (
- max(0, (-self.v_shift[0] // 2) - abs(h_shift[0]) // 2)
- if self.row_index < len(self.y_positions) - 1
- else 0
- )
- left_crop = max(0, (-h_shift[1] // 2) - abs(self.v_shift[1]) // 2) if self.col_index > 0 else 0
- right_crop = (
- max(0, (-h_shift[1] // 2) - abs(self.v_shift[1]) // 2)
- if self.col_index < len(self.x_positions) - 1
- else 0
- )
-
- # Apply cropping to the tile
- tile = tile[top_crop : tile.shape[0] - bottom_crop, left_crop : tile.shape[1] - right_crop]
-
- # Adjust x_pixel and y_pixel based on cropping
- x_pixel += left_crop
- y_pixel += top_crop
-
- y_end = min(y_pixel + tile.shape[0], stitched_images.shape[3])
- x_end = min(x_pixel + tile.shape[1], stitched_images.shape[4])
-
- try:
- stitched_images[t, channel_idx, z_level, y_pixel:y_end, x_pixel:x_end] = tile[
- : y_end - y_pixel, : x_end - x_pixel
- ]
- except Exception as e:
- print(f"ERROR: Failed to place tile. Details: {str(e)}")
- print(
- f"DEBUG: t:{t}, channel_idx:{channel_idx}, z_level:{z_level}, y:{y_pixel}-{y_end}, x:{x_pixel}-{x_end}"
- )
- print(f"DEBUG: tile slice shape: {tile[:y_end-y_pixel, :x_end-x_pixel].shape}")
- raise
-
- def apply_flatfield_correction(self, tile, channel_idx):
- if channel_idx in self.flatfields:
- return (
- (tile / self.flatfields[channel_idx])
- .clip(min=np.iinfo(self.dtype).min, max=np.iinfo(self.dtype).max)
- .astype(self.dtype)
- )
- return tile
-
- def generate_pyramid(self, image, num_levels):
- pyramid = [image]
- for level in range(1, num_levels):
- scale_factor = 2**level
- factors = {0: 1, 1: 1, 2: 1, 3: scale_factor, 4: scale_factor}
- if isinstance(image, da.Array):
- downsampled = da.coarsen(np.mean, image, factors, trim_excess=True)
- else:
- block_size = (1, 1, 1, scale_factor, scale_factor)
- downsampled = downscale_local_mean(image, block_size)
- pyramid.append(downsampled)
- return pyramid
-
- def save_region_to_hcs_ome_zarr(self, region, stitched_images):
- output_path = os.path.join(self.input_folder, self.output_name)
- store = ome_zarr.io.parse_url(output_path, mode="a").store
- root = zarr.group(store=store)
-
- row, col = region[0], region[1:]
- row_group = root.require_group(row)
- well_group = row_group.require_group(col)
-
- if "well" not in well_group.attrs:
- well_metadata = {
- "images": [{"path": "0", "acquisition": 0}],
- }
- ome_zarr.writer.write_well_metadata(well_group, well_metadata["images"])
-
- image_group = well_group.require_group("0")
-
- pyramid = self.generate_pyramid(stitched_images, self.num_pyramid_levels)
- coordinate_transformations = [
- [
- {
- "type": "scale",
- "scale": [
- 1,
- 1,
- self.acquisition_params.get("dz(um)", 1),
- self.pixel_size_um * (2**i),
- self.pixel_size_um * (2**i),
- ],
- }
- ]
- for i in range(self.num_pyramid_levels)
- ]
-
- axes = [
- {"name": "t", "type": "time", "unit": "second"},
- {"name": "c", "type": "channel"},
- {"name": "z", "type": "space", "unit": "micrometer"},
- {"name": "y", "type": "space", "unit": "micrometer"},
- {"name": "x", "type": "space", "unit": "micrometer"},
- ]
-
- # Prepare channels metadata
- omero_channels = [
- {
- "label": name,
- "color": f"{color:06X}",
- "window": {"start": 0, "end": np.iinfo(self.dtype).max, "min": 0, "max": np.iinfo(self.dtype).max},
- }
- for name, color in zip(self.mono_channel_names, self.channel_colors)
- ]
-
- omero = {"name": f"{region}", "version": "0.4", "channels": omero_channels}
-
- image_group.attrs["omero"] = omero
-
- # Write the multiscale image data and metadata
- ome_zarr.writer.write_multiscale(
- pyramid=pyramid,
- group=image_group,
- chunks=self.chunks,
- axes=axes,
- coordinate_transformations=coordinate_transformations,
- storage_options=dict(chunks=self.chunks),
- name=f"{region}",
- )
-
- def save_as_ome_zarr(self, region, stitched_images):
- output_path = os.path.join(self.input_folder, self.output_name)
- dz_um = self.acquisition_params.get("dz(um)", None)
- sensor_pixel_size_um = self.acquisition_params.get("sensor_pixel_size_um", None)
- channel_minmax = [(np.iinfo(self.dtype).min, np.iinfo(self.dtype).max)] * self.num_c
- for i in range(self.num_c):
- print(
- f"Channel {i}:",
- self.mono_channel_names[i],
- " \tColor:",
- self.channel_colors[i],
- " \tPixel Range:",
- channel_minmax[i],
- )
-
- zarr_writer = OmeZarrWriter(output_path)
- zarr_writer.build_ome(
- size_z=self.num_z,
- image_name=region,
- channel_names=self.mono_channel_names,
- channel_colors=self.channel_colors,
- channel_minmax=channel_minmax,
- )
- zarr_writer.write_image(
- image_data=stitched_images,
- image_name=region,
- physical_pixel_sizes=types.PhysicalPixelSizes(dz_um, self.pixel_size_um, self.pixel_size_um),
- channel_names=self.mono_channel_names,
- channel_colors=self.channel_colors,
- dimension_order="TCZYX",
- scale_num_levels=self.num_pyramid_levels,
- chunk_dims=self.chunks,
- )
-
- def save_region_to_ome_zarr(self, region, stitched_images):
- output_path = os.path.join(self.input_folder, self.output_name)
- store = ome_zarr.io.parse_url(output_path, mode="a").store
- root = zarr.group(store=store)
-
- # Generate the pyramid
- pyramid = self.generate_pyramid(stitched_images, self.num_pyramid_levels)
-
- datasets = []
- for i in range(self.num_pyramid_levels):
- scale = 2**i
- datasets.append(
- {
- "path": str(i),
- "coordinateTransformations": [
- {
- "type": "scale",
- "scale": [
- 1,
- 1,
- self.acquisition_params.get("dz(um)", 1),
- self.pixel_size_um * scale,
- self.pixel_size_um * scale,
- ],
- }
- ],
- }
- )
-
- axes = [
- {"name": "t", "type": "time", "unit": "second"},
- {"name": "c", "type": "channel"},
- {"name": "z", "type": "space", "unit": "micrometer"},
- {"name": "y", "type": "space", "unit": "micrometer"},
- {"name": "x", "type": "space", "unit": "micrometer"},
- ]
-
- ome_zarr.writer.write_multiscales_metadata(root, datasets, axes=axes, name="stitched_image")
-
- omero = {
- "name": "stitched_image",
- "version": "0.4",
- "channels": [
- {
- "label": name,
- "color": f"{color:06X}",
- "window": {"start": 0, "end": np.iinfo(self.dtype).max, "min": 0, "max": np.iinfo(self.dtype).max},
- }
- for name, color in zip(self.mono_channel_names, self.channel_colors)
- ],
- }
- root.attrs["omero"] = omero
-
- coordinate_transformations = [dataset["coordinateTransformations"] for dataset in datasets]
-
- ome_zarr.writer.write_multiscale(
- pyramid=pyramid,
- group=root,
- axes="tczyx",
- coordinate_transformations=coordinate_transformations,
- storage_options=dict(chunks=self.chunks),
- )
-
- def write_stitched_plate_metadata(self):
- output_path = os.path.join(self.input_folder, self.output_name)
- store = ome_zarr.io.parse_url(output_path, mode="a").store
- root = zarr.group(store=store)
-
- rows, columns = self.get_rows_and_columns()
- well_paths = [f"{well_id[0]}/{well_id[1:]}" for well_id in sorted(self.regions)]
-
- plate_metadata = {
- "name": "Stitched Plate",
- "rows": [{"name": row} for row in rows],
- "columns": [{"name": col} for col in columns],
- "wells": [
- {"path": path, "rowIndex": rows.index(path[0]), "columnIndex": columns.index(path[2:])}
- for path in well_paths
- ],
- "field_count": 1,
- "acquisitions": [{"id": 0, "maximumfieldcount": 1, "name": "Stitched Acquisition"}],
- }
-
- ome_zarr.writer.write_plate_metadata(
- root,
- rows=[row["name"] for row in plate_metadata["rows"]],
- columns=[col["name"] for col in plate_metadata["columns"]],
- wells=plate_metadata["wells"],
- acquisitions=plate_metadata["acquisitions"],
- name=plate_metadata["name"],
- field_count=plate_metadata["field_count"],
- )
-
- def get_rows_and_columns(self):
- rows = sorted(set(region[0] for region in self.regions))
- columns = sorted(set(region[1:] for region in self.regions))
- return rows, columns
-
- def create_ome_tiff(self, stitched_images):
- output_path = os.path.join(self.input_folder, self.output_name)
-
- with TiffWriter(output_path, bigtiff=True, ome=True) as tif:
- tif.write(
- data=stitched_images,
- shape=stitched_images.shape,
- dtype=self.dtype,
- photometric="minisblack",
- planarconfig="separate",
- metadata={
- "axes": "TCZYX",
- "Channel": {"Name": self.mono_channel_names},
- "SignificantBits": stitched_images.dtype.itemsize * 8,
- "Pixels": {
- "PhysicalSizeX": self.pixel_size_um,
- "PhysicalSizeXUnit": "µm",
- "PhysicalSizeY": self.pixel_size_um,
- "PhysicalSizeYUnit": "µm",
- "PhysicalSizeZ": self.acquisition_params.get("dz(um)", 1.0),
- "PhysicalSizeZUnit": "µm",
- },
- },
- )
-
- print(f"Data saved in OME-TIFF format at: {output_path}")
- self.finished_saving.emit(output_path, self.dtype)
-
- def run(self):
- stime = time.time()
- # try:
- self.get_time_points()
- self.parse_filenames()
-
- if self.apply_flatfield:
- print("Calculating flatfields...")
- self.getting_flatfields.emit()
- self.get_flatfields(progress_callback=self.update_progress.emit)
- print("time to apply flatfields", time.time() - stime)
-
- if self.num_fovs_per_region > 1:
- self.run_regions()
- else:
- self.run_fovs() # only displays one fov per region even though all fovs are saved in zarr with metadata
-
- # except Exception as e:
- # print("time before error", time.time() - stime)
- # print(f"Error while stitching: {e}")
- # raise
-
- def run_regions(self):
- stime = time.time()
- if len(self.regions) > 1:
- self.write_stitched_plate_metadata()
-
- if self.use_registration:
- print(f"\nCalculating shifts for region {self.regions[0]}...")
- self.calculate_shifts(self.regions[0])
-
- for region in self.regions:
- wtime = time.time()
-
- # if self.use_registration:
- # print(f"\nCalculating shifts for region {region}...")
- # self.calculate_shifts(region)
-
- self.starting_stitching.emit()
- print(f"\nstarting stitching for region {region}...")
- self.stitch_and_save_region(region, progress_callback=self.update_progress.emit)
-
- sttime = time.time()
- print(f"time to stitch and save region {region}", time.time() - wtime)
- print(f"...done with region:{region}")
-
- if self.output_format.endswith(".ome.tiff"):
- self.create_ome_tiff(self.stitched_images)
- else:
- output_path = os.path.join(self.input_folder, self.output_name)
- print(f"Data saved in OME-ZARR format at: {output_path}")
- self.print_zarr_structure(output_path)
-
- self.finished_saving.emit(os.path.join(self.input_folder, self.output_name), self.dtype)
- print("total time to stitch + save:", time.time() - stime)
-
- # ________________________________________________________________________________________________________________________________
- # run_fovs: directly save fovs to final hcs ome zarr
- #
- # issue:
- # only shows one fov per region when there are multiple fovs
- # - (fix metadata? translation, scale, path, multiscale?)
- # correct channels in napari, well + plate metadata, z-stack shape, time-point shape
-
- def run_fovs(self):
- stime = time.time()
- self.starting_stitching.emit()
-
- output_path = os.path.join(self.input_folder, self.output_name)
- store = ome_zarr.io.parse_url(output_path, mode="a").store
- root = zarr.group(store=store)
-
- self.write_fov_plate_metadata(root)
-
- total_fovs = sum(
- len(set([k[2] for k in self.stitching_data.keys() if k[1] == region])) for region in self.regions
- )
- processed_fovs = 0
-
- for region in self.regions:
- region_data = {k: v for k, v in self.stitching_data.items() if k[1] == region}
- well_group = self.write_fov_well_metadata(root, region)
-
- for fov_idx in range(self.num_fovs_per_region):
- fov_data = {k: v for k, v in region_data.items() if k[2] == fov_idx}
-
- if not fov_data:
- continue # Skip if no data for this FOV index
-
- tcz_fov = self.compile_single_fov_data(fov_data)
- self.write_fov_to_zarr(well_group, tcz_fov, fov_idx, fov_data)
- processed_fovs += 1
- self.update_progress.emit(processed_fovs, total_fovs)
-
- omero_channels = [
- {
- "label": name,
- "color": f"{color:06X}",
- "window": {"start": 0, "end": np.iinfo(self.dtype).max, "min": 0, "max": np.iinfo(self.dtype).max},
- }
- for name, color in zip(self.mono_channel_names, self.channel_colors)
- ]
-
- omero = {"name": "hcs-acquisition", "version": "0.4", "channels": omero_channels}
-
- root.attrs["omero"] = omero
-
- print(f"Data saved in OME-ZARR format at: {output_path}")
- self.print_zarr_structure(output_path)
- self.finished_saving.emit(output_path, self.dtype)
-
- print("total time to save FOVs:", time.time() - stime)
-
- def compile_single_fov_data(self, fov_data):
- # Initialize a 5D array to hold all the data for this FOV
- tcz_fov = np.zeros((self.num_t, self.num_c, self.num_z, self.input_height, self.input_width), dtype=self.dtype)
-
- for key, scan_info in fov_data.items():
- t, _, _, z_level, channel = key
- image = dask_imread(scan_info["filepath"])[0]
-
- if self.apply_flatfield:
- channel_idx = self.mono_channel_names.index(channel)
- image = self.apply_flatfield_correction(image, channel_idx)
-
- if len(image.shape) == 3 and image.shape[2] == 3: # RGB image
- channel = channel.split("_")[0]
- for i, color in enumerate(["R", "G", "B"]):
- c_idx = self.mono_channel_names.index(f"{channel}_{color}")
- tcz_fov[t, c_idx, z_level] = image[:, :, i]
- else: # Grayscale image
- c_idx = self.mono_channel_names.index(channel)
- tcz_fov[t, c_idx, z_level] = image
-
- return da.from_array(tcz_fov, chunks=self.chunks)
-
- def write_fov_plate_metadata(self, root):
- rows, columns = self.get_rows_and_columns()
- well_paths = [f"{well_id[0]}/{well_id[1:]}" for well_id in sorted(self.regions)]
-
- plate_metadata = {
- "name": "Sample",
- "rows": [{"name": row} for row in rows],
- "columns": [{"name": col} for col in columns],
- "wells": [
- {"path": path, "rowIndex": rows.index(path[0]), "columnIndex": columns.index(path[2:])}
- for path in well_paths
- ],
- "field_count": self.num_fovs_per_region * len(self.regions),
- "acquisitions": [
- {"id": 0, "maximumfieldcount": self.num_fovs_per_region, "name": "Multipoint Acquisition"}
- ],
- }
-
- ome_zarr.writer.write_plate_metadata(
- root,
- rows=[row["name"] for row in plate_metadata["rows"]],
- columns=[col["name"] for col in plate_metadata["columns"]],
- wells=plate_metadata["wells"],
- acquisitions=plate_metadata["acquisitions"],
- name=plate_metadata["name"],
- field_count=plate_metadata["field_count"],
- )
-
- def write_fov_well_metadata(self, root, region):
- row, col = region[0], region[1:]
- row_group = root.require_group(row)
- well_group = row_group.require_group(col)
-
- if "well" not in well_group.attrs:
- well_metadata = {
- "images": [{"path": str(fov_idx), "acquisition": 0} for fov_idx in range(self.num_fovs_per_region)]
- }
- ome_zarr.writer.write_well_metadata(well_group, well_metadata["images"])
- return well_group
-
- def write_fov_to_zarr(self, well_group, tcz_fov, fov_idx, fov_data):
- axes = [
- {"name": "t", "type": "time", "unit": "second"},
- {"name": "c", "type": "channel"},
- {"name": "z", "type": "space", "unit": "micrometer"},
- {"name": "y", "type": "space", "unit": "micrometer"},
- {"name": "x", "type": "space", "unit": "micrometer"},
- ]
-
- # Generate pyramid levels
- pyramid = self.generate_pyramid(tcz_fov, self.num_pyramid_levels)
-
- # Get the position of the FOV (use the first scan in fov_data)
- first_scan = next(iter(fov_data.values()))
- x_mm, y_mm = first_scan["x"], first_scan["y"]
-
- # Get the z positions
- z_positions = sorted(set(scan_info["z"] for scan_info in fov_data.values()))
- z_min = min(z_positions)
- dz = self.acquisition_params.get("dz(um)", 1.0)
-
- # Create coordinate transformations for each pyramid level
- coordinate_transformations = []
- for level in range(len(pyramid)):
- scale_factor = 2**level
- coordinate_transformations.append(
- [
- {
- "type": "scale",
- "scale": [1, 1, dz, self.pixel_size_um * scale_factor, self.pixel_size_um * scale_factor],
- },
- {"type": "translation", "translation": [0, 0, z_min, y_mm * 1000, x_mm * 1000]},
- ]
- )
-
- image_group = well_group.require_group(str(fov_idx))
-
- # Prepare datasets for multiscales metadata
- datasets = [
- {"path": str(i), "coordinateTransformations": coord_trans}
- for i, coord_trans in enumerate(coordinate_transformations)
- ]
-
- # Write multiscales metadata
- ome_zarr.writer.write_multiscales_metadata(
- group=image_group,
- datasets=datasets,
- axes=axes,
- name=f"FOV_{fov_idx}", # This will be passed as part of **metadata
- )
-
- # Write the actual data
- ome_zarr.writer.write_multiscale(
- pyramid=pyramid,
- group=image_group,
- axes=axes,
- coordinate_transformations=coordinate_transformations,
- storage_options=dict(chunks=self.chunks),
- )
-
- # Add OMERO metadata
- omero_channels = [
- {
- "label": name,
- "color": f"{color:06X}",
- "window": {"start": 0, "end": np.iinfo(self.dtype).max, "min": 0, "max": np.iinfo(self.dtype).max},
- }
- for name, color in zip(self.mono_channel_names, self.channel_colors)
- ]
-
- omero = {"name": f"FOV_{fov_idx}", "version": "0.4", "channels": omero_channels}
-
- image_group.attrs["omero"] = omero
-
- def print_zarr_structure(self, path, indent=""):
- root = zarr.open(path, mode="r")
- print(f"Zarr Tree and Metadata for: {path}")
- print(root.tree())
- print(dict(root.attrs))
diff --git a/software/control/tracking.py b/software/control/tracking.py
deleted file mode 100755
index 93e224b53..000000000
--- a/software/control/tracking.py
+++ /dev/null
@@ -1,234 +0,0 @@
-import control.utils_.image_processing as image_processing
-import numpy as np
-from os.path import realpath, dirname, join
-
-try:
- import torch
- from control.DaSiamRPN.code.net import SiamRPNvot
-
- print(1)
- from control.DaSiamRPN.code import vot
-
- print(2)
- from control.DaSiamRPN.code.utils import get_axis_aligned_bbox, cxy_wh_2_rect
-
- print(3)
- from control.DaSiamRPN.code.run_SiamRPN import SiamRPN_init, SiamRPN_track
-
- print(4)
-except Exception as e:
- print(e)
- # print('Warning: DaSiamRPN is not available!')
-from control._def import Tracking
-import cv2
-
-
-class Tracker_Image(object):
- """
- SLOTS: update_tracker_type, Connected to: Tracking Widget
- """
-
- def __init__(self):
- # Define list of trackers being used(maybe do this as a definition?)
- # OpenCV tracking suite
- # self.OPENCV_OBJECT_TRACKERS = {}
- self.OPENCV_OBJECT_TRACKERS = {
- "csrt": cv2.legacy.TrackerCSRT_create,
- "kcf": cv2.legacy.TrackerKCF_create,
- "mil": cv2.legacy.TrackerMIL_create,
- }
- try:
- self.OPENCV_OBJECT_TRACKERS = {
- "csrt": cv2.legacy.TrackerCSRT_create,
- "kcf": cv2.legacy.TrackerKCF_create,
- "boosting": cv2.legacy.TrackerBoosting_create,
- "mil": cv2.legacy.TrackerMIL_create,
- "tld": cv2.legacy.TrackerTLD_create,
- "medianflow": cv2.legacy.TrackerMedianFlow_create,
- "mosse": cv2.legacy.TrackerMOSSE_create,
- }
- except:
- print("Warning: OpenCV-Contrib trackers unavailable!")
-
- # Neural Net based trackers
- self.NEURALNETTRACKERS = {"daSiamRPN": []}
- try:
- # load net
- self.net = SiamRPNvot()
- self.net.load_state_dict(
- torch.load(join(realpath(dirname(__file__)), "DaSiamRPN", "code", "SiamRPNOTB.model"))
- )
- self.net.eval().cuda()
- print("Finished loading net ...")
- except Exception as e:
- print(e)
- print("No neural net model found ...")
- print("reverting to default OpenCV tracker")
-
- # Image Tracker type
- self.tracker_type = Tracking.DEFAULT_TRACKER
- # Init method for tracker
- self.init_method = Tracking.DEFAULT_INIT_METHOD
- # Create the tracker
- self.create_tracker()
-
- # Centroid of object from the image
- self.centroid_image = None # (2,1)
- self.bbox = None
- self.rect_pts = None
- self.roi_bbox = None
- self.origin = np.array([0, 0])
-
- self.isCentroidFound = False
- self.trackerActive = False
- self.searchArea = None
- self.is_color = None
-
- def track(self, image, thresh_image, is_first_frame=False):
-
- # case 1: initialize the tracker
- if is_first_frame == True or self.trackerActive == False:
- # tracker initialization - using ROI
- if self.init_method == "roi":
- self.bbox = tuple(self.roi_bbox)
- self.centroid_image = self.centroid_from_bbox(self.bbox)
- self.isCentroidFound = True
- # tracker initialization - using thresholded image
- else:
- self.isCentroidFound, self.centroid_image, self.bbox = image_processing.find_centroid_basic_Rect(
- thresh_image
- )
- self.bbox = image_processing.scale_square_bbox(self.bbox, Tracking.BBOX_SCALE_FACTOR, square=True)
- # initialize the tracker
- if self.bbox is not None:
- print("Starting tracker with initial bbox: {}".format(self.bbox))
- self._initialize_tracker(image, self.centroid_image, self.bbox)
- self.trackerActive = True
- self.rect_pts = self.rectpts_from_bbox(self.bbox)
-
- # case 2: continue tracking an object using tracking
- else:
- # Find centroid using the tracking.
- objectFound, self.bbox = self._update_tracker(image, thresh_image) # (x,y,w,h)
- if objectFound:
- self.isCentroidFound = True
- self.centroid_image = self.centroid_from_bbox(self.bbox) + self.origin
- self.bbox = np.array(self.bbox)
- self.bbox[0], self.bbox[1] = self.bbox[0] + self.origin[0], self.bbox[1] + self.origin[1]
- self.rect_pts = self.rectpts_from_bbox(self.bbox)
- else:
- print("No object found ...")
- self.isCentroidFound = False
- self.trackerActive = False
- return self.isCentroidFound, self.centroid_image, self.rect_pts
-
- def reset(self):
- print("Reset image tracker state")
- self.is_first_frame = True
- self.trackerActive = False
- self.isCentroidFound = False
-
- def create_tracker(self):
- if self.tracker_type in self.OPENCV_OBJECT_TRACKERS.keys():
- self.tracker = self.OPENCV_OBJECT_TRACKERS[self.tracker_type]()
- elif self.tracker_type in self.NEURALNETTRACKERS.keys():
- print("Using {} tracker".format(self.tracker_type))
- pass
-
- def _initialize_tracker(self, image, centroid, bbox):
- bbox = tuple(int(x) for x in bbox)
- # check if the image is color or not
- if len(image.shape) < 3:
- self.is_color = False
- # Initialize the OpenCV based tracker
- if self.tracker_type in self.OPENCV_OBJECT_TRACKERS.keys():
- print("Initializing openCV tracker")
- print(self.tracker_type)
- print(bbox)
- if self.is_color == False:
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
- self.create_tracker() # for a new track, just calling self.tracker.init(image,bbox) is not sufficient, this line needs to be called
- self.tracker.init(image, bbox)
- # Initialize Neural Net based Tracker
- elif self.tracker_type in self.NEURALNETTRACKERS.keys():
- # Initialize the tracker with this centroid position
- print("Initializing with daSiamRPN tracker")
- target_pos, target_sz = np.array([centroid[0], centroid[1]]), np.array([bbox[2], bbox[3]])
- if self.is_color == False:
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
- self.state = SiamRPN_init(image, target_pos, target_sz, self.net)
- print("daSiamRPN tracker initialized")
- else:
- pass
-
- def _update_tracker(self, image, thresh_image):
- # Input: image or thresh_image
- # Output: new_bbox based on tracking
- new_bbox = None
- # tracking w/ openCV tracker
- if self.tracker_type in self.OPENCV_OBJECT_TRACKERS.keys():
- self.origin = np.array([0, 0])
- # (x,y,w,h)\
- if self.is_color == False:
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
- ok, new_bbox = self.tracker.update(image)
- return ok, new_bbox
- # tracking w/ the neural network-based tracker
- elif self.tracker_type in self.NEURALNETTRACKERS.keys():
- self.origin = np.array([0, 0])
- if self.is_color == False:
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
- self.state = SiamRPN_track(self.state, image)
- ok = True
- if ok:
- # (x,y,w,h)
- new_bbox = cxy_wh_2_rect(self.state["target_pos"], self.state["target_sz"])
- new_bbox = [int(l) for l in new_bbox]
- # print('Updated daSiamRPN tracker')
- return ok, new_bbox
- # tracking w/ nearest neighbhour using the thresholded image
- else:
- # If no tracker is specified, use basic thresholding and
- # nearest neighbhour tracking. i.e Look for objects in a search region
- # near the last detected centroid
-
- # Get the latest thresholded image from the queue
- # thresh_image =
- pts, thresh_image_cropped = image_processing.crop(thresh_image, self.centroid_image, self.searchArea)
- self.origin = pts[0]
- isCentroidFound, centroid, new_bbox = image_processing.find_centroid_basic_Rect(thresh_image_cropped)
- return isCentroidFound, new_bbox
- # @@@ Can add additional methods here for future tracker implementations
-
- # Signal from Tracking Widget connects to this Function
- def update_tracker_type(self, tracker_type):
- self.tracker_type = tracker_type
- print("set tracker set to {}".format(self.tracker_type))
- # self.create_tracker()
-
- def update_init_method(self, method):
- self.init_method = method
- print("Tracking init method set to : {}".format(self.init_method))
-
- def centroid_from_bbox(self, bbox):
- # Coordinates of the object centroid are taken as the center of the bounding box
- assert len(bbox) == 4
- cx = int(bbox[0] + bbox[2] / 2)
- cy = int(bbox[1] + bbox[3] / 2)
- centroid = np.array([cx, cy])
- return centroid
-
- def rectpts_from_bbox(self, bbox):
- if self.bbox is not None:
- pts = np.array([[bbox[0], bbox[1]], [bbox[0] + bbox[2], bbox[1] + bbox[3]]], dtype="int")
- else:
- pts = None
- return pts
-
- def update_searchArea(self, value):
- self.searchArea = value
-
- def set_roi_bbox(self, bbox):
- # Updates roi bbox from ImageDisplayWindow
- self.roi_bbox = bbox
- print("Rec bbox from ImageDisplay: {}".format(self.roi_bbox))
diff --git a/software/control/utils_/image_processing.py b/software/control/utils_/image_processing.py
deleted file mode 100644
index 8621b3543..000000000
--- a/software/control/utils_/image_processing.py
+++ /dev/null
@@ -1,294 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Mon May 7 19:44:40 2018
-
-@author: Francois and Deepak
-"""
-
-import numpy as np
-import cv2
-from scipy.ndimage.filters import laplace
-from numpy import std, square, mean
-
-# color is a vector HSV whose size is 3
-
-
-def default_lower_HSV(color):
- c = [0, 100, 100]
- c[0] = np.max([color[0] - 10, 0])
- c[1] = np.max([color[1] - 40, 0])
- c[2] = np.max([color[2] - 40, 0])
- return np.array(c, dtype="uint8")
-
-
-def default_upper_HSV(color):
- c = [0, 255, 255]
- c[0] = np.min([color[0] + 10, 178])
- c[1] = np.min([color[1] + 40, 255])
- c[2] = np.min([color[2] + 40, 255])
- return np.array(c, dtype="uint8")
-
-
-def threshold_image(image_BGR, LOWER, UPPER):
- image_HSV = cv2.cvtColor(image_BGR, cv2.COLOR_BGR2HSV)
- imgMask = 255 * np.array(cv2.inRange(image_HSV, LOWER, UPPER), dtype="uint8") # The tracked object will be in white
- imgMask = cv2.erode(
- imgMask, None, iterations=2
- ) # Do a series of erosions and dilations on the thresholded image to reduce smaller blobs
- imgMask = cv2.dilate(imgMask, None, iterations=2)
-
- return imgMask
-
-
-def threshold_image_gray(image_gray, LOWER, UPPER):
- imgMask = np.array((image_gray >= LOWER) & (image_gray <= UPPER), dtype="uint8")
-
- # imgMask = cv2.inRange(cv2.UMat(image_gray), LOWER, UPPER) #The tracked object will be in white
- imgMask = cv2.erode(
- imgMask, None, iterations=2
- ) # Do a series of erosions and dilations on the thresholded image to reduce smaller blobs
- imgMask = cv2.dilate(imgMask, None, iterations=2)
-
- return imgMask
-
-
-def bgr2gray(image_BGR):
- return cv2.cvtColor(image_BGR, cv2.COLOR_BGR2GRAY)
-
-
-def crop(image, center, imSize): # center is the vector [x,y]
- imH, imW, *rest = image.shape # image.shape:[nb of row -->height,nb of column --> Width]
- xmin = max(10, center[0] - int(imSize))
- xmax = min(imW - 10, center[0] + int(imSize))
- ymin = max(10, center[1] - int(imSize))
- ymax = min(imH - 10, center[1] + int(imSize))
- return np.array([[xmin, ymin], [xmax, ymax]]), np.array(image[ymin:ymax, xmin:xmax])
-
-
-def crop_image(image, crop_width, crop_height):
- image_height = image.shape[0]
- image_width = image.shape[1]
- roi_left = int(max(image_width / 2 - crop_width / 2, 0))
- roi_right = int(min(image_width / 2 + crop_width / 2, image_width))
- roi_top = int(max(image_height / 2 - crop_height / 2, 0))
- roi_bottom = int(min(image_height / 2 + crop_height / 2, image_height))
- image_cropped = image[roi_top:roi_bottom, roi_left:roi_right]
- image_cropped_height = image_cropped.shape[0]
- image_cropped_width = image_cropped.shape[1]
- return image_cropped, image_cropped_width, image_cropped_height
-
-
-def get_bbox(cnt):
- return cv2.boundingRect(cnt)
-
-
-def find_centroid_enhanced(image, last_centroid):
- # find contour takes image with 8 bit int and only one channel
- # find contour looks for white object on a black back ground
- # This looks for all contours in the thresholded image and then finds the centroid that maximizes a tracking metric
- # Tracking metric : current centroid area/(1 + dist_to_prev_centroid**2)
- contours = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2]
- centroid = False
- isCentroidFound = False
- if len(contours) > 0:
- all_centroid = []
- dist = []
- for cnt in contours:
- M = cv2.moments(cnt)
- if M["m00"] != 0:
- cx = int(M["m10"] / M["m00"])
- cy = int(M["m01"] / M["m00"])
- centroid = np.array([cx, cy])
- isCentroidFound = True
- all_centroid.append(centroid)
- dist.append([cv2.contourArea(cnt) / (1 + (centroid - last_centroid) ** 2)])
-
- if isCentroidFound:
- ind = dist.index(max(dist))
- centroid = all_centroid[ind]
-
- return isCentroidFound, centroid
-
-
-def find_centroid_enhanced_Rect(image, last_centroid):
- # find contour takes image with 8 bit int and only one channel
- # find contour looks for white object on a black back ground
- # This looks for all contours in the thresholded image and then finds the centroid that maximizes a tracking metric
- # Tracking metric : current centroid area/(1 + dist_to_prev_centroid**2)
- contours = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2]
- centroid = False
- isCentroidFound = False
- rect = False
- if len(contours) > 0:
- all_centroid = []
- dist = []
- for cnt in contours:
- M = cv2.moments(cnt)
- if M["m00"] != 0:
- cx = int(M["m10"] / M["m00"])
- cy = int(M["m01"] / M["m00"])
- centroid = np.array([cx, cy])
- isCentroidFound = True
- all_centroid.append(centroid)
- dist.append([cv2.contourArea(cnt) / (1 + (centroid - last_centroid) ** 2)])
-
- if isCentroidFound:
- ind = dist.index(max(dist))
- centroid = all_centroid[ind]
- cnt = contours[ind]
- xmin, ymin, width, height = cv2.boundingRect(cnt)
- xmin = max(0, xmin)
- ymin = max(0, ymin)
- width = min(width, imW - int(cx))
- height = min(height, imH - int(cy))
- rect = (xmin, ymin, width, height)
-
- return isCentroidFound, centroid, rect
-
-
-def find_centroid_basic(image):
- # find contour takes image with 8 bit int and only one channel
- # find contour looks for white object on a black back ground
- # This finds the centroid with the maximum area in the current frame
- contours = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2]
- centroid = False
- isCentroidFound = False
- if len(contours) > 0:
- cnt = max(contours, key=cv2.contourArea)
- M = cv2.moments(cnt)
- if M["m00"] != 0:
- cx = int(M["m10"] / M["m00"])
- cy = int(M["m01"] / M["m00"])
- centroid = np.array([cx, cy])
- isCentroidFound = True
- return isCentroidFound, centroid
-
-
-def find_centroid_basic_Rect(image):
- # find contour takes image with 8 bit int and only one channel
- # find contour looks for white object on a black back ground
- # This finds the centroid with the maximum area in the current frame and alsio the bounding rectangle. - DK 2018_12_12
- imH, imW = image.shape
- contours = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2]
- centroid = False
- isCentroidFound = False
- bbox = None
- rect = False
- if len(contours) > 0:
- # Find contour with max area
- cnt = max(contours, key=cv2.contourArea)
- M = cv2.moments(cnt)
-
- if M["m00"] != 0:
- # Centroid coordinates
- cx = int(M["m10"] / M["m00"])
- cy = int(M["m01"] / M["m00"])
- centroid = np.array([cx, cy])
- isCentroidFound = True
-
- # Find the bounding rectangle
- xmin, ymin, width, height = cv2.boundingRect(cnt)
- xmin = max(0, xmin)
- ymin = max(0, ymin)
- width = min(width, imW - xmin)
- height = min(height, imH - ymin)
-
- bbox = (xmin, ymin, width, height)
-
- return isCentroidFound, centroid, bbox
-
-
-def scale_square_bbox(bbox, scale_factor, square=True):
-
- xmin, ymin, width, height = bbox
-
- if square == True:
- min_dim = min(width, height)
- width, height = min_dim, min_dim
-
- new_width, new_height = int(scale_factor * width), int(scale_factor * height)
-
- new_xmin = xmin - (new_width - width) / 2
- new_ymin = ymin - (new_height - height) / 2
-
- new_bbox = (new_xmin, new_ymin, new_width, new_height)
- return new_bbox
-
-
-def get_image_center_width(image):
- ImShape = image.shape
- ImH, ImW = ImShape[0], ImShape[1]
- return np.array([ImW * 0.5, ImH * 0.5]), ImW
-
-
-def get_image_height_width(image):
- ImShape = image.shape
- ImH, ImW = ImShape[0], ImShape[1]
- return ImH, ImW
-
-
-def get_image_top_center_width(image):
- ImShape = image.shape
- ImH, ImWs = ImShape[0], ImShape[1]
- return np.array([ImW * 0.5, 0.25 * ImH]), ImW
-
-
-def YTracking_Objective_Function(image, color):
- # variance method
- if image.size != 0:
- if color:
- image = bgr2gray(image)
- mean, std = cv2.meanStdDev(image)
- return std[0][0] ** 2
- else:
- return 0
-
-
-def calculate_focus_measure(image):
- if len(image.shape) == 3:
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # optional
- lap = cv2.Laplacian(image, cv2.CV_16S)
- focus_measure = mean(square(lap))
- return focus_measure
-
-
-# test part
-if __name__ == "__main__":
- # Load an color image in grayscale
- rouge = np.array([[[255, 0, 0]]], dtype="uint8")
- vert = np.array([[[0, 255, 0]]], dtype="uint8")
- bleu = np.array([[[0, 0, 255]]], dtype="uint8")
-
- rouge_HSV = cv2.cvtColor(rouge, cv2.COLOR_RGB2HSV)[0][0]
- vert_HSV = cv2.cvtColor(vert, cv2.COLOR_RGB2HSV)[0][0]
- bleu_HSV = cv2.cvtColor(bleu, cv2.COLOR_RGB2HSV)[0][0]
-
- img = cv2.imread("C:/Users/Francois/Documents/11-Stage_3A/6-Code_Python/ConsoleWheel/test/rouge.jpg")
- print(img)
- img2 = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-
- couleur = bleu_HSV
- LOWER = default_lower_HSV(couleur)
- UPPER = default_upper_HSV(couleur)
-
- img3 = threshold_image(img2, LOWER, UPPER)
- cv2.imshow("image", img3)
- cv2.waitKey(0)
- cv2.destroyAllWindows()
-
-# for more than one tracked object
-"""
-def find_centroid_many(image,contour_area_min,contour_area_max):
- contours = cv2.findContours(image, cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)[-2]
- count=0
- last_centroids=[]
- for j in range(len(contours)):
- cnt = contours[j]
- if cv2.contourArea(contours[j])>contour_area_min and cv2.contourArea(contours[j]) 3:
- coord_list[1] = str(Ny - 1 - int(coord_list[1]))
- else:
- coord_list[0] = str(Ny - 1 - int(coord_list[0]))
-
- inverted_y_filename = "_".join([*coord_list, channel_name]) + "." + extension
- inverted_y_filepath = filepath.replace(filename, inverted_y_filename)
- return inverted_y_filepath
-
-
-def invert_y_in_folder(fovs_path, channel_names, Ny):
- """Given a folder with FOVs, channel names, and Ny, inverts the y-indices of all of them"""
-
- for channel in channel_names:
- channel = channel.replace(" ", "_")
- filepaths = list(glob(os.path.join(fovs_path, "*_*_*_" + channel + ".*")))
- for path in filepaths:
- inv_y_filepath = get_inverted_y_filepath(path, channel, Ny)
- os.rename(path, inv_y_filepath + "._inverted")
- for path in filepaths:
- os.rename(path + "._inverted", path)
-
-
-def invert_y_in_slide(slide_path):
- Ny = get_ny(slide_path)
- time_indices = get_time_indices(slide_path)
- channels = get_channels(slide_path)
- for t in time_indices:
- fovs_path = os.path.join(slide_path, str(t))
- invert_y_in_folder(fovs_path, channels, Ny)
-
- # invert the y-index in the CSV too
- coord_csv_path = os.path.join(fovs_path, "coordinates.csv")
- coord_df = pd.read_csv(coord_csv_path)
- coord_df["i"] = (Ny - 1) - coord_df["i"]
- coord_df.to_csv(coord_csv_path, index=False)
-
-
-if __name__ == "__main__":
- if len(sys.argv) <= 1:
- print("Must provide a path to a slide folder.")
- exit()
- invert_y_in_slide(sys.argv[1])
- print("Inverted all i/y-indices in " + sys.argv[1])
diff --git a/software/tools/script_stitch_slide.py b/software/tools/script_stitch_slide.py
deleted file mode 100644
index 250354ede..000000000
--- a/software/tools/script_stitch_slide.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import json
-import os
-from glob import glob
-from lxml import etree as ET
-import cv2
-from stitcher import stitch_slide, compute_overlap_percent
-import sys
-
-
-def get_pixel_size(
- slide_path,
- default_pixel_size=1.85,
- default_tube_lens_mm=50.0,
- default_objective_tube_lens_mm=180.0,
- default_magnification=20.0,
-):
- parameter_path = os.path.join(slide_path, "acquisition parameters.json")
- parameters = {}
- with open(parameter_path, "r") as f:
- parameters = json.load(f)
- try:
- tube_lens_mm = float(parameters["tube_lens_mm"])
- except KeyError:
- tube_lens_mm = default_tube_lens_mm
- try:
- pixel_size_um = float(parameters["sensor_pixel_size_um"])
- except KeyError:
- pixel_size_um = default_pixel_size
- try:
- objective_tube_lens_mm = float(parameters["objective"]["tube_lens_f_mm"])
- except KeyError:
- objective_tube_lens_mm = default_objective_tube_lens_mm
- try:
- magnification = float(parameters["objective"]["magnification"])
- except KeyError:
- magnification = default_magnification
-
- pixel_size_xy = pixel_size_um / (magnification / (objective_tube_lens_mm / tube_lens_mm))
-
- return pixel_size_xy
-
-
-def get_overlap(slide_path, **kwargs):
- sample_fov_path = os.path.join(slide_path, "0/*0_0_0_*.*")
- sample_fov_path = glob(sample_fov_path)[0]
- sample_fov_shape = cv2.imread(sample_fov_path).shape
- fov_width = sample_fov_shape[1]
- fov_height = sample_fov_shape[0]
-
- pixel_size_xy = get_pixel_size(slide_path, **kwargs)
-
- parameter_path = os.path.join(slide_path, "acquisition parameters.json")
- parameters = {}
- with open(parameter_path, "r") as f:
- parameters = json.load(f)
-
- dx = float(parameters["dx(mm)"]) * 1000.0
- dy = float(parameters["dy(mm)"]) * 1000.0
-
- overlap_percent = compute_overlap_percent(dx, dy, fov_width, fov_height, pixel_size_xy)
-
- return overlap_percent
-
-
-def get_time_indices(slide_path):
-
- parameter_path = os.path.join(slide_path, "acquisition parameters.json")
- parameters = {}
- with open(parameter_path, "r") as f:
- parameters = json.load(f)
-
- time_indices = list(range(int(parameters["Nt"])))
- return time_indices
-
-
-def get_channels(slide_path):
- config_xml_tree_root = ET.parse(os.path.join(slide_path, "configurations.xml")).getroot()
- channel_names = []
- for mode in config_xml_tree_root.iter("mode"):
- if mode.get("Selected") == "1":
- channel_names.append(mode.get("Name").replace(" ", "_"))
- return channel_names
-
-
-def get_z_indices(slide_path):
- parameter_path = os.path.join(slide_path, "acquisition parameters.json")
- parameters = {}
- with open(parameter_path, "r") as f:
- parameters = json.load(f)
-
- z_indices = list(range(int(parameters["Nz"])))
- return z_indices
-
-
-def get_coord_names(slide_path):
- sample_fovs_path = os.path.join(slide_path, "0/*_0_0_0_*.*")
- sample_fovs = glob(sample_fovs_path)
- coord_names = []
- for fov in sample_fovs:
- filename = fov.split("/")[-1]
- coord_name = filename.split("_0_")[0]
- coord_names.append(coord_name + "_")
- coord_names = list(set(coord_names))
- if len(coord_names) == 0:
- coord_names = [""]
- return coord_names
-
-
-def stitch_slide_from_path(slide_path, **kwargs):
- time_indices = get_time_indices(slide_path)
- z_indices = get_z_indices(slide_path)
- channels = get_channels(slide_path)
- coord_names = get_coord_names(slide_path)
- overlap_percent = get_overlap(slide_path, **kwargs)
-
- recompute_overlap = overlap_percent > 10
-
- stitch_slide(
- slide_path,
- time_indices,
- channels,
- z_indices,
- coord_names,
- overlap_percent=overlap_percent,
- reg_threshold=0.30,
- avg_displacement_threshold=2.50,
- abs_displacement_threshold=3.50,
- tile_downsampling=1.0,
- recompute_overlap=recompute_overlap,
- )
-
-
-def print_usage():
- usage_str = """
- Stitches images using Fiji. NOTE: the y-indexing of images must go from bottom to top, which is only the case for the most recent patch of Squid.
-
- Usage (to be run from software directory in your Squid install):
-
- python tools/script_stitch_slide.py PATH_TO_SLIDE_FOLDER [--sensor-size SENSOR_PIXEL_SIZE_UM] [--tube-lens TUBE_LENS_MM] [--objective-tube-lens OBJECTIVE_TUBE_LENS_MM] [--magnification MAGNIFICATION] [--help]
-
- OPTIONAL PARAMETERS:
- --help/-h : Prints this and exits.
-
- --sensor-size : Sensor pixel size in um
- --tube-lens : Your tube lens's length in mm (separate from the objective's
- tube lens focal length)
- --objective-tube-lens : Your objective's tube lens focal length in mm
- --magnification : Your objective's listed magnification
-
- The script will first try to read this parameters from acquisition parameters.json, but will default to your provided values if it can't.
- """
-
- print(usage_str)
-
-
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- print("No slide path name provided!")
- print_usage()
- exit()
-
- parameter_names = {
- "--sensor-size": "default_pixel_size",
- "--tube-lens": "default_tube_lens_mm",
- "--objective-tube-lens": "default_objective_tube_lens_mm",
- "--magnification": "default_magnification",
- }
-
- param_list = list(parameter_names.keys())
-
- user_kwargs = {}
-
- if "--help" in sys.argv or "-h" in sys.argv:
- print_usage()
- exit()
-
- for i in range(len(sys.argv)):
- if sys.argv[i] in param_list:
- try:
- arg_value = float(sys.argv[i + 1])
- user_kwargs[parameter_names[sys.argv[i]]] = arg_value
- except (IndexError, ValueError):
- print("Malformed argument, exiting.")
- exit()
-
- stitch_slide_from_path(sys.argv[1], **user_kwargs)
diff --git a/software/tools/stitcher.py b/software/tools/stitcher.py
deleted file mode 100644
index d2373ab52..000000000
--- a/software/tools/stitcher.py
+++ /dev/null
@@ -1,438 +0,0 @@
-import cv2
-import imagej, scyjava
-import os
-import shutil
-import tifffile
-from glob import glob
-import numpy as np
-import multiprocessing as mp
-
-JVM_MAX_MEMORY_GB = 4.0
-
-
-def compute_overlap_percent(deltaX, deltaY, image_width, image_height, pixel_size_xy, min_overlap=0):
- """Helper function to calculate percent overlap between images in
- a grid"""
- shift_x = deltaX / pixel_size_xy
- shift_y = deltaY / pixel_size_xy
- overlap_x = max(0, image_width - shift_x)
- overlap_y = max(0, image_height - shift_y)
- overlap_x = overlap_x * 100.0 / image_width
- overlap_y = overlap_y * 100.0 / image_height
- overlap = max(min_overlap, overlap_x, overlap_y)
- return overlap
-
-
-def stitch_slide_mp(*args, **kwargs):
- ctx = mp.get_context("spawn")
- stitch_process = ctx.Process(target=stitch_slide, args=args, kwargs=kwargs)
- stitch_process.start()
- return stitch_process
-
-
-def migrate_tile_config(
- fovs_path, coord_name, channel_name_source, z_index_source, channel_name_target, z_index_target
-):
- channel_name_source = channel_name_source.replace(" ", "_")
- channel_name_target = channel_name_target.replace(" ", "_")
-
- if z_index_source == z_index_target and channel_name_source == channel_name_target:
- raise RuntimeError("Source and target for channel/z-index migration are the same!")
-
- tile_conf_name_source = (
- "TileConfiguration_COORD_"
- + coord_name
- + "_Z_"
- + str(z_index_source)
- + "_"
- + channel_name_source
- + ".registered.txt"
- )
- tile_conf_name_target = (
- "TileConfiguration_COORD_"
- + coord_name
- + "_Z_"
- + str(z_index_target)
- + "_"
- + channel_name_target
- + ".registered.txt"
- )
- tile_config_source_path = os.path.join(fovs_path, tile_conf_name_source)
-
- if not os.path.isfile(tile_config_source_path):
- tile_config_source_path = tile_config_source_path.replace(".registered.txt", ".txt")
-
- assert os.path.isfile(tile_config_source_path)
-
- tile_config_target_path = os.path.join(fovs_path, tile_conf_name_target)
-
- tile_conf_target = open(tile_config_target_path, "w")
-
- with open(tile_config_source_path, "r") as tile_conf_source:
- for line in tile_conf_source:
- if line.startswith("#") or line.startswith("dim") or len(line) <= 1:
- tile_conf_target.write(line)
- continue
- line_to_write = line.replace(
- "_" + str(z_index_source) + "_" + channel_name_source,
- "_" + str(z_index_target) + "_" + channel_name_target,
- )
- tile_conf_target.write(line_to_write)
-
- tile_conf_target.close()
-
- return tile_conf_name_target
-
-
-def stitch_slide(
- slide_path,
- time_indices,
- channels,
- z_indices,
- coord_names=[""],
- overlap_percent=10,
- reg_threshold=0.30,
- avg_displacement_threshold=2.50,
- abs_displacement_threshold=3.50,
- tile_downsampling=0.5,
- recompute_overlap=False,
- **kwargs
-):
- st = Stitcher()
- st.stitch_slide(
- slide_path,
- time_indices,
- channels,
- z_indices,
- coord_names,
- overlap_percent,
- reg_threshold,
- avg_displacement_threshold,
- abs_displacement_threshold,
- tile_downsampling,
- recompute_overlap,
- **kwargs
- )
-
-
-class Stitcher:
- def __init__(self):
- scyjava.config.add_option("-Xmx" + str(int(JVM_MAX_MEMORY_GB)) + "g")
- self.ij = imagej.init("sc.fiji:fiji", mode="headless")
-
- def stitch_slide(
- self,
- slide_path,
- time_indices,
- channels,
- z_indices,
- coord_names=[""],
- overlap_percent=10,
- reg_threshold=0.30,
- avg_displacement_threshold=2.50,
- abs_displacement_threshold=3.50,
- tile_downsampling=0.5,
- recompute_overlap=False,
- **kwargs
- ):
- for time_index in time_indices:
- self.stitch_single_time_point(
- slide_path,
- time_index,
- channels,
- z_indices,
- coord_names,
- overlap_percent,
- reg_threshold,
- avg_displacement_threshold,
- abs_displacement_threshold,
- tile_downsampling,
- recompute_overlap,
- **kwargs
- )
-
- def stitch_single_time_point(
- self,
- slide_path,
- time_index,
- channels,
- z_indices,
- coord_names=[""],
- overlap_percent=10,
- reg_threshold=0.30,
- avg_displacement_threshold=2.50,
- abs_displacement_threshold=3.50,
- tile_downsampling=0.5,
- recompute_overlap=False,
- **kwargs
- ):
- fovs_path = os.path.join(slide_path, str(time_index))
- for coord_name in coord_names:
- already_registered = False
- registered_z_index = None
- registered_channel_name = None
- for channel_name in channels:
- for z_index in z_indices:
- if already_registered:
- migrate_tile_config(
- fovs_path,
- coord_name,
- registered_channel_name,
- registered_z_index,
- channel_name.replace(" ", "_"),
- z_index,
- )
- output_dir = self.stitch_single_channel_from_tile_config(
- fovs_path, channel_name, z_index, coord_name
- )
- combine_stitched_channels(output_dir, **kwargs)
- else:
- output_dir = self.stitch_single_channel(
- fovs_path,
- channel_name,
- z_index,
- coord_name,
- overlap_percent,
- reg_threshold,
- avg_displacement_threshold,
- abs_displacement_threshold,
- tile_downsampling,
- recompute_overlap,
- )
- combine_stitched_channels(output_dir, **kwargs)
- if not already_registered:
- already_registered = True
- registered_z_index = z_index
- registered_channel_name = channel_name.replace(" ", "_")
-
- def stitch_single_channel_from_tile_config(self, fovs_path, channel_name, z_index, coord_name):
- """
- Stitches images using grid/collection stitching, reading registered
- positions from a tile configuration path that has been migrated from an
- already-registered channel/z-level at the same coordinate name
- """
- channel_name = channel_name.replace(" ", "_")
- tile_conf_name = (
- "TileConfiguration_COORD_" + coord_name + "_Z_" + str(z_index) + "_" + channel_name + ".registered.txt"
- )
- assert os.path.isfile(os.path.join(fovs_path, tile_conf_name))
-
- stitching_output_dir = "COORD_" + coord_name + "_Z_" + str(z_index) + "_" + channel_name + "_stitched/"
-
- stitching_output_dir = os.path.join(fovs_path, stitching_output_dir)
-
- os.makedirs(stitching_output_dir, exist_ok=True)
-
- stitching_params = {
- "type": "Positions from file",
- "order": "Defined by TileConfiguration",
- "fusion_mode": "Linear Blending",
- "ignore_z_stage": True,
- "downsample_tiles": False,
- "directory": fovs_path,
- "layout_file": tile_conf_name,
- "fusion_method": "Linear Blending",
- "regression_threshold": "0.30",
- "max/avg_displacement_threshold": "2.50",
- "absolute_displacement_threshold": "3.50",
- "compute_overlap": False,
- "computation_parameters": "Save computation time (but use more RAM)",
- "image_output": "Write to disk",
- "output_directory": stitching_output_dir,
- }
-
- plugin = "Grid/Collection stitching"
-
- self.ij.py.run_plugin(plugin, stitching_params)
-
- return stitching_output_dir
-
- def stitch_single_channel(
- self,
- fovs_path,
- channel_name,
- z_index,
- coord_name="",
- overlap_percent=10,
- reg_threshold=0.30,
- avg_displacement_threshold=2.50,
- abs_displacement_threshold=3.50,
- tile_downsampling=0.5,
- recompute_overlap=False,
- ):
- """
- Stitches images using grid/collection stitching with filename-defined
- positions following the format that squid saves multipoint acquisitions
- in. Requires that the filename-indicated grid positions go top-to-bottom
- on the y axis and left-to-right on the x axis (this is handled by
- the MultiPointController code in control/core.py). Must be passed
- the folder containing the image files.
- """
- channel_name = channel_name.replace(" ", "_")
-
- file_search_name = coord_name + "0_0_" + str(z_index) + "_" + channel_name + ".*"
-
- ext_glob = list(glob(os.path.join(fovs_path, file_search_name)))
-
- file_ext = ext_glob[0].split(".")[-1]
-
- y_length_pattern = coord_name + "*_0_" + str(z_index) + "_" + channel_name + "." + file_ext
-
- x_length_pattern = coord_name + "0_*_" + str(z_index) + "_" + channel_name + "." + file_ext
-
- grid_size_y = len(list(glob(os.path.join(fovs_path, y_length_pattern))))
-
- grid_size_x = len(list(glob(os.path.join(fovs_path, x_length_pattern))))
-
- stitching_filename_pattern = coord_name + "{y}_{x}_" + str(z_index) + "_" + channel_name + "." + file_ext
-
- stitching_output_dir = "COORD_" + coord_name + "_Z_" + str(z_index) + "_" + channel_name + "_stitched/"
-
- tile_conf_name = "TileConfiguration_COORD_" + coord_name + "_Z_" + str(z_index) + "_" + channel_name + ".txt"
-
- stitching_output_dir = os.path.join(fovs_path, stitching_output_dir)
-
- os.makedirs(stitching_output_dir, exist_ok=True)
-
- sample_tile_name = coord_name + "0_0_" + str(z_index) + "_" + channel_name + "." + file_ext
- sample_tile_shape = cv2.imread(os.path.join(fovs_path, sample_tile_name)).shape
-
- tile_downsampled_width = int(sample_tile_shape[1] * tile_downsampling)
- tile_downsampled_height = int(sample_tile_shape[0] * tile_downsampling)
- stitching_params = {
- "type": "Filename defined position",
- "order": "Defined by filename",
- "fusion_mode": "Linear Blending",
- "grid_size_x": grid_size_x,
- "grid_size_y": grid_size_y,
- "first_file_index_x": str(0),
- "first_file_index_y": str(0),
- "ignore_z_stage": True,
- "downsample_tiles": False,
- "tile_overlap": overlap_percent,
- "directory": fovs_path,
- "file_names": stitching_filename_pattern,
- "output_textfile_name": tile_conf_name,
- "fusion_method": "Linear Blending",
- "regression_threshold": str(reg_threshold),
- "max/avg_displacement_threshold": str(avg_displacement_threshold),
- "absolute_displacement_threshold": str(abs_displacement_threshold),
- "compute_overlap": recompute_overlap,
- "computation_parameters": "Save computation time (but use more RAM)",
- "image_output": "Write to disk",
- "output_directory": stitching_output_dir, # ,
- #'x':str(tile_downsampling),
- #'y':str(tile_downsampling),
- #'width':str(tile_downsampled_width),
- #'height':str(tile_downsampled_height),
- #'interpolation':'Bicubic average'
- }
-
- plugin = "Grid/Collection stitching"
-
- self.ij.py.run_plugin(plugin, stitching_params)
-
- return stitching_output_dir
-
-
-def images_identical(im_1, im_2):
- """Return True if two opencv arrays are exactly the same"""
- return im_1.shape == im_2.shape and not (np.bitwise_xor(im_1, im_2).any())
-
-
-def combine_stitched_channels(
- stitched_image_folder_path, write_multiscale_tiff=False, pixel_size_um=1.0, tile_side_length=1024, subresolutions=3
-):
- """Combines the three channel images created into one TIFF. Currently
- not recommended to run this with multiscale TIFF enabled, combining
- all channels/z-levels in one region of the acquisition into one OME-TIFF
- to be done later."""
-
- c1 = cv2.imread(os.path.join(stitched_image_folder_path, "img_t1_z1_c1"))
-
- c2 = cv2.imread(os.path.join(stitched_image_folder_path, "img_t1_z1_c2"))
-
- c3 = cv2.imread(os.path.join(stitched_image_folder_path, "img_t1_z1_c3"))
-
- combine_to_mono = False
-
- if c2 is None or c3 is None:
- combine_to_mono = True
-
- if write_multiscale_tiff:
- output_path = os.path.join(stitched_image_folder_path, "stitched_img.ome.tif")
- else:
- output_path = os.path.join(stitched_image_folder_path, "stitched_img.tif")
-
- if not combine_to_mono:
- if images_identical(c1, c2) and images_identical(c2, c3):
- combine_to_mono = True
-
- if not combine_to_mono:
- c1 = c1[:, :, 0]
- c2 = c2[:, :, 1]
- c3 = c3[:, :, 2]
- if write_multiscale_tiff:
- data = np.stack((c1, c2, c3), axis=0)
- else:
- data = np.stack((c1, c2, c3), axis=-1)
- axes = "CYX"
- channels = {"Name": ["Channel 1", "Channel 2", "Channel 3"]}
- else:
- data = c1[:, :, 0]
- axes = "YX"
- channels = None
-
- metadata = {
- "axes": axes,
- "SignificantBits": 16 if data.dtype == np.uint8 else 8,
- "PhysicalSizeX": pixel_size_um,
- "PhysicalSizeY": pixel_size_um,
- "PhysicalSizeXUnit": "um",
- "PhysicalSizeYUnit": "um",
- }
- if channels is not None:
- metadata["Channel"] = channels
-
- options = dict(
- photometric="rgb" if not combine_to_mono else "minisblack",
- tile=(tile_side_length, tile_side_length),
- compression="jpeg",
- resolutionunit="CENTIMETER",
- maxworkers=2,
- )
-
- if write_multiscale_tiff:
- with tifffile.TiffWriter(output_path, bigtiff=True) as tif:
- tif.write(
- data,
- subifds=subresolutions,
- resolution=(1e4 / pixel_size_um, 1e4 / pixel_size_um),
- metadata=metadata,
- **options
- )
- for level in range(subresolutions):
- mag = 2 ** (level + 1)
- if combine_to_mono:
- subdata = data[::mag, ::mag]
- else:
- subdata = data[:, ::mag, ::mag]
- tif.write(
- subdata, subfiletype=1, resolution=(1e4 / mag / pixel_size_um, 1e3 / mag / pixel_size_um), **options
- )
-
- if combine_to_mono:
- thumbnail = (data[::8, ::8] >> 2).astype("uint8")
- else:
- thumbnail = (data[0, ::8, ::8] >> 2).astype("uint8")
- tif.write(thumbnail, metadata={"Name": "thumbnail"})
- else:
- cv2.imwrite(output_path, data)
-
- channel_files = [os.path.join(stitched_image_folder_path, "img_t1_z1_c") + str(i + 1) for i in range(3)]
-
- for filename in channel_files:
- try:
- os.remove(filename)
- except FileNotFoundError:
- pass