diff --git a/src/dolphin/goldstein.py b/src/dolphin/goldstein.py index a6735aef..19232a35 100644 --- a/src/dolphin/goldstein.py +++ b/src/dolphin/goldstein.py @@ -81,32 +81,47 @@ def patch_goldstein_filter( return weight * data def apply_goldstein_filter(data: NDArray[np.complex64]) -> np.ndarray: - # Create an empty array for the output - out = np.zeros(data.shape, dtype=np.complex64) - empty_mask = np.isnan(data) | (np.angle(data) == 0) - # ignore processing for a chunks + # Mark invalid pixels (NaN or zero magnitude) + empty_mask = np.isnan(data) | (data == 0) + # ignore processing for empty chunks if np.all(empty_mask): return data - # Create the weight matrix + + nrows, ncols = data.shape + step = psize // 2 + + # Pad on all sides with reflection to handle edges without artifacts. + # Padding of step ensures original (0,0) gets weight from overlapping windows. + pad_top = step + pad_left = step + pad_bottom = step + (step - (nrows % step)) % step + pad_right = step + (step - (ncols % step)) % step + data_padded = np.pad( + data, ((pad_top, pad_bottom), (pad_left, pad_right)), mode="reflect" + ) + + # Create output arrays matching padded size + out = np.zeros(data_padded.shape, dtype=np.complex64) + weight_sum = np.zeros(data_padded.shape, dtype=np.float64) weight_matrix = make_weight(psize, psize) - # Iterate over windows of the data - for i in range(0, data.shape[0] - psize, psize // 2): - for j in range(0, data.shape[1] - psize, psize // 2): - # Create processing windows - data_window = data[i : i + psize, j : j + psize] - weight_window = weight_matrix[ - : data_window.shape[0], : data_window.shape[1] - ] - # Apply the filter to the window + + # Iterate over windows using full psize windows + padded_rows, padded_cols = data_padded.shape + for i in range(0, padded_rows - psize + 1, step): + for j in range(0, padded_cols - psize + 1, step): + data_window = data_padded[i : i + psize, j : j + psize] filtered_window = patch_goldstein_filter( - data_window, weight_window, psize + data_window, weight_matrix, psize ) - # Add the result to the output array - slice_i = slice(i, min(i + psize, out.shape[0])) - slice_j = slice(j, min(j + psize, out.shape[1])) - out[slice_i, slice_j] += filtered_window[ - : slice_i.stop - slice_i.start, : slice_j.stop - slice_j.start - ] + out[i : i + psize, j : j + psize] += filtered_window + weight_sum[i : i + psize, j : j + psize] += weight_matrix + + # Normalize by accumulated weights + valid = weight_sum > 0 + out[valid] /= weight_sum[valid] + + # Crop back to original size and apply empty mask + out = out[pad_top : pad_top + nrows, pad_left : pad_left + ncols] out[empty_mask] = 0 return out