Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/example_data/LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,14 @@ The dataset contained in the file:
- AuBalls_700ms_30nmStep_3_6SS_filter.cxi

is sourced from https://cxidb.org/id-65.html, and was made available by the original authors under the CC0 Public Domain Dedication Waiver. This data was deposited into the CXIDB by Stefano Marchesini.

The dataset contained in the file:

- PETRAIII_P25_Near_Field_Ptycho.cxi

is sourced from from [this](http://dx.doi.org/10.5281/zenodo.17899482) Zenodo upload, and was collected at the P25 beamline of the PETRA III light source at DESY. The following list of experiment participants were involved:


Nazanin Samadi, Aknur Karabay, Pengju Sheng, Canrong Qiu, Kathryn Spiers, Wenhui Xu, Abraham Levitan, and Manuel Guizar-Sicairos.

The dataset is made available under a CC BY 4.0 License, defined at https://creativecommons.org/licenses/by/4.0/.
56 changes: 56 additions & 0 deletions examples/near_field_ptycho.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import cdtools
from matplotlib import pyplot as plt

filename = 'example_data/PETRAIII_P25_Near_Field_Ptycho.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

dataset.inspect()
plt.show()

# Setting near_field equal to True uses an angular spectrum propagator in
# lieu of the default Fourier-transform propagator for far-field ptychography.
#
# If propagation_distance is not set, it assumes that the geometry is
# a standard near-field geometry with flat illumination wavefronts, and
# pulls the sample to detector distance from dataset.distance
#
# If propagation_distance is set, it assumes a Fresnel scaling theorem
# geometry with:
#
# - distance (from the dataset): The sample-to-detector distance
# - propagation_distance: The focus-to-sample distance
#
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=1,
near_field=True,
propagation_distance=3.65e-3, # 3.65 downstream from focus
units='um', # Set the units for the live plots
obj_view_crop=-35,
)

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

model.inspect(dataset)

recon = cdtools.reconstructors.AdamReconstructor(model, dataset)

for loss in recon.optimize(100, lr=0.04, batch_size=10):
print(model.report())
# Plotting is expensive, so we only do it every tenth epoch
if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)

# This orthogonalizes the recovered probe modes
model.tidy_probes()

model.inspect(dataset)
model.compare(dataset)
plt.show()
153 changes: 122 additions & 31 deletions src/cdtools/models/fancy_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(self,
exponentiate_obj=False,
phase_only=False,
dtype=t.float32,
obj_view_crop=0
obj_view_crop=0,
near_field=False,
angular_spectrum_propagator=None,
inv_angular_spectrum_propagator=None,
):

super(FancyPtycho, self).__init__()
Expand Down Expand Up @@ -80,6 +83,25 @@ def __init__(self,
self.register_buffer('phase_only',
t.as_tensor(phase_only, dtype=bool))

self.register_buffer('near_field',
t.as_tensor(near_field, dtype=bool))

if angular_spectrum_propagator is None:
self.angular_spectrum_propagator = None
else:
self.register_buffer(
'angular_spectrum_propagator',
t.as_tensor(angular_spectrum_propagator, dtype=t.complex64)
)

if inv_angular_spectrum_propagator is None:
self.inv_angular_spectrum_propagator = None
else:
self.register_buffer(
'inv_angular_spectrum_propagator',
t.as_tensor(inv_angular_spectrum_propagator, dtype=t.complex64)
)

# Not sure how to make this a buffer...
self.units = units

Expand Down Expand Up @@ -230,6 +252,7 @@ def from_dataset(cls,
phase_only=False,
obj_view_crop=None,
obj_padding=200,
near_field=False,
):

wavelength = dataset.wavelength
Expand All @@ -247,16 +270,86 @@ def from_dataset(cls,

dataset.get_as(*get_as_args[0], **get_as_args[1])

# Then, generate the probe geometry from the dataset
ewg = tools.initializers.exit_wave_geometry
obj_basis = ewg(
det_basis,
det_shape,
wavelength,
distance,
oversampling=oversampling,
)
if not near_field:
# Then, generate the probe geometry from the dataset
ewg = tools.initializers.exit_wave_geometry
obj_basis = ewg(
det_basis,
det_shape,
wavelength,
distance,
oversampling=oversampling,
)

probe = tools.initializers.SHARP_style_probe(
dataset,
propagation_distance=propagation_distance,
oversampling=oversampling,
)
angular_spectrum_propagator=None
inv_angular_spectrum_propagator=None

else:
if propagation_distance is None or propagation_distance==0:
# In this case, we assume that we're genuinely in a near
# field geometry, such that z_eff = z and there is no
# magnification
obj_basis = t.as_tensor(det_basis) / oversampling
angular_spectrum_propagator = \
tools.propagators.generate_generalized_angular_spectrum_propagator(
[d*oversampling for d in det_shape],
obj_basis,
wavelength,
np.array([0,0,distance]),
)
inv_angular_spectrum_propagator = \
t.conj(angular_spectrum_propagator)
inv_angular_spectrum_propagator_init = t.conj(
tools.propagators.generate_generalized_angular_spectrum_propagator(
det_shape,
obj_basis,
wavelength,
np.array([0,0,distance]),
)
)
else:
# In this case, we assume that we're in a projection geometry
# with a z_eff based on propagation_distance and a nonzero
# magnification
M = (propagation_distance + distance) / propagation_distance
z_eff = distance / M

obj_basis = t.as_tensor(det_basis) / (oversampling * M)
angular_spectrum_propagator = \
tools.propagators.generate_generalized_angular_spectrum_propagator(
[d * oversampling for d in det_shape],
obj_basis,
wavelength,
np.array([0,0,z_eff]),
)
inv_angular_spectrum_propagator = t.conj(
angular_spectrum_propagator)
inv_angular_spectrum_propagator_init = t.conj(
tools.propagators.generate_generalized_angular_spectrum_propagator(
det_shape,
obj_basis,
wavelength,
np.array([0,0,z_eff]),
)
)

backward_propagator = lambda wavefields: \
tools.propagators.near_field(
wavefields,
inv_angular_spectrum_propagator_init
)

probe = tools.initializers.SHARP_style_near_field_probe(
dataset,
backward_propagator=backward_propagator,
oversampling=oversampling,
)

if hasattr(dataset, 'sample_info') and \
dataset.sample_info is not None and \
'orientation' in dataset.sample_info:
Expand Down Expand Up @@ -289,21 +382,6 @@ def from_dataset(cls,
padding=obj_padding,
)

# Finally, initialize the probe and object using this information
if probe_shape is None:
probe = tools.initializers.SHARP_style_probe(
dataset,
propagation_distance=propagation_distance,
oversampling=oversampling,
)
else:
probe = tools.initializers.gaussian_probe(
dataset,
obj_basis,
probe_shape,
propagation_distance=propagation_distance,
)

if hasattr(dataset, 'background') and dataset.background is not None:
background = t.sqrt(dataset.background)
else:
Expand Down Expand Up @@ -436,7 +514,10 @@ def from_dataset(cls,
simulate_finite_pixels=simulate_finite_pixels,
phase_only=phase_only,
exponentiate_obj=exponentiate_obj,
obj_view_crop=obj_view_crop
obj_view_crop=obj_view_crop,
near_field=near_field,
angular_spectrum_propagator=angular_spectrum_propagator,
inv_angular_spectrum_propagator=inv_angular_spectrum_propagator,
)


Expand Down Expand Up @@ -537,16 +618,26 @@ def interaction(self, index, translations, *args):
probe_support=self.probe_support)

return exit_waves


def forward_propagator(self, wavefields):
return tools.propagators.far_field(wavefields)
if self.near_field:
return tools.propagators.near_field(
wavefields, self.angular_spectrum_propagator
)
else:
return tools.propagators.far_field(wavefields)


def backward_propagator(self, wavefields):
return tools.propagators.inverse_far_field(wavefields)

if self.near_field:
return tools.propagators.near_field(
wavefields, self.inverse_angular_spectrum_propagator
)
else:
return tools.propagators.inverse_far_field(wavefields)


def measurement(self, wavefields):
return tools.measurements.quadratic_background(
wavefields,
Expand Down Expand Up @@ -792,7 +883,7 @@ def get_probes(idx):
values=values,
fig=fig,
units=self.units,
basis=self.obj_basis,
basis=self.probe_basis,
nanomap_colorbar_title='Total Probe Intensity',
cmap=cmap,
**kwargs),
Expand Down
16 changes: 10 additions & 6 deletions src/cdtools/tools/image_processing/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,19 +321,23 @@ def convolve_1d(image, kernel, dim=0, fftshift_kernel=True):
return conv_im


def fourier_upsample(ims, preserve_mean=False):
def fourier_upsample(ims, upsample_factor=2, preserve_mean=False):
# If preserve_mean is true, it preserves the mean pixel intensity
# otherwise, it preserves the total summed intensity
upsampled = t.zeros(ims.shape[:-2]+(2*ims.shape[-2],2*ims.shape[-1]),

upsampled = t.zeros(ims.shape[:-2]+(upsample_factor*ims.shape[-2],
upsample_factor*ims.shape[-1]),
dtype=ims.dtype,
device=ims.device)
left = [ims.shape[-2]//2,ims.shape[-1]//2]
right = [ims.shape[-2]//2+ims.shape[-2],
ims.shape[-1]//2+ims.shape[-1]]

left = [((upsample_factor-1)*ims.shape[-2])//2,
((upsample_factor-1)*ims.shape[-1])//2]
right = [left[0]+ims.shape[-2],
left[1]+ims.shape[-1]]

upsampled[...,left[0]:right[0],left[1]:right[1]] = propagators.far_field(ims)
if preserve_mean:
upsampled *= 2
upsampled *= upsample_factor
return propagators.inverse_far_field(upsampled)


Expand Down
Loading