diff --git a/src/obstools/airmass.py b/src/obstools/airmass.py index eb3f4b0..186ffb4 100644 --- a/src/obstools/airmass.py +++ b/src/obstools/airmass.py @@ -41,15 +41,14 @@ def refractive_index(h_gp): delta = 2.93e-4 rho = atmosphere(h_gp) - n = 1. + delta * (rho / RHO0) - return n + return 1. + delta * (rho / RHO0) class Atmopshere(object): pass -def atmosphere(H_gp): # class StandardAtmosphere +def atmosphere(H_gp): # class StandardAtmosphere """ US Standard Atmosphere, 1976 As published by NOAA, NASA, and USAF @@ -69,13 +68,16 @@ def atmosphere(H_gp): # class StandardAtmosphere if isinstance(H_gp, (float, int)): H_gp = np.array([H_gp]) - regions = [(0. <= H_gp) & (H_gp <= 11e3), - (11e3 < H_gp) & (H_gp <= 20e3), - (20e3 < H_gp) & (H_gp <= 32e3), - (32e3 < H_gp) & (H_gp <= 47e3), - (47e3 < H_gp) & (H_gp <= 51e3), - (51e3 < H_gp) & (H_gp <= 71e3), - (71e3 < H_gp) & (H_gp <= 84852.)] + regions = [ + (H_gp >= 0.0) & (H_gp <= 11e3), + (H_gp > 11e3) & (H_gp <= 20e3), + (H_gp > 20e3) & (H_gp <= 32e3), + (H_gp > 32e3) & (H_gp <= 47e3), + (H_gp > 47e3) & (H_gp <= 51e3), + (H_gp > 51e3) & (H_gp <= 71e3), + (H_gp > 71e3) & (H_gp <= 84852.0), + ] + expressions = [lambda x: RHO0 * (1. - x / 44330.94) ** 4.25587615, lambda x: RHO0 * 0.29707755 * np.exp((11e3 - x) / 6341.62), @@ -362,8 +364,7 @@ def delM(z, h): cos_delphi = (4 * (rhm * np.cos(im)) ** 2 - (delh * np.sin(im)) ** 2) / ( 4 * (rhm * np.cos(im)) ** 2 + (delh * np.sin(im)) ** 2) - dM = rho * np.sqrt(rh * rh + rhp * rhp - 2 * rh * rhp * cos_delphi) - return dM + return rho * np.sqrt(rh * rh + rhp * rhp - 2 * rh * rhp * cos_delphi) H = np.arange(0., Hmax, delh) X = np.empty(Z.shape) diff --git a/src/obstools/aps/ApertureCollections.py b/src/obstools/aps/ApertureCollections.py index d6f06b4..b55d7db 100644 --- a/src/obstools/aps/ApertureCollections.py +++ b/src/obstools/aps/ApertureCollections.py @@ -356,7 +356,7 @@ def __init__(self, widths=None, heights=None, angles=None, **kws): def __str__(self): # FIXME: better repr with widths, heights, angles - return '%s of shape %s' % (self.__class__.__name__, self.shape) + return f'{self.__class__.__name__} of shape {self.shape}' # def __repr__(self): # return str(self) @@ -518,26 +518,19 @@ def append(self, aps=None, **props): print('#' * 300) return - if not self.size: - concatenate = lambda o, a: a - # if the Collection was initialized as empty, set the new properties as current - else: - concatenate = props.concatenate - + concatenate = props.concatenate if self.size else (lambda o, a: a) # embed() oprops = self._properties._original # Find which properties differ and update those for key, val in props.items(): - if (not key in oprops) \ - or (not np.array_equal(oprops[key], props[ - key])): # `np.array_equal` here flags the empty properties as being unequal to the new ones, whereas `np.all` evaluates as True under the same conditions + if key not in oprops or not np.array_equal(oprops[key], props[key]): # `np.array_equal` here flags the empty properties as being unequal to the new ones, whereas `np.all` evaluates as True under the same conditions new = concatenate(self[key], val) # print( '8'*88 ) # print('APPEND', key, self[key], val ) # print( 'NEW:', new ) - setter = getattr(self, 'set_%s' % key) + setter = getattr(self, f'set_{key}') setter(new) # print( ) @@ -581,8 +574,7 @@ def area(self, idx=...): def area_between(self, idxs): """return the area enclosed between the two apertures given by idxs""" A = np.array(self.area(idxs), ndmin=2) - area = np.abs(np.subtract(*A.T)) - return area + return np.abs(np.subtract(*A.T)) def center_proximity(self, position, idx=...): """ @@ -645,7 +637,7 @@ def edge_proximity(self, position, idx=...): # ============================================================================================== # TODO: maybe maxe this a propetry?? def add_to_axes(self, ax=None): - if not self in ax.collections: + if self not in ax.collections: # print( 'Adding collection to axis' ) # necessary for apertures to map correctly to data positions @@ -888,10 +880,7 @@ def __init__(self, apertures, ax, **kws): # @expose.args( pre='='*100, post='?'*100 ) def make_segments(self, radii): - if radii.size: - return [list(zip((r, r), (0, 1))) for r in radii] - else: - return [] + return [list(zip((r, r), (0, 1))) for r in radii] if radii.size else [] def update_from(self, aps): @@ -1113,12 +1102,7 @@ def append(self, aps=None, **props): # can all be changed at a specific index. # FIXME: SMELLY CODE!!!!!!!!!!!! - if not self.size: - concatenate = lambda o, a: a - # HACK! if the Collection was initialized as empty, set the new properties as current - else: - concatenate = PropertyManager.concatenate - + concatenate = PropertyManager.concatenate if self.size else (lambda o, a: a) # for key in self._properties.__broadcast__: for key, val in props.items(): @@ -1134,7 +1118,7 @@ def append(self, aps=None, **props): print(e) embed() - setter = getattr(self, 'set_%s' % key) + setter = getattr(self, f'set_{key}') print() # try: @@ -1148,7 +1132,7 @@ def within_allowed_range(self, r): def resize(self, relative_motion, idx=..., ): print('RESIZING!', relative_motion, self.radii, idx) - if not relative_motion is None: + if relative_motion is not None: rnew = self.radii rnew[idx] += relative_motion if not self.within_allowed_range(rnew): diff --git a/src/obstools/ephemeris.py b/src/obstools/ephemeris.py index 21fa191..1079626 100644 --- a/src/obstools/ephemeris.py +++ b/src/obstools/ephemeris.py @@ -153,12 +153,11 @@ def rephase(phase, offset, *data): phase %= 1 if data is None: return phase - else: - data = np.array(cosort(phase, *data)) + data = np.array(cosort(phase, *data)) - phase = data[0] - data = data[1:] - return phase, data + phase = data[0] + data = data[1:] + return phase, data def phase_splitter(ph, *data, **kws): diff --git a/src/obstools/image/calibration.py b/src/obstools/image/calibration.py index 9fa51fa..861df02 100644 --- a/src/obstools/image/calibration.py +++ b/src/obstools/image/calibration.py @@ -15,9 +15,7 @@ def __init__(self, name): self.name = f'_{name}' def __get__(self, instance, owner): - if instance is None: - return self - return getattr(instance, self.name) + return self if instance is None else getattr(instance, self.name) def __set__(self, instance, value): if value is keep: diff --git a/src/obstools/image/mosaic.py b/src/obstools/image/mosaic.py index 4c05edb..e0c9bda 100644 --- a/src/obstools/image/mosaic.py +++ b/src/obstools/image/mosaic.py @@ -87,7 +87,7 @@ def plot_transformed_image(ax, image, fov=None, p=(0, 0, 0), frame=True, frame_kws = dict(fc='none', lw=0.5, ec='0.5', alpha=kws.get('alpha')) if isinstance(frame, dict): - frame_kws.update(frame) + frame_kws |= frame ax.add_patch( Rectangle(xy - half_pixel_size, *fov[::-1], np.degrees(theta), @@ -256,7 +256,7 @@ def plot_image(self, image=None, fov=None, p=(0, 0, 0), name=None, # # if not isinstance(image, SkyImage): - if not image.__class__.__name__ == 'SkyImage': + if image.__class__.__name__ != 'SkyImage': image = SkyImage(image, fov) # @@ -373,7 +373,7 @@ def mark_target(self, xy, name, colour='forestgreen', arrow_size=10, def label_image(self, name='', p=(0, 0, 0), fov=(0, 0), **kws): # default args for init _kws = {} - _kws.update(self.label_props) + _kws |= self.label_props _kws.update(kws) return self.ax.text(*ulc(p, fov), name, rotation=np.degrees(p[-1]), diff --git a/src/obstools/image/registration.py b/src/obstools/image/registration.py index 77bac81..dd26f2b 100644 --- a/src/obstools/image/registration.py +++ b/src/obstools/image/registration.py @@ -298,10 +298,7 @@ def _checks(self, p, xy=None, *args, **kws): return self._check_params(p), self._check_grid(xy) def _check_params(self, p): - if (p is None) or (p == ()): - # default parameter values for evaluation - return np.zeros(self.dof) - return p + return np.zeros(self.dof) if (p is None) or (p == ()) else p def _check_grid(self, grid): # @@ -373,9 +370,8 @@ def plot(self, grid=100, show_xy=True, show_peak=True, **kws): grid = duplicate_if_scalar(grid, self.n_dims, raises=False) if grid.size == self.n_dims: grid = self._auto_grid(grid) - else: - if (grid.ndim != 3) or (grid.shape[-1] != self.n_dims): - raise ValueError('Invalid grid') + elif (grid.ndim != 3) or (grid.shape[-1] != self.n_dims): + raise ValueError('Invalid grid') # compute model values z = self((), grid) @@ -691,15 +687,13 @@ def display_multitab(images, fovs, params, coords): import more_itertools as mit ui = MplMultiTab() - for i, (image, fov, p, yx) in enumerate(zip(images, fovs, params, coords)): + for image, fov, p, yx in zip(images, fovs, params, coords): xy = yx[:, ::-1] # roto_translate_yx(yx, np.r_[-p[:2], 0])[:, ::-1] ex = mit.interleave((0, 0), fov) im = ImageDisplay(image, extent=list(ex)) im.ax.plot(*xy.T, 'kx', ms=5) ui.add_tab(im.figure) plt.close(im.figure) - # if i == 1: - # break return ui @@ -1102,20 +1096,18 @@ def _measure_positions_offsets(xy, centres, d_cut=None): out_new = (d > d_cut) out_new = np.ma.getdata(out_new) | np.ma.getmask(out_new) - changed = (outliers != out_new).any() - if changed: - out = out_new - xym[out] = np.ma.masked - n_out = out.sum() + if not (changed := (outliers != out_new).any()): + break - if n_out / n_points > 0.5: - raise Exception('Too many outliers!!') + out = out_new + xym[out] = np.ma.masked + n_out = out.sum() - logger.info('Ignoring %i/%i (%.1f%%) values with |δr| > %.3f', - n_out, n_points, (n_out / n_points) * 100, d_cut) - else: - break + if n_out / n_points > 0.5: + raise Exception('Too many outliers!!') + logger.info('Ignoring %i/%i (%.1f%%) values with |δr| > %.3f', + n_out, n_points, (n_out / n_points) * 100, d_cut) return centres, xy_shifted.std(0), xy_offsets.squeeze(), outliers @@ -1385,7 +1377,7 @@ def plot(self, ax=None, p=(0, 0, 0), scale='fov', frame=True, **kws): frame_kws = dict(fc='none', lw=1, ec='0.5', alpha=kws.get('alpha')) if isinstance(frame, dict): - frame_kws.update(frame) + frame_kws |= frame *xy, theta = p frame = Rectangle(np.subtract(xy, half_pixel_size), *urc, @@ -1652,11 +1644,7 @@ def from_images(cls, images, fovs, angles=(), ridx=None, plot=False, # message cls.logger.info('Aligning %i images on image %i', n, ridx) - if len(angles): - angles = np.array(angles) - angles[ridx] # relative angles - else: - angles = np.zeros(n) - + angles = np.array(angles) - angles[ridx] if len(angles) else np.zeros(n) reg = cls(**find_kws) for i in indices: reg(images[i], fovs[i], angles[i], plot=plot) @@ -2566,8 +2554,7 @@ def mosaic(self, names=(), **kws): def get_rotation(self): # transform pixel to ICRS coordinate h = self.hdu[0].header - theta = np.pi / 2 - np.arctan(-h['CD1_1'] / h['CD1_2']) - return theta + return np.pi / 2 - np.arctan(-h['CD1_1'] / h['CD1_2']) # todo: def proper_motion_correction(self, coords): diff --git a/src/obstools/image/segmentation/core.py b/src/obstools/image/segmentation/core.py index e71c831..2a1beab 100644 --- a/src/obstools/image/segmentation/core.py +++ b/src/obstools/image/segmentation/core.py @@ -198,17 +198,16 @@ def select_rect_pad(segm, image, start, shape): def inside_segment(coords, sub, grid): b = [] ogrid = grid[0, :, 0], grid[1, 0, :] - for j, (g, f) in enumerate(zip(ogrid, coords)): + for g, f in zip(ogrid, coords): bi = np.digitize(f, g - 0.5) b.append(bi) mask = (sub == 0) - if np.equal(grid.shape[1:], b).any() or np.equal(0, b).any(): - inside = False - else: - inside = not mask[b[0], b[1]] - - return inside + return ( + False + if np.equal(grid.shape[1:], b).any() or np.equal(0, b).any() + else not mask[b[0], b[1]] + ) def get_masking_flags(arrays, masked): @@ -466,9 +465,10 @@ def grow(self, labels, inc=1): # z + np.array([-1, 1], ndmin=3).T urc = np.add(self.urc(labels), inc) # .clip(None, self.seg.shape) llc = np.add(self.llc(labels), -inc).clip(0) - slices = [tuple(slice(*i) for i in yxix) - for yxix in zip(*np.swapaxes([llc, urc], -1, 0))] - return slices + return [ + tuple(slice(*i) for i in yxix) + for yxix in zip(*np.swapaxes([llc, urc], -1, 0)) + ] # def around_centroids(self, image, size, labels=None): # com = self.seg.centroid(image, labels) @@ -609,9 +609,7 @@ def __init_subclass__(cls, **kwargs): for stat in cls._supported: method = MaskedStatistic(getattr(ndimage, stat)) setattr(cls, stat, method) - # also add aliases for convenience - alias = cls._aliases.get(stat) - if alias: + if alias := cls._aliases.get(stat): setattr(cls, alias, method) @@ -641,12 +639,7 @@ def __init__(self, func): def __get__(self, seg, objtype=None): - if seg is None: # called from class - return self - - # bind this class to the seg instance from whence the lookup came. - # Essentially this binds the first argument `seg` in `__call__` below - return types.MethodType(self, seg) + return self if seg is None else types.MethodType(self, seg) def __call__(self, seg, image, labels=None): # handling of masked pixels for all statistical methods done here @@ -1072,7 +1065,7 @@ def slices(self): which is very nice. """ s = {0: (slice(None),) * self.data.ndim} - s.update(zip(self.labels, SegmentationImage.slices.fget(self))) + s |= zip(self.labels, SegmentationImage.slices.fget(self)) return Sliced(s) def sliced(self, label): @@ -1113,11 +1106,7 @@ def heights(self): @lazyproperty def max_label(self): - if len(self.labels): - return super().max_label - else: - # otherwise `np.max` borks with empty sequence - return 0 + return super().max_label if len(self.labels) else 0 def make_cmap(self, background_color='#000000', random_state=random_state): # this function fails for all zero data since `make_random_cmap` @@ -1209,7 +1198,7 @@ def has_labels(self, labels, allow_zero=False): invalid = np.setdiff1d(labels, valid) if len(invalid): - raise ValueError('Invalid label(s): %s' % str(tuple(invalid))) + raise ValueError(f'Invalid label(s): {tuple(invalid)}') return labels @@ -1555,14 +1544,13 @@ def _relabel_masked(self, image, masked_pixels_label=None): masked_pixels_label = self.max_label + 1 masked_pixels_label = int(masked_pixels_label) - if np.ma.is_masked(image): - # ignore masked pixels - seg_data = self.data.copy() - seg_data[image.mask] = masked_pixels_label - # this label will not be used for statistic computation - return seg_data - else: + if not np.ma.is_masked(image): return self.data + # ignore masked pixels + seg_data = self.data.copy() + seg_data[image.mask] = masked_pixels_label + # this label will not be used for statistic computation + return seg_data def thumbnails(self, image=None, labels=None, masked=False): """ @@ -1858,8 +1846,7 @@ def dilate(self, iterations=1, connectivity=4, labels=None, mask=None, masks = self.to_binary(labels, expand=True) if structure is None: - d = {4: 1, 8: 2}.get(connectivity) - if d: + if d := {4: 1, 8: 2}.get(connectivity): structure = ndimage.generate_binary_structure(2, d) else: raise ValueError('Invalid connectivity={0}. ' @@ -1881,9 +1868,8 @@ def dilate(self, iterations=1, connectivity=4, labels=None, mask=None, if copy: return self.__class__(data) - else: - self.data = data - return self + self.data = data + return self def auto_dilate(self, image, labels=None, dmax=5, sigma=3): # @@ -2268,7 +2254,7 @@ def format_term(self, show_labels=True, frame=True, origin=0, cmap=None): from motley import codes origin = int(origin) - assert origin in (0, 1) + assert origin in {0, 1} # re-orient data o = 1 if origin else -1 @@ -2342,10 +2328,7 @@ def format_term(self, show_labels=True, frame=True, origin=0, cmap=None): # create string i0 = 0 - im = '' - if frame: - im = motley.underline(' ' * ((nc + 1) * nm)) + '\n' + BORDER - + im = motley.underline(' ' * ((nc + 1) * nm)) + '\n' + BORDER if frame else '' for i, mrk in sorted(marks.items(), key=lambda _: _[0]): im += ' ' * ((i - i0 - 1) * nm) + mrk i0 = i diff --git a/src/obstools/image/segmentation/detect.py b/src/obstools/image/segmentation/detect.py index 84d9327..af77434 100644 --- a/src/obstools/image/segmentation/detect.py +++ b/src/obstools/image/segmentation/detect.py @@ -41,7 +41,7 @@ def make_border_mask(data, edge_cutoffs): if len(edge_cutoffs) == 4: return _make_border_mask(data, *edge_cutoffs) - raise ValueError('Invalid edge_cutoffs %s' % edge_cutoffs) + raise ValueError(f'Invalid edge_cutoffs {edge_cutoffs}') def detect(image, mask=False, background=None, snr=3., npixels=7, @@ -95,7 +95,7 @@ def detect(image, mask=False, background=None, snr=3., npixels=7, logger.debug('No objects detected') return np.zeros_like(image, bool) - if deblend and not no_sources: + if deblend: from photutils import deblend_sources seg = deblend_sources(image, seg, npixels) diff --git a/src/obstools/image/segmentation/trace.py b/src/obstools/image/segmentation/trace.py index 7d09a3f..9383847 100644 --- a/src/obstools/image/segmentation/trace.py +++ b/src/obstools/image/segmentation/trace.py @@ -95,9 +95,7 @@ def trace_boundary(b, stop=int(1e4)): if step_size == 1: boundary.append(boundary[-1] + EDGES[tuple(mv)]) - # check if we are done. Jacob's stopping criterion - done = (current == start).all() and (mv == (0, -1)).all() - if done: + if done := (current == start).all() and (mv == (0, -1)).all(): # close the perimeter perimeter += step_size break diff --git a/src/obstools/image/transforms.py b/src/obstools/image/transforms.py index 155e8af..bd61491 100644 --- a/src/obstools/image/transforms.py +++ b/src/obstools/image/transforms.py @@ -57,7 +57,7 @@ def rigid(xy, p): if (np.ndim(xy) < 2) or (np.shape(xy)[-1] != 2): raise ValueError('Invalid dimensions for coordinate array `xy`') - if not len(p) == 3: + if len(p) != 3: raise ValueError('Invalid parameter array for rigid transform `xy`') return rotate(xy, p[-1]) + p[:2] diff --git a/src/obstools/image/utils.py b/src/obstools/image/utils.py index e5bc9c1..d969760 100644 --- a/src/obstools/image/utils.py +++ b/src/obstools/image/utils.py @@ -39,11 +39,7 @@ def table_coords(coo, ix_fit, ix_scale, ix_loc): frame=False, align='^', col_borders='', cell_whitespace=0) tt.colourise(m, fg=cols) - # ts = tt.add_colourbar(str(tt), ('fit|', 'scale|', 'loc|')) - - # join tables - tbl = Table([[str(cootbl), str(tt)]], frame=False, col_borders='') - return tbl + return Table([[str(cootbl), str(tt)]], frame=False, col_borders='') def table_cdist(sdist, window, _print=False): diff --git a/src/obstools/lc/ascii.py b/src/obstools/lc/ascii.py index 301aa43..afc23ff 100644 --- a/src/obstools/lc/ascii.py +++ b/src/obstools/lc/ascii.py @@ -36,15 +36,14 @@ def parse_format_spec(fmt): - mo = FORMATSPEC_SRE.match(fmt) - if mo: + if mo := FORMATSPEC_SRE.match(fmt): return mo.groups() # width, precision, dtype = else: raise ValueError('Nope!') def format_list(data, fmt='%g', width=8, sep=','): - lfmt = '%-{}s'.format(width) * len(data) + lfmt = f'%-{width}s' * len(data) s = lfmt % tuple(np.char.mod(fmt + sep, data)) return s[::-1].replace(',', ' ', 1)[::-1].join('[]') @@ -65,9 +64,7 @@ def header_info_block(name, info): def get_name(o): - if isinstance(o, Callable): - return o.__name__ - return str(o) + return o.__name__ if isinstance(o, Callable) else str(o) def check_column_widths(names, formats): @@ -128,10 +125,7 @@ def hstack_string(a, b, whitespace=1): bl = b.splitlines() w = max(map(len, al)) + whitespace - return ''.join( - '{: <{}s}{}\n'.format(aa, w, bb) - for i, (aa, bb) in enumerate(zip(al, bl)) - ) + return ''.join('{: <{}s}{}\n'.format(aa, w, bb) for aa, bb in zip(al, bl)) def get_column_info(nstars, has_oflag): @@ -180,9 +174,9 @@ def get_column_info(nstars, has_oflag): formats.extend(col_fmt_per_star) # prepend comment str in such a way as to not screw up alignment with data - units = ['[%s]' % u for u in units] - units[0] = '# ' + units[0] - names[0] = '# ' + names[0] + units = [f'[{u}]' for u in units] + units[0] = f'# {units[0]}' + names[0] = f'# {names[0]}' return names, units, formats, col_info @@ -213,17 +207,18 @@ def make_header(obj_name, shape_info, has_oflag, meta={}): col_widths, col_fmt_head, col_fmt_data = make_column_format(names, formats) # make header - title = 'Light Curve for %s' % obj_name + title = f'Light Curve for {obj_name}' # table shape info - lines = ['# ' + header_info_block(title, shape_info)] + lines = [f'# {header_info_block(title, shape_info)}'] # column descriptions lines.append(header_info_block('columns', col_info)) # header blocks for additional meta data - for sec_name, info in meta.items(): - lines.append(header_info_block(sec_name, info)) + lines.extend( + header_info_block(sec_name, info) for sec_name, info in meta.items() + ) # header as commented string # prepend comment character diff --git a/src/obstools/modelling/core.py b/src/obstools/modelling/core.py index 8240341..3d2d79f 100644 --- a/src/obstools/modelling/core.py +++ b/src/obstools/modelling/core.py @@ -153,11 +153,7 @@ def __call__(self, p, data, grid=None, sigma=None): # point # - if sigma is None: - sigma_term = 0 - else: - sigma_term = np.log(sigma).sum() - + sigma_term = 0 if sigma is None else np.log(sigma).sum() return (- data.size * LN2PI_2 # # TODO: einsum here for mahalanobis distance term - 0.5 * self.wrss(p, data, grid, stddev) @@ -427,8 +423,13 @@ def aicc(self, p, data, *args, **kws): # then the formula for AICc is as follows. k = len(p) n = data.size - return 2 * (k + (k * k + k) / (n - k - 1) - - self.ln_likelihood(p, data, *args, **kws)) + return 2 * ( + ( + k + + (k**2 + k) / (n - k - 1) + - self.ln_likelihood(p, data, *args, **kws) + ) + ) # "If the assumption that the model is univariate and linear with normal # residuals does not hold, then the formula for AICc will generally be # different from the formula above. For some models, the precise formula @@ -517,14 +518,7 @@ def pre_fit(self, loss, p0, data, *args, **kws): # nested parameters: flatten prior to minimize, re-structure post-fit # TODO: move to HandleParameters Mixin?? - if isinstance(p0, Parameters): - p0 = p0.flattened - else: - # need to convert to float since we are using p0 dtype to type - # cast the results vector for structured parameters and user might - # have passed an array-like of integers - p0 = p0.astype(float) - + p0 = p0.flattened if isinstance(p0, Parameters) else p0.astype(float) # check that call works. This check here so that we can identify # potential problems with the function call / arguments before entering # the optimization routine. Any potential errors that occur here will @@ -582,8 +576,7 @@ def _fit(self, loss, p0, data, *args, **kws): msg = result.message if success: - unchanged = np.allclose(p, p0) - if unchanged: + if unchanged := np.allclose(p, p0): # TODO: maybe also warn if any close ? self.logger.warning('"Converged" parameter vector is ' 'identical to initial guess: %s', p0) @@ -962,12 +955,7 @@ def get_dtype(self, keys=all): if keys is all: keys = self.models.keys() - dtype = [] - for key in keys: - dtype.append( - self._adapt_dtype(self.models[key], ()) - ) - return dtype + return [self._adapt_dtype(self.models[key], ()) for key in keys] def _adapt_dtype(self, model, out_shape): # adapt the dtype of a component model so that it can be used with @@ -976,20 +964,15 @@ def _adapt_dtype(self, model, out_shape): # used for more than one key (label) to be represented by a 2D array. # make sure size in a tuple - if out_shape == 1: - out_shape = () - else: - out_shape = int2tup(out_shape) - + out_shape = () if out_shape == 1 else int2tup(out_shape) dt = model.get_dtype() - if len(dt) == 1: # simple model - name, base, dof = dt[0] - dof = int2tup(dof) - # extend shape of dtype - return model.name, base, out_shape + dof - else: # compound model + if len(dt) != 1: # structured dtype - nest! return model.name, dt, out_shape + name, base, dof = dt[0] + dof = int2tup(dof) + # extend shape of dtype + return model.name, base, out_shape + dof def _results_container(self, keys=all, dtype=None, fill=np.nan, shape=(), type_=Parameters): diff --git a/src/obstools/modelling/image/core.py b/src/obstools/modelling/image/core.py index cb540a9..72de007 100644 --- a/src/obstools/modelling/image/core.py +++ b/src/obstools/modelling/image/core.py @@ -120,8 +120,8 @@ def set_models(self, models): if isinstance(models, Model): models = [models] - n_models = len(models) if not isinstance(models, MutableMapping): + n_models = len(models) if n_models not in (0, self.seg.nlabels): raise ValueError("Mapping from segments to models is not " "1-to-1") @@ -850,16 +850,14 @@ def create_apertures(self, appars, coords=None, sky=False, fallbacks=None): if coords is None: coords = fcoords # use fit coordinates for aperture positions - if sky: - sx, sy = sigma_xy - rxsky = sx * self.rsky - rysky = sy * self.rsky[1] - return [EllipticalAnnulus(coo, *rxsky, rysky, theta) - for coo in coords[:, 1::-1]] - - else: + if not sky: return [EllipticalAperture(coo, rx, ry, theta) for coo in coords[:, 1::-1]] + sx, sy = sigma_xy + rxsky = sx * self.rsky + rysky = sy * self.rsky[1] + return [EllipticalAnnulus(coo, *rxsky, rysky, theta) + for coo in coords[:, 1::-1]] # TODO: use_fit_coords # TODO: handle bright and faint seperately here @@ -879,11 +877,11 @@ def check_aps_sky(self, i, rsky, rmax): rskyin, rskyout = rsky info = 'Frame {:d}, rin={:.1f}, rout={:.1f}' if np.isnan(rsky).any(): - self.logger.warning('Nans in sky apertures: ' + info, i, *rsky) + self.logger.warning(f'Nans in sky apertures: {info}', i, *rsky) if rskyin > rskyout: - self.logger.warning('rskyin > rskyout: ' + info, i, *rsky) + self.logger.warning(f'rskyin > rskyout: {info}', i, *rsky) if rskyin > rmax: - self.logger.warning('Large sky apertures: ' + info, i, *rsky) + self.logger.warning(f'Large sky apertures: {info}', i, *rsky) # class ImageModeller(SegmentedImageModel, ModellingResultsMixin): diff --git a/src/obstools/modelling/image/diagnostics.py b/src/obstools/modelling/image/diagnostics.py index 95de794..b216b23 100644 --- a/src/obstools/modelling/image/diagnostics.py +++ b/src/obstools/modelling/image/diagnostics.py @@ -62,7 +62,7 @@ def plot_modelled_image(model, image, params, seg=None, residual_mask=False, ImageDisplay(model(params), ax=axes[2 - ovr], title='Model') # residuals - if not (residual_mask is None or residual_mask is False): + if residual_mask is not None and residual_mask is not False: image = np.ma.MaskedArray(image, residual_mask) residuals = model.residuals(params, image) @@ -80,9 +80,7 @@ def plot_modelled_image(model, image, params, seg=None, residual_mask=False, im.ax.text(0, -0.45, s.expandtabs(), transform=im.ax.transAxes, fontsize=12) - # segmentation - seg = seg or getattr(model, 'seg', None) - if seg: + if seg := seg or getattr(model, 'seg', None): if overlay_segments: axes[1].add_collection(seg.get_contours()) seg.draw_labels(axes[1]) @@ -209,7 +207,7 @@ def plot_cross_section(model, p, data, grid=None, std=None, yscale=1, def gof_text(stat, name, xpos=0): v = stat(p, data, grid, std) s = sci_repr(v, latex=True).strip('$') - txt = '$%s = %s$' % (name, s) + txt = f'${name} = {s}$' # print(txt) return axTxt.text(xpos, 0, txt, fontsize=14, va='top', transform=axTxt.transAxes) diff --git a/src/obstools/modelling/lm_compat.py b/src/obstools/modelling/lm_compat.py index 4355e1b..6cdc1f2 100644 --- a/src/obstools/modelling/lm_compat.py +++ b/src/obstools/modelling/lm_compat.py @@ -59,7 +59,7 @@ class lmMixin(): def fit(self, p0, data, grid, data_stddev=None, **kws): - self.logger.debug('Guessed: (%s)' % ', '.join(map(decimal_repr, p0))) + self.logger.debug(f"Guessed: ({', '.join(map(decimal_repr, p0))})") params = self._set_param_values(p0) params = self._constrain_params(params, z0=(0, np.inf)) @@ -76,8 +76,7 @@ def fit(self, p0, data, grid, data_stddev=None, **kws): plsq = result.params p, punc = np.transpose([(p.value, p.stderr) for p in plsq.values()]) - bad = np.allclose(p, p0) - if bad: # model "converged" to the initial values + if bad := np.allclose(p, p0): self.logger.warning('%s fit did not converge!', self) self.logger.debug('input parameters identical to output') @@ -153,13 +152,13 @@ def lmModelFactory(base_, method_names, param_names): class lmConvertMeta(type): """Constructor that creates the converted class""" - def __new__(meta, name, bases, namespace): + def __new__(cls, name, bases, namespace): for base in bases: for mn, method in inspect.getmembers(base, inspect.isfunction): if (mn in method_names): # decorate the method namespace[mn] = convert_params(method) - return type.__new__(meta, name, bases, namespace) + return type.__new__(cls, name, bases, namespace) class lmCompatModel(lmMixin, base_, metaclass=lmConvertMeta): params = make_params(param_names) diff --git a/src/obstools/modelling/parameters.py b/src/obstools/modelling/parameters.py index 3c5eee5..472fbd8 100644 --- a/src/obstools/modelling/parameters.py +++ b/src/obstools/modelling/parameters.py @@ -24,10 +24,7 @@ def echo(*_): def get_shape(data): - if isinstance(data, Parameters): - return data.npar - else: - return np.shape(data) + return data.npar if isinstance(data, Parameters) else np.shape(data) def _walk_dtype_size(obj): @@ -116,7 +113,7 @@ def type_assertion(obj, allow_types=any): # print('allow types', allow_types) if not isinstance(obj, allow_types): - raise TypeError('%s type objects are not supported' % type(obj)) + raise TypeError(f'{type(obj)} type objects are not supported') @staticmethod def asscalar(key, val): @@ -160,17 +157,14 @@ def walk(obj, call=echo, flat=False, with_keys=True, container_out) if flat: yield from gen + elif with_keys: + yield key, container_out(gen) # map( else: - if with_keys: - yield key, container_out(gen) # map( - else: - yield container_out(gen) + yield container_out(gen) + elif with_keys: + yield call(key, item) else: - # switch caller here to call(item) if with_keys is False - if with_keys: - yield call(key, item) - else: - yield call(item) + yield call(item) # default helper singleton @@ -223,12 +217,11 @@ def __new__(cls, data=None, base_dtype=float, **kws): if data is not None: if isinstance(data, dict): return cls.__new__(cls, None, base_dtype, **data) - else: - # use case: Parameters([1, 2, 3, 4, 5]) - # use the `numpy.rec.array` to allow for construction from a - # wide variety of compatible objects - obj = np.rec.array(data) - return obj.view(cls) # view as Parameters object + # use case: Parameters([1, 2, 3, 4, 5]) + # use the `numpy.rec.array` to allow for construction from a + # wide variety of compatible objects + obj = np.rec.array(data) + return obj.view(cls) # view as Parameters object # first we have to construct the dtype by walking the (possibly nested) # kws that define the data structure. @@ -299,9 +292,9 @@ def __str__(self): s = pformat_dict(self.to_dict()) indent = ' ' * (len(cls_name) + 1) s = s.replace('\n', '\n' + indent) - return '%s(%s)' % (cls_name, s) + return f'{cls_name}({s})' else: - return '%s(%s)' % (cls_name, super().__str__()) + return f'{cls_name}({super().__str__()})' def __repr__(self): return self.__str__() @@ -321,10 +314,7 @@ def to_dict(self, attr=False, flat=False): dict """ - dict_ = dict - if attr: - dict_ = AttrReadItem - + dict_ = AttrReadItem if attr else dict # _par_help.asscalar # return dict_(_par_help.walk(self, echo, flat, container_out=dict_)) @@ -431,7 +421,7 @@ class Prior(rv_frozen): # DistRepr # `scipy.stats._distn_infrastructure.rv_frozen` def __repr__(self): - return self.dist.name.title() + 'Prior' + str(self.args) + return f'{self.dist.name.title()}Prior{str(self.args)}' def __str__(self): # here look at `dist` attribute to determine symbol diff --git a/src/obstools/modelling/psf/models.py b/src/obstools/modelling/psf/models.py index 4aad090..85cae59 100644 --- a/src/obstools/modelling/psf/models.py +++ b/src/obstools/modelling/psf/models.py @@ -163,8 +163,7 @@ def __repr__(self): def fit(self, data, grid, std=None, **kws): """guess p0 and fit""" p0 = self.p0guess(data) - result = leastsq(self.objective, p0, args=(data, grid, std), **kws) - return result + return leastsq(self.objective, p0, args=(data, grid, std), **kws) def param_hint(self, data, grid=None, std=None): """Return a guess of the fitting parameters based on the data""" @@ -211,14 +210,7 @@ def rss(self, p, data, grid): def wrs(self, p, data, grid, std=None): """weighted squared residuals""" - if std is None: - return self.rs(p, data, grid) - w = self.rs(p, data, grid) / std - # if np.isnan(w).any(): - # from IPython import embed - # embed() - # raise SystemExit - return w + return self.rs(p, data, grid) if std is None else self.rs(p, data, grid) / std def fwrs(self, p, data, grid, std=None): """weighted squared residuals flattened""" @@ -230,7 +222,7 @@ def wrss(self, p, data, grid, std=None): def validate(self, p, *args): """validate parameter values. To be overwritten by sub-class""" - return all([vf(p) for vf in self.validations]) + return all(vf(p) for vf in self.validations) def add_validation(self, func): if not isinstance(func, Callable): @@ -410,8 +402,7 @@ def reparameterize(self, p): ratio = min(sigx, sigy) / max(sigx, sigy) ellipticity = np.sqrt(1 - ratio ** 2) fwhm = self.get_fwhm(p) - par_alt = sigx, sigy, cov, theta, ellipticity, fwhm - return par_alt + return sigx, sigy, cov, theta, ellipticity, fwhm def integrate(self, p): """ @@ -467,16 +458,17 @@ def get_description(self, p, offset=(0, 0)): ellipticity = np.sqrt(1 - ratio ** 2) coo = x + offset[0], y + offset[1] - pdict = {'coo': coo, - 'flux': counts, - 'peak': z + d, - 'sky_mean': d, - 'fwhm': fwhm, - 'sigma_xy': (sigx, sigy), - 'theta': np.degrees(theta), - 'ratio': ratio, - 'ellipticity': ellipticity} - return pdict + return { + 'coo': coo, + 'flux': counts, + 'peak': z + d, + 'sky_mean': d, + 'fwhm': fwhm, + 'sigma_xy': (sigx, sigy), + 'theta': np.degrees(theta), + 'ratio': ratio, + 'ellipticity': ellipticity, + } def coeff(self, covariance_matrix): """ @@ -513,9 +505,7 @@ def covariance_matrix(self, p): def precision_matrix(self, p): _, _, _, a, b, c, _ = p - P = np.array([[a, -b], - [-b, c]]) * 2 - return P + return np.array([[a, -b], [-b, c]]) * 2 def get_sigma_xy(self, p): covm = self.covariance_matrix(p) @@ -670,9 +660,9 @@ def __call__(self, p, grid): w = pf * (erf(f * td0) - erf(f * td1)) return amp * np.prod(w, axis=0) - def _eig2cov(w0, w1, theta): + def _eig2cov(self, w1, theta): M = _rot_mat(float(theta)) - w0 * M[:, 0] + w1 * M[:, 1] + self * M[:, 0] + w1 * M[:, 1] def residuals(self, p, data, grid): return np.square(data - self(p, grid)) @@ -709,9 +699,7 @@ def eigenvecs(self, var, cov, eigvals): varx, vary = var cov2 = cov * cov - m0 = np.sqrt(np.square(vary * vary - e0) / cov2 + 1) - return m0 - m1 = np.sqrt(1 - m0 * m0) + return np.sqrt(np.square(vary * vary - e0) / cov2 + 1) def eigenvecs2(self, var, cor, eigvals): """eigenvectorss of precision matrix from variance, correlation, eigenvalues""" @@ -720,9 +708,7 @@ def eigenvecs2(self, var, cor, eigvals): varx, vary = var cov2 = cov * cov - m0 = np.sqrt(np.square(vary * vary - e0) / cov2 + 1) - return m0 - m1 = np.sqrt(1 - m0 * m0) + return np.sqrt(np.square(vary * vary - e0) / cov2 + 1) def rss(self, p, data, grid): return np.square(self(p, grid) - data) @@ -782,7 +768,7 @@ def __init__(self, psf=None, algorithm=None, caching=True, hints=True, self._print = _print # @profile() - def __call__(self, grid, data): # TODO: make grid optional... + def __call__(self, grid, data): # TODO: make grid optional... """Fits the PSF model to the data on the grid given the input coordinates xy0""" psf = self.psf @@ -813,7 +799,7 @@ def __call__(self, grid, data): # TODO: make grid optional... return else: if self._print: - print('\nSuccessfully fit {} function to stellar profile.'.format(psf.F.__name__)) + print(f'\nSuccessfully fit {psf.F.__name__} function to stellar profile.') if self.caching: # update cache with these parameters i = self.call_count % self.max_cache_size # wrap! diff --git a/src/obstools/modelling/utils.py b/src/obstools/modelling/utils.py index 04310de..9b07400 100644 --- a/src/obstools/modelling/utils.py +++ b/src/obstools/modelling/utils.py @@ -3,6 +3,4 @@ def prod(x): """Product of a list of numbers; ~40x faster vs np.prod for Python tuples""" - if len(x) == 0: - return 1 - return ftl.reduce(op.mul, x) + return 1 if len(x) == 0 else ftl.reduce(op.mul, x) diff --git a/src/obstools/phot/__init__.py b/src/obstools/phot/__init__.py index 30f10aa..d199acb 100644 --- a/src/obstools/phot/__init__.py +++ b/src/obstools/phot/__init__.py @@ -2,7 +2,7 @@ # create module level logger logbase = 'phot' -logname = '{}.{}'.format(logbase, __name__) +logname = f'{logbase}.{__name__}' logger = logging.getLogger(logname) #__name__ logger.setLevel(logging.DEBUG) diff --git a/src/obstools/phot/campaign.py b/src/obstools/phot/campaign.py index 6ed7356..e752ce5 100644 --- a/src/obstools/phot/campaign.py +++ b/src/obstools/phot/campaign.py @@ -215,17 +215,7 @@ def sampler(self): raise ValueError('Cannot create image sampler for data with ' f'{self.ndim} dimensions.') - # ensure NE orientation - data = self.oriented - - # make sure we pass 3d data to sampler. This is a hack so we can use - # the sampler to get thumbnails from data that is a 2d image, - # eg. master flats. The 'sample' will just be the image itself. - - if self.ndim == 2: - # insert axis in front - data = self.data[None] - + data = self.data[None] if self.ndim == 2 else self.oriented return BootstrapResample(data) @ftl.lru_cache() @@ -681,48 +671,6 @@ def coalign_dss(self, depth=10, sample_stat='median', reference_index=0, # _, better = imr.refine(plot=plot) return dss - # group observations by telescope / instrument - # groups, indices = self.group_by('telescope', 'instrument', - # return_index=True) - - # start with the group having the most observations - - # create data containers - # n = len(self) - # images = np.empty(n, 'O') - # params = np.empty((n, 3)) - # fovs = np.empty((n, 2)) - # coords = np.empty(n, 'O') - # ng = len(groups) - # aligned_on = np.empty(ng, int) - # matchers = np.empty(ng, 'O') - - # # For each image group, align images wrt each other - # # ensure that `params`, `fovs` etc maintains the same order as `self` - # for i, (gid, run) in enumerate(groups.items()): - # idx = indices[gid] - # m = matchers[i] = run.coalign(depth, sample_stat, plot=plot, - # **find_kws) - - # aligned_on[i] = idx[m.idx] - - # try: - - # # - # dss = ImageRegisterDSS(self[reference_index].coords, fov_dss, - # **find_kws) - - # for i, gid in enumerate(groups.keys()): - # mo = matchers[i] - # theta = self[aligned_on[i]].get_rotation() - # p = dss.match_points(mo.yx, mo.fov, theta) - # params[indices[gid]] += p - # except: - # from IPython import embed - # embed() - - return dss - def close(self): # close all files self.calls('_file.close') diff --git a/src/obstools/phot/diagnostics.py b/src/obstools/phot/diagnostics.py index 6174c68..948f39c 100644 --- a/src/obstools/phot/diagnostics.py +++ b/src/obstools/phot/diagnostics.py @@ -318,11 +318,9 @@ def scatter_density_grid(features, centres=None, axes=None, auto_lim_axes=False, def new_diagnostics(coords, rcoo, Appars, optstat): - figs = {} # coordinate diagnostics fig = plot_coord_moves(coords, rcoo) - figs['coords.moves'] = fig - + figs = {'coords.moves': fig} # fig = plot_coord_scatter(coords, rcoo) # figs['coords.scatter'] = fig # fig = plot_coord_walk(coords) @@ -734,20 +732,18 @@ def plot_aperture_flux(fitspath, proc, tracker): star_labels = list(map('{0:d}: ({1[1]:3.1f}, {1[0]:3.1f})'.format, tracker.segm.labels, tracker.rcoo)) - figs = { + return { 'lc.aps.opt': plot_lc(t, flux, flxStd, star_labels, '(Optimal)'), - 'lc.aps.bg': plot_lc(t, fluxBG, flxBGStd, star_labels, '(BG)') + 'lc.aps.bg': plot_lc(t, fluxBG, flxBGStd, star_labels, '(BG)'), } - return figs - def plot_lc(t, flux, flxStd, labels, description='', max_errorbars=200): logger.info('plotting lc aps: %s', description) # no more than 200 error bars so we don't clutter the plot error_every = flxStd.shape[1] // int(max_errorbars) - title = 'Aperture flux %s' % description + title = f'Aperture flux {description}' # plot with frame number at bottom t0 = t[0].to_datetime() diff --git a/src/obstools/phot/gui.py b/src/obstools/phot/gui.py index 4491461..f861e0e 100644 --- a/src/obstools/phot/gui.py +++ b/src/obstools/phot/gui.py @@ -96,8 +96,7 @@ def binary_contours(b): Z = g(X[:-1], Y[:-1]) gen = QuadContourGenerator(X[:-1], Y[:-1], Z, None, False, 0) - c = gen.create_contour(0) - return c + return gen.create_contour(0) class Ellipse(_Ellipse): @@ -158,7 +157,7 @@ def create_apertures(self, **props): props = propList[i] props.setdefault('animated', self.use_blit) - for j in range(self.n_groups): + for _ in range(self.n_groups): # color = next(self.) # props.update(ec=color) aps = kls[i](**props) @@ -218,10 +217,7 @@ def update(self, i, draw=True): draw_list.append(self.marks) appars = self.appars[i] - skypars = None - if self.skypars is not None: - skypars = self.skypars[i] - + skypars = self.skypars[i] if self.skypars is not None else None art = self.update_apertures(i, coo.T, appars, skypars) draw_list.append(art) @@ -299,10 +295,7 @@ def update(self, i, draw=True): draw_list.append(self.marks) appars = self.appars[i] - skypars = None - if self.skypars is not None: - skypars = self.skypars[i] - + skypars = self.skypars[i] if self.skypars is not None else None art = self.update_apertures(i, coo.T, appars, skypars) draw_list.append(art) @@ -388,13 +381,11 @@ def __init__(self, data, coords, tracker, mdlr, apdata, residu=None, **kws): self.connect() def _slider_move(self, x, y): - draw_list = [] - # art = self.markers + self.aps + (self.windows, ) - for mrk in flatten((self.markers, self.aps, self.windows)): - if mrk.get_visible(): - draw_list.append(mrk) - - return draw_list + return [ + mrk + for mrk in flatten((self.markers, self.aps, self.windows)) + if mrk.get_visible() + ] # def toggle_windows(self, label): @@ -411,12 +402,7 @@ def init_figure(self, **kws): # create axes if required if ax is None: - if autoscale_figure: - # automatically determine the figure size based on the data - figsize = self.guess_figsize(self.data) - else: - figsize = None - + figsize = self.guess_figsize(self.data) if autoscale_figure else None fig = plt.figure(figsize=figsize) self._gs = gs = GridSpec(2, 5, @@ -427,14 +413,14 @@ def init_figure(self, **kws): ax = fig.add_subplot(gs[0, :]) - # axes = self.init_axes(fig) + # axes = self.init_axes(fig) # else: # axes = namedtuple('AxesContainer', ('image',))(ax) self.divider = make_axes_locatable(ax) # ax = axes.image # set the axes title if given - if not title is None: + if title is not None: ax.set_title(title) # setup coordinate display @@ -579,10 +565,8 @@ def update(self, i, draw=True): # contour tracking regions if self.outlines is not None: - segments = [] off = trk.offset[::-1] - for seg in self.outlineData: - segments.append(seg + off) + segments = [seg + off for seg in self.outlineData] self.outlines.set_segments(segments) # if self.use_blit: @@ -750,10 +734,7 @@ def show_outlines(self, **kws): im = data[e[1, 0]:e[1, 1], e[0, 0]:e[0, 1]] contours = binary_contours(im) - for c in contours: - outlines.append(c + e[:, 0] - 0.5) - - + outlines.extend(c + e[:, 0] - 0.5 for c in contours) col = LineCollection(outlines, **kws) self.ax.add_collection(col) diff --git a/src/obstools/phot/proc.py b/src/obstools/phot/proc.py index 05cd33b..a23bac0 100644 --- a/src/obstools/phot/proc.py +++ b/src/obstools/phot/proc.py @@ -224,10 +224,7 @@ def std_ccd(counts, npix, counts_bg, npixbg): def opt_factory(p): - if len(p) == 1: - cls = CircleOptimizer - else: - cls = EllipseOptimizer + cls = CircleOptimizer if len(p) == 1 else EllipseOptimizer # initialize return cls() @@ -440,42 +437,37 @@ def catch(self, *args, **kws): """ # exceptions like moths to the flame abort = self.fail_counter.get_value() >= self.max_fail - if not abort: - try: - result = self.run(*args, **kws) - except Exception as err: - # logs full trace by default - i = args[0] - self.status[i] = self.FAIL - nfail = self.fail_counter.inc() - logger = logging.getLogger(self.name) - logger.exception('Processing failed at frame %i. (%i/%i)', - i, nfail, self.max_fail) - - # check if we are beyond exception threshold - if nfail >= self.max_fail: - logger.critical('Exception threshold reached!') - # self.logger.critical('Exception threshold reached!') - else: - i = args[0] - self.status[i] = self.SUCCESS - return result # finally clause executes before this returns - - finally: - # log progress - counter = self.counter - if counter: - n = counter.inc() - if self.progLog: - self.progLog.update(n) - - # if there was a KeyboardInterrupt, it will be raised at this point - else: + if abort: # doing this here (instead of inside the except clause) avoids # duplication by chained exception traceback when logging raise AbortCompute( 'Number of exceptions larger than threshold of %i' % self.max_fail) + try: + result = self.run(*args, **kws) + except Exception as err: + # logs full trace by default + i = args[0] + self.status[i] = self.FAIL + nfail = self.fail_counter.inc() + logger = logging.getLogger(self.name) + logger.exception('Processing failed at frame %i. (%i/%i)', + i, nfail, self.max_fail) + + # check if we are beyond exception threshold + if nfail >= self.max_fail: + logger.critical('Exception threshold reached!') + # self.logger.critical('Exception threshold reached!') + else: + i = args[0] + self.status[i] = self.SUCCESS + return result # finally clause executes before this returns + + finally: + if counter := self.counter: + n = counter.inc() + if self.progLog: + self.progLog.update(n) def report(self): # not_done, = np.where(self.status == 0) @@ -691,45 +683,36 @@ def optimal_aperture_photometry(self, i, data, residu, coords, tracker, p = r.x if flag != 1: # there was an error or no convergence - if prevr is not None and prevr.success: - # use bright star appars for faint stars (if available) if - # optimization failed for this group - p = prevr.x - else: - # no convergence for this opt or previous. fall back to p0 - p = p0 - + p = prevr.x if prevr is not None and prevr.success else p0 # update to fallback values opt.update(coords[ix], *p, sky_width, sky_buf, r_sky_min) skip_opt = True - # if fit didn't converge for bright stars, it won't for the - # fainter ones. save some time by skipping opt + # if fit didn't converge for bright stars, it won't for the + # fainter ones. save some time by skipping opt # get apertures aps, aps_sky = opt - if flag is not None: # ie. optimization was at least attempted - # save appars - if len(p) == 1: # circle - a, = b, = p - theta = 0 - a_sky_in = opt.ap_sky.r_in - a_sky_out = b_sky_out = opt.ap_sky.r_out - - else: # ellipse - a, b, theta = p - a_sky_in = opt.ap_sky.a_in - a_sky_out = opt.ap_sky.a_out - b_sky_out = opt.ap_sky.b_out - - else: + if flag is None: # no optimization attempted # use the radii, angle of the previous group for photometry on # remaining groups aps.positions = coords[ix] aps_sky.positions = coords[ix] + elif len(p) == 1: # circle + a, = b, = p + theta = 0 + a_sky_in = opt.ap_sky.r_in + a_sky_out = b_sky_out = opt.ap_sky.r_out + + else: # ellipse + a, b, theta = p + a_sky_in = opt.ap_sky.a_in + a_sky_out = opt.ap_sky.a_out + b_sky_out = opt.ap_sky.b_out + # save appars self._appars[i, g] = list(zip( (a, b, theta), diff --git a/src/obstools/phot/tracking/core.py b/src/obstools/phot/tracking/core.py index 9be013d..81777ad 100644 --- a/src/obstools/phot/tracking/core.py +++ b/src/obstools/phot/tracking/core.py @@ -285,20 +285,18 @@ def _measure_positions_offsets(xy, centres, d_cut=None): out_new = (d > d_cut) out_new = np.ma.getdata(out_new) | np.ma.getmask(out_new) - changed = (out != out_new).any() - if changed: - out = out_new - xym[out] = np.ma.masked - n_out = out.sum() + if not (changed := (out != out_new).any()): + break - if n_out / n_points > 0.5: - raise Exception('Too many outliers!!') + out = out_new + xym[out] = np.ma.masked + n_out = out.sum() - logger.info('Ignoring %i/%i (%.3f%%) values with |δr| > %.1f', - n_out, n_points, (n_out / n_points) * 100, d_cut) - else: - break + if n_out / n_points > 0.5: + raise Exception('Too many outliers!!') + logger.info('Ignoring %i/%i (%.3f%%) values with |δr| > %.1f', + n_out, n_points, (n_out / n_points) * 100, d_cut) return centres, xy_shifted.std(0), xy_offsets.squeeze(), out @@ -515,9 +513,7 @@ def use_labels(self, labels): self._ignore_labels = np.setdiff1d(self.segm.labels, labels) def resolve_labels(self, labels=None): - if labels is None: - return self.use_labels - return self.segm.has_labels(labels) + return self.use_labels if labels is None else self.segm.has_labels(labels) @property def nlabels(self): @@ -1319,11 +1315,10 @@ def track(self, index, image, mask=None): self.measurements[index] = yx # weights - if self.snr_weighting or self.snr_cut: - if (self._weights is None) or (count // self._update_weights_every): - self._weights = self.get_snr_weights(image) - # TODO: maybe aggregate weights and use mean ?? - + if (self.snr_weighting or self.snr_cut) and ( + (self._weights is None) or (count // self._update_weights_every) + ): + self._weights = self.get_snr_weights(image) # # update relative positions from CoM measures # check here if measurements of cluster centres are good enough! if ((count + 1) % self._update_rvec_every) == 0 and count > 0 and \ @@ -1422,7 +1417,7 @@ def measure_star_locations(self, image, mask=None, start_indices=(0, 0)): # '/media/Oceanus/UCT/Observing/data/Feb_2015/ # MASTER_J0614-2725/20150225.001.fits' : 23 - if not ((mask is None) or (self.masks.bad_pixels is None)): + if mask is not None and self.masks.bad_pixels is not None: mask |= self.masks.bad_pixels else: mask = self.masks.bad_pixels # may be None @@ -1434,7 +1429,7 @@ def measure_star_locations(self, image, mask=None, start_indices=(0, 0)): if np.ma.is_masked(start_indices): raise ValueError('Start indices cannot be masked array') - if not start_indices.dtype.kind == 'i': + if start_indices.dtype.kind != 'i': start_indices = start_indices.round().astype(int) yx = self._measure_star_locations(image, mask, start_indices) @@ -1450,8 +1445,6 @@ def measure_star_locations(self, image, mask=None, start_indices=(0, 0)): def _measure_star_locations(self, image, mask, start_indices): seg = self.segm.select_subset(start_indices, image.shape) - xy = seg.com_bg(image, self.use_labels, mask, None) - # note this is ~2x faster than padding image and mask. can probably # be made even faster by using __slots__ # todo: check if using grid + offset then com is faster @@ -1467,7 +1460,7 @@ def _measure_star_locations(self, image, mask, start_indices): # good = ~self.is_bad(com) # self.coms[i, good] = com[good] - return xy # + start_indices + return seg.com_bg(image, self.use_labels, mask, None) # def check_measurement(self, xy): @@ -1683,21 +1676,14 @@ def flux_sort(self, fluxes): def pprint(self): from motley.table import Table - # FIXME: not working - - # TODO: say what is being ignored - # TODO: include uncertainties - - # coo = [:, ::-1] # , dtype='O') - tbl = Table(self.rcoo_xy, - col_headers=list('xy'), - col_head_props=dict(bg='g'), - row_headers=self.use_labels, - # number_rows=True, - align='>', # easier to read when right aligned - ) - - return tbl + return Table( + self.rcoo_xy, + col_headers=list('xy'), + col_head_props=dict(bg='g'), + row_headers=self.use_labels, + # number_rows=True, + align='>', # easier to read when right aligned + ) def sdist(self): coo = self.rcoo @@ -1868,9 +1854,8 @@ def compute_offset(self, xy, weights): # shift calculated as snr weighted mean of individual CoM shifts xym = np.ma.MaskedArray(xy, np.isnan(xy)) δ = (self.rcoo[self.use_labels - 1] - xym) - offset = np.ma.average(δ, 0, weights) # this offset already relative to the global segmentation - return offset + return np.ma.average(δ, 0, weights) def update_rvec_point(self, coo, weights=None): # TODO: bayesian_update diff --git a/src/obstools/phot/utils.py b/src/obstools/phot/utils.py index a007b12..7205371 100644 --- a/src/obstools/phot/utils.py +++ b/src/obstools/phot/utils.py @@ -63,7 +63,7 @@ def create(self, end): def progress(self, state, info=None): if self.needs_update(state): bar = self.get_bar(state) - self.logger.info('Progress: %s' % bar) + self.logger.info(f'Progress: {bar}') # class ProgressPrinter(ProgressBar): diff --git a/src/obstools/plan/limits.py b/src/obstools/plan/limits.py index 77f5f56..7efd449 100644 --- a/src/obstools/plan/limits.py +++ b/src/obstools/plan/limits.py @@ -175,10 +175,7 @@ def get(self, where, which): def get_visible_ha(self, dec, where='both', which='hard'): if where == 'both': - lims = [] - for where_ in _EW: - lims.append(self.get_visible_ha(dec, where_, which)) - return lims + return [self.get_visible_ha(dec, where_, which) for where_ in _EW] where, which = _checks(where, which) pp = self.interpolators[where][which] @@ -189,11 +186,7 @@ def plot(self, ax=None, which='both', hard_kws=None, soft_kws=None, **kws): if ax is None: fig, ax = plt.subplots() - if which == 'both': - which = _HS - else: - which = _check(which, _HS) - + which = _HS if which == 'both' else _check(which, _HS) data = {} for key in which: east, west = LIMITS[key][self.tel].values() diff --git a/src/obstools/plan/skytracks.py b/src/obstools/plan/skytracks.py index 86c6695..9baa362 100644 --- a/src/obstools/plan/skytracks.py +++ b/src/obstools/plan/skytracks.py @@ -207,9 +207,7 @@ def sidereal_transform(t, longitude): def short_name(name): - if jparser.search(name): - return jparser.shorten(name) - return name + return jparser.shorten(name) if jparser.search(name) else name def set_visible(artists, state=True): @@ -474,7 +472,7 @@ def annotate(self, ax, **kws): # decide whether to add one label or two per segment labels = [] - for i, (i0, i1) in enumerate(zip(first, last)): + for i0, i1 in zip(first, last): # determine the slope and length of curve segments at all points # within axes x, yy = t[i0:i1 + 1], y[i0:i1 + 1] @@ -706,10 +704,7 @@ class SeczFormatter(TransFormatter): def __call__(self, x, pos=None): # ignore negative numbers (below horizon) - if (x < 0): - return '' - - return TransFormatter.__call__(self, x, pos) + return '' if (x < 0) else TransFormatter.__call__(self, x, pos) class Clock(LoggingMixin): diff --git a/src/obstools/rkcat.py b/src/obstools/rkcat.py index c929069..370d1a3 100644 --- a/src/obstools/rkcat.py +++ b/src/obstools/rkcat.py @@ -148,7 +148,7 @@ def iter_lines(filename): """ with open(filename, 'r') as fp: - for i, line in enumerate(fp): + for line in fp: if not line.startswith(('\n', '-')): yield line.strip('\n') @@ -200,7 +200,7 @@ def read_ascii(filename, mask_missing=False): n = len(lines) // 2 ncols = len(column_slices) * 2 # dtypes = tuple('U%s' % w for w in widths) - data = np.empty((n, ncols), 'U%s' % widths.max()) # somewhat inefficient + data = np.empty((n, ncols), f'U{widths.max()}') ix_odd = np.arange(1, ncols, 2) ix_even = np.arange(0, ncols, 2) @@ -355,7 +355,7 @@ def select_by_type(self, *types, exclude=()): # select_object_types # l = np.zeros(a2D.shape, bool) # f = np.zeros(len(tbl), 'U2') - for i, t in enumerate(types): + for t in types: match = (a2D == t) l[match.any(1)] = False diff --git a/src/obstools/transients/master.py b/src/obstools/transients/master.py index 48278cf..a41ca2e 100644 --- a/src/obstools/transients/master.py +++ b/src/obstools/transients/master.py @@ -165,8 +165,7 @@ def get_coords(name): try: return jparser.to_ra_dec_float(name) except ValueError as err: - match = RGX_BAD_JCOO.search(name) - if match: + if match := RGX_BAD_JCOO.search(name): prefix, hms, dms = np.split( RGX_BAD_JCOO.search(name).groups(), [1, 6]) warnings.warn(f'Bad value for arcseconds in name: {name!r}') @@ -179,15 +178,12 @@ def get_coords(name): def get_mag(m): - mo = RGX_NR.match(m) - if mo: - return float(m) - return -1 + return float(m) if (mo := RGX_NR.match(m)) else -1 def write_ascii(filename, data): col_names = list(data.dtype.fields.keys()) - col_names[0] = '# ' + col_names[0] + col_names[0] = f'# {col_names[0]}' data = np.vstack([col_names, data]).astype('U') fmt = list(map('%-{}s'.format, np.char.str_len(data).max(0))) @@ -207,10 +203,7 @@ def get_date(date): 'Mrt': 'Mar', 'Nar': 'Mar', 'Avg': 'Aug'}) - mfmt = '%b' - if len(month) > 3: - mfmt = '%B' - + mfmt = '%B' if len(month) > 3 else '%b' s = f'{yr} {month}' cur = day.strip('X') for mult in (24, 60, 60): diff --git a/src/obstools/utils.py b/src/obstools/utils.py index 9eb7e9a..789bec3 100644 --- a/src/obstools/utils.py +++ b/src/obstools/utils.py @@ -35,10 +35,7 @@ def int2tup(v): """wrap integer in a tuple""" - if isinstance(v, numbers.Integral): - return v, - else: - return tuple(v) + return (v, ) if isinstance(v, numbers.Integral) else tuple(v) # else: # raise ValueError('bad item %s of type %r' % (v, type(v))) @@ -183,8 +180,7 @@ def convert_skycoords(ra, dec): try: return SkyCoord(ra=ra, dec=dec, unit=('h', 'deg')) except ValueError: - logger.warning( - 'Could not interpret coordinates: %s; %s' % (ra, dec)) + logger.warning(f'Could not interpret coordinates: {ra}; {dec}') def retrieve_coords_ra_dec(name, verbose=True, **fmt): @@ -204,9 +200,7 @@ def retrieve_coords_ra_dec(name, verbose=True, **fmt): def ra_dec_string(coords, **kws): kws_ = dict(precision=2, sep=' ', pad=1) kws_.update(**kws) - return 'α = %s; δ = %s' % ( - coords.ra.to_string(unit='h', **kws_), - coords.dec.to_string(unit='deg', alwayssign=1, **kws_)) + return f"α = {coords.ra.to_string(unit='h', **kws_)}; δ = {coords.dec.to_string(unit='deg', alwayssign=1, **kws_)}" def get_skymapper_table(coords, bands, size=(10, 10)): @@ -259,8 +253,7 @@ def get_skymapper(coords, bands, size=(10, 10), combine=True, # retrieve data possibly from cache logger.info('Retrieving images...') - hdus = [_get_skymapper(url) for url in urls] - return hdus + return [_get_skymapper(url) for url in urls] @caches.to_file(skyCachePath) # memoize for performance @@ -309,9 +302,11 @@ def get_dss(server, ra, dec, size=(10, 10), epoch=2000): 'poss2ukstu_ir', 'quickv' ) # TODO: module scope ? - if not server in known_servers: - raise ValueError('Unknown server: %s. Please select from: %s' - % (server, str(known_servers))) + if server not in known_servers: + raise ValueError( + f'Unknown server: {server}. Please select from: {known_servers}' + ) + # resolve size h, w = size # FIXME: if number @@ -332,9 +327,7 @@ def get_dss(server, ra, dec, size=(10, 10), epoch=2000): with urllib.request.urlopen(url, params) as html: raw = html.read() - # parse error message - error = RGX_DSS_ERROR.search(raw) - if error: + if error := RGX_DSS_ERROR.search(raw): raise STScIServerError(error[1]) # log diff --git a/tests/image/test_registration.py b/tests/image/test_registration.py index 38703cd..a4ea37c 100644 --- a/tests/image/test_registration.py +++ b/tests/image/test_registration.py @@ -33,7 +33,7 @@ def make_id(name, n): # dynamically generate some simple fixtures for combinatorial tests -for name, params in dict( +for name in dict( xy=[((5, 10), (5, 6), (8, 2))], sigmas=[1, (1, 0.5), @@ -45,7 +45,7 @@ def make_id(name, n): amplitudes=[10, (1, 2, 3) ] -).items(): +): exec(textwrap.dedent( f""" @pytest.fixture(params=params, ids=make_id(name, len(params))) diff --git a/tests/image/test_segmentation.py b/tests/image/test_segmentation.py index b76ad18..ec734df 100644 --- a/tests/image/test_segmentation.py +++ b/tests/image/test_segmentation.py @@ -127,10 +127,8 @@ def test_add_segments(): def test_trace_contours(): from scipy import ndimage - tests = [] d = ndimage.distance_transform_edt(np.ones((15, 15))) - tests.append(d > 5) - + tests = [d > 5] z = np.square(np.indices((15, 15)) - 7.5).sum(0) < 7.5 z[9, 4] = 1 tests.append(z) @@ -139,7 +137,7 @@ def test_trace_contours(): z[3:5, 2] = 1 tests.append(z) - for i, t in enumerate(tests): + for t in tests: boundary = trace_boundary(t) seg = SegmentedImage(t)