diff --git a/py4DSTEM/io/datastructure/py4dstem/datacube.py b/py4DSTEM/io/datastructure/py4dstem/datacube.py index 8dd8c19d6..01dc274eb 100644 --- a/py4DSTEM/io/datastructure/py4dstem/datacube.py +++ b/py4DSTEM/io/datastructure/py4dstem/datacube.py @@ -37,7 +37,8 @@ def __init__( Q_pixel_size: Optional[Union[float,list]] = 1, Q_pixel_units: Optional[Union[str,list]] = 'pixels', slicelabels: Optional[Union[bool,list]] = None, - calibration: Optional = None, + calibration: Optional[Calibration] = None, + stack_pointer = None, ): """ Accepts: @@ -100,6 +101,12 @@ def __init__( self.tree['calibration'].set_Q_pixel_size( Q_pixel_size ) self.tree['calibration'].set_Q_pixel_units( Q_pixel_units ) + # Add attribute of stack pointer for Dask related stuff + # Tacking this here for now + # this can also be used as a quick check for + self.stack_pointer = stack_pointer + + diff --git a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py index c245b9590..2050c4f94 100644 --- a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py +++ b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py @@ -777,11 +777,15 @@ def find_Bragg_disks( CUDA = False, CUDA_batched = True, distributed = None, + dask = True, + dask_params = None, _qt_progress_bar = None, name = 'braggvectors', returncalc = True, + + **kwargs ): """ Finds the Bragg disks by cross correlation with `template`. @@ -879,6 +883,7 @@ def find_Bragg_disks( processing if distributed is None, which is the default, processing will be in serial + dask (dict): if not None ... TODO _qt_progress_bar (QProgressBar instance): used only by the GUI for serial execution name (str): name for the output BraggVectors @@ -924,8 +929,11 @@ def find_Bragg_disks( CUDA = CUDA, CUDA_batched = CUDA_batched, distributed = distributed, + dask = dask, + dask_params = dask_params, _qt_progress_bar = _qt_progress_bar, + **kwargs ) diff --git a/py4DSTEM/io/native/legacy/read_v0_12.py b/py4DSTEM/io/native/legacy/read_v0_12.py index 9e0e82ac4..cc70c8be8 100644 --- a/py4DSTEM/io/native/legacy/read_v0_12.py +++ b/py4DSTEM/io/native/legacy/read_v0_12.py @@ -1,8 +1,12 @@ # Reader for py4DSTEM v0.12 files +from inspect import stack import h5py import numpy as np from os.path import splitext, exists + +import dask.array as da + from py4DSTEM.io.native.legacy.read_utils import is_py4DSTEM_file, get_py4DSTEM_topgroups, get_py4DSTEM_version, version_is_geq from py4DSTEM.io.native.legacy.read_utils_v0_12 import get_py4DSTEM_dataobject_info from py4DSTEM.io.datastructure import DataCube @@ -12,6 +16,7 @@ from py4DSTEM.io.datastructure import PointListArray from py4DSTEM import tqdmnd + def read_v0_12(fp, **kwargs): """ File reader for files written by py4DSTEM v0.12. Precise behavior is detemined by which @@ -287,8 +292,14 @@ def get_datacube_from_grp(g,mem='RAM',binfactor=1,bindtype=None): elif (mem, binfactor) == ("MEMMAP", 1): data = g['data'] stack_pointer = None + elif (mem, binfactor) == ("DASK", 1): + stack_pointer = g['data'] + shape = g['data'].shape + + data = da.from_array(stack_pointer, chunks=(1,1,shape[2], shape[3])) + name = g.name.split('/')[-1] - return DataCube(data=data,name=name) + return DataCube(data=data,name=name, stack_pointer=stack_pointer) def get_diffractionslice_from_grp(g): diff --git a/py4DSTEM/process/diskdetection/__init__.py b/py4DSTEM/process/diskdetection/__init__.py index ea16f0f74..8cb5a7de4 100644 --- a/py4DSTEM/process/diskdetection/__init__.py +++ b/py4DSTEM/process/diskdetection/__init__.py @@ -2,5 +2,4 @@ from py4DSTEM.process.diskdetection.braggvectormap import * #from .diskdetection_aiml import * -#from .diskdetection_parallel_new import * - +from py4DSTEM.process.diskdetection.diskdetection_parallel_new import * diff --git a/py4DSTEM/process/diskdetection/diskdetection.py b/py4DSTEM/process/diskdetection/diskdetection.py index 3b69245d9..c6deb8977 100644 --- a/py4DSTEM/process/diskdetection/diskdetection.py +++ b/py4DSTEM/process/diskdetection/diskdetection.py @@ -4,6 +4,7 @@ import numpy as np from scipy.ndimage import gaussian_filter + from py4DSTEM.io.datastructure.py4dstem import DataCube, QPoints, BraggVectors from py4DSTEM.process.utils.get_maxima_2D import get_maxima_2D from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT @@ -13,6 +14,7 @@ + def find_Bragg_disks( data, template, @@ -34,9 +36,11 @@ def find_Bragg_disks( CUDA = False, CUDA_batched = True, distributed = None, + dask : bool = False, + dask_params : dict = None, _qt_progress_bar = None, - ): + **kws): """ Finds the Bragg disks in the diffraction patterns represented by `data` by cross/phase correlatin with `template`. @@ -53,10 +57,10 @@ def find_Bragg_disks( and returns a instance or length N list of instances of QPoints For disk detection on a full DataCube, the calculation can be performed - on the CPU, GPU or a cluster. By default the CPU is used. If `CUDA` is set - to True, tries to use the GPU. If `CUDA_batched` is also set to True, - batches the FFT/IFFT computations on the GPU. For distribution to a cluster, - distributed must be set to a dictionary, with contents describing how + on the CPU, GPU, or using dask or ipyparallel. By default the CPU is used. + If `CUDA` is set to True, tries to use the GPU. If `CUDA_batched` is also set + to True, batches the FFT/IFFT computations on the GPU. For distribution to a + cluster, distributed must be set to a dictionary, with contents describing how distributed processing should be performed - see below for details. @@ -141,6 +145,9 @@ def find_Bragg_disks( processing if distributed is None, which is the default, processing will be in serial + dask (dict): if not None, indictates dask should be used. Must then be a + dictionary with arguments to pass to the dask detection function. + Valid arguments are (...). See docstring for (...) for details. _qt_progress_bar (QProgressBar instance): used only by the GUI for serial execution @@ -153,6 +160,8 @@ def find_Bragg_disks( - a (DataCube,rx,ry) 3-tuple, returns a list of QPoints instances """ + # TODO add checks about ensuring Dask and Cuda aren't both passed i.e. ensure user knows + # behaviour # parse args @@ -196,11 +205,13 @@ def find_Bragg_disks( mode = 'dc_GPU' else: mode = 'dc_GPU_batched' + elif dask: + mode = 'dc_dask' else: x = _parse_distributed(distributed) connect, data_file, cluster_path, distributed_mode = x if distributed_mode == 'dask': - mode = 'dc_dask' + mode = 'dc_dask_old' elif distributed_mode == 'ipyparallel': mode = 'dc_ipyparallel' else: @@ -222,6 +233,9 @@ def find_Bragg_disks( kws['connect'] = connect kws['data_file'] = data_file kws['cluster_path'] = cluster_path + # dask kwargs + if dask_params is not None: + kws.update(dask_params) # run and return ans = fn( @@ -243,7 +257,8 @@ def find_Bragg_disks( return ans - +# TODO add extra skeleton func which imports betaparallel and returns it if added dask_cuda +# TODO add MLAI at some point def _get_function_dictionary(): d = { @@ -252,14 +267,19 @@ def _get_function_dictionary(): "dc_CPU" : _find_Bragg_disks_CPU, "dc_GPU" : _find_Bragg_disks_CUDA_unbatched, "dc_GPU_batched" : _find_Bragg_disks_CUDA_batched, - "dc_dask" : _find_Bragg_disks_dask, + "dc_dask_old" : _find_Bragg_disks_dask, + # "dc_dask" : beta_parallel_disk_detection, + "dc_dask" : place_holder, + "dc_ipyparallel" : _find_Bragg_disks_ipp, } return d - - +# TODO change the name to something better +def place_holder(): + from .diskdetection_parallel_new import beta_parallel_disk_detection + return beta_parallel_disk_detection # Single diffraction pattern @@ -721,6 +741,7 @@ def _parse_distributed(distributed): elif "dask" in distributed: mode = 'dask' + print(type(distributed)) if "client" in distributed["dask"]: connect = distributed["dask"]["client"] else: diff --git a/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py b/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py index 8d76a2bbc..6210dec66 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py +++ b/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py @@ -7,8 +7,12 @@ from dask import delayed import dask #import dask.bag as db -from py4DSTEM.io.datastructure import PointListArray, PointList -from py4DSTEM.process.diskdetection.diskdetection import _find_Bragg_disks_single_DP_FK + + +from py4DSTEM.io.datastructure.py4dstem import DataCube, QPoints, BraggVectors, PointListArray, PointList + +from py4DSTEM.process.diskdetection.diskdetection import _find_Bragg_disks_single + from py4DSTEM.io import PointListArray, PointList, datastructure import time from dask.diagnostics import ProgressBar @@ -50,7 +54,6 @@ def register_dill_serializer(): register_serialization_family('dill', dill_dumps, dill_loads) return None - #### END OF SERAILISERS #### @@ -62,7 +65,7 @@ def register_dill_serializer(): # TODO add ML-AI version def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args,**kwargs): # THis is needed as _find_Bragg_disks_single_DP_FK takes 2D array these arrays have the wrong shape - return _find_Bragg_disks_single_DP_FK(arr[0,0], *args, **kwargs) + return _find_Bragg_disks_single(arr[0,0], *args, **kwargs) #### END OF DASK WRAPPER FUNCTIONS #### @@ -75,7 +78,7 @@ def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args,**kwargs): def beta_parallel_disk_detection(dataset, probe, - #rxmin=None, # these would allow selecting a sub section + #rxmin=None, # these would allow selecting a sub section # probably not a useful case #rxmax=None, #rymin=None, #rymax=None, @@ -125,13 +128,21 @@ def beta_parallel_disk_detection(dataset, # ... dask stuff. #TODO add assert statements and other checks. Think about reordering opperations + ## adding assert statement to make sure peaks not passed as a keyword argument + assert 'peaks' not in kwargs, "peaks must not be passed as a keyword arguement" + + # Check to see if a dask client has been passed. + # if no client passed if dask_client == None: + # check if parameters are passed create a cluster, and pass them to dask client. if dask_client_params !=None: dask.config.set({'distributed.worker.memory.spill': False, 'distributed.worker.memory.target': False}) cluster = LocalCluster(**dask_client_params) dask_client = Client(cluster, **dask_client_params) + + # if no parameters are passed create them with some default values else: # AUTO MAGICALLY SET? # LET DASK SET? @@ -154,8 +165,10 @@ def beta_parallel_disk_detection(dataset, pass - # Probe stuff + #### Probe stuff + # check that the probe shape is correct. assert (probe.shape == dataset.data.shape[2:]), "Probe and Diffraction Pattern Shapes are Mismatched" + if probe_type != "FT": #TODO clean up and pull out redudant parts #if probe.dtype != (np.complex128 or np.complex64 or np.complex256): @@ -192,7 +205,7 @@ def beta_parallel_disk_detection(dataset, # loop over the dataset_delayed and create a delayed function of for x in np.ndindex(dataset_delayed.shape): temp = delayed(_find_Bragg_disks_single_DP_FK_dask_wrapper)(dataset_delayed[x], - probe_kernel_FT=dask_probe_delayed[0,0], + template=dask_probe_delayed[0,0], #probe_kernel_FT=delayed_probe_kernel_FT, *args, **kwargs) #passing through args from earlier or should I use #corrPower=corrPower, @@ -207,28 +220,38 @@ def beta_parallel_disk_detection(dataset, output = dask_client.gather(_temp_peaks) # gather the future objects - coords = [('qx',float),('qy',float),('intensity',float)] - peaks = PointListArray(coordinates=coords, shape=dataset.data.shape[:-2]) - - #temp_peaks[0][0] + dtype = [('qx',float),('qy',float),('intensity',float)] + peaks = PointListArray(dtype=dtype, shape=dataset.data.shape[:-2]) + # operating over a list so we need the size (0->count) and re-create the probe positions (0->rx,0->ry), + # count is the size of the list for (count,(rx, ry)) in zip([i for i in range(dataset.data[...,0,0].size)],np.ndindex(dataset.data.shape[:-2])): #peaks.get_pointlist(rx, ry).add_pointlist(temp_peaks[0][count]) #peaks.get_pointlist(rx, ry).add_pointlist(output[count][0]) - peaks.get_pointlist(rx, ry).add_pointlist(output[count]) + peaks.get_pointlist(rx, ry).add(output[count]) + + + # create a BraggVectors obj + braggvectors = BraggVectors(dataset.Rshape, dataset.Qshape) + # populate the uncalibrated object with the + braggvectors._v_uncal = peaks + + + # TODO Remove ability to return the clinet + # TODO RE-VISIT IF NEEDED TO RETURN - # Clean up + # Clean up dask related stuff dask_client.cancel(_temp_peaks) # removes from the dask workers del _temp_peaks # deletes the object if close_dask_client: dask_client.close() - return peaks + return braggvectors elif close_dask_client == False and return_dask_client == True: - return peaks, dask_client + return braggvectors, dask_client elif close_dask_client and return_dask_client == False: - return peaks + return braggvectors else: print('Dask Client in unknown state, this may result in unpredicitable behaviour later') - return peaks + return braggvectors diff --git a/py4DSTEM/test/dask/diskdetection.py b/py4DSTEM/test/dask/diskdetection.py new file mode 100644 index 000000000..c91ed52b4 --- /dev/null +++ b/py4DSTEM/test/dask/diskdetection.py @@ -0,0 +1,15 @@ +# Test dask disk detection functionality + +# Devices use cases: +# - local machine +# - cluster + +# Storage use case: +# - as dask array +# - as mem map +# - in RAM + +# Future cases: +# - GPU + dask + + diff --git a/py4DSTEM/test/dask/io.py b/py4DSTEM/test/dask/io.py new file mode 100644 index 000000000..5aca1e131 --- /dev/null +++ b/py4DSTEM/test/dask/io.py @@ -0,0 +1,11 @@ +# Test dask i/o functionality + +# Cases: +# - load a datacube "normally", i.e. into memory, and then convert it to a dask array +# - load a datacube directly from .h5 to a mapped dask array +# - load a datacube into a numpy memmap, and then work on that as a dask array + + + + + diff --git a/py4DSTEM/test/dask/virtualimage.py b/py4DSTEM/test/dask/virtualimage.py new file mode 100644 index 000000000..45dcee431 --- /dev/null +++ b/py4DSTEM/test/dask/virtualimage.py @@ -0,0 +1,11 @@ +# Test virtual imaging with dask + + +# Do speed testing + + +# Use cases to test: +# - no center shifting +# - center shifting + +