diff --git a/.flake8 b/.flake8 index 563ac436..8031aef4 100644 --- a/.flake8 +++ b/.flake8 @@ -2,7 +2,7 @@ # Keep in sync with setup.cfg and pyproject.toml which is used for source packages. [flake8] -ignore = W503, E203, B950, B011, B904 +ignore = W503, E203, B950, B011, B904, B907, B905 max-line-length = 100 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6e2f37d5..f86cb7ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.6", "3.7", "3.8"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 @@ -43,7 +43,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.6", "3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 @@ -72,7 +72,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.6", "3.7", "3.8"] # there are some issues with numpy multiarray in 3.7 + python-version: ["3.7", "3.8", "3.9", "3.10"] # there are some issues with numpy multiarray in 3.7 steps: - uses: actions/checkout@v2 @@ -83,15 +83,13 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install Dependencies - # there are some issues with numpy multiarray in 3.7, affecting numba 0.54 installation - # SimpleITK 2.1.0 does not support non-orthonormal directions run: | python -m pip install --upgrade pip - pip install numba==0.53.1 - pip install tensorflow==2.4.1 keras + pip install numba + pip install tensorflow>=2.0.0 pip install torch pip install sigpy - pip install --upgrade simpleitk==2.0.2 + pip install --upgrade simpleitk make dev pip install -e '.[dev]' @@ -112,6 +110,7 @@ jobs: make test-cov - name: Upload to codecov.io + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository uses: codecov/codecov-action@v1 with: file: ./coverage.xml diff --git a/.gitignore b/.gitignore index 07f4018b..5ee133c6 100644 --- a/.gitignore +++ b/.gitignore @@ -54,4 +54,13 @@ docs/source/generated # coverage files .coverage -coverage.xml \ No newline at end of file +coverage.xml + +# system files +.DS_Store + +#nipype logfiles +stdout.nipype + +# ignore "scratch" files used for testing / debugging +**/scratch \ No newline at end of file diff --git a/Makefile b/Makefile index d22bceb6..8c1c18f7 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,7 @@ build-docs: cd docs && make html dev: + pip install pytest pip install black==21.4b2 click==8.0.2 coverage isort flake8 flake8-bugbear flake8-comprehensions pip install --upgrade mistune==0.8.4 sphinx sphinx-rtd-theme recommonmark m2r2 pip install -r docs/requirements.txt diff --git a/README.md b/README.md index 7e4c81f7..9cb8f47b 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ DOSMA is an AI-powered Python library for medical image analysis. This includes, We hope that this open-source pipeline will be useful for quick anatomy/pathology analysis and will serve as a hub for adding support for analyzing different anatomies and scan sequences. ## Installation -DOSMA requires Python 3.6+. The core module depends on numpy, nibabel, nipype, +DOSMA requires Python 3.7+. The core module depends on numpy, nibabel, nipype, pandas, pydicom, scikit-image, scipy, PyYAML, and tqdm. Additional AI features can be unlocked by installing tensorflow and keras. To @@ -30,6 +30,24 @@ pip install dosma pip install dosma[ai] ``` +If you would like to install from source - to get the latest versions, or to build for your specific system: + +If you would like to install from source then use: + +```bash +git clone git@github.com:ad12/DOSMA.git +cd DOSMA +pip install . +``` + +To get AI support: + +```bash +git clone git@github.com:ad12/DOSMA.git +cd DOSMA +pip install '.[ai]' +``` + If you would like to contribute to DOSMA, we recommend you clone the repository and install DOSMA with `pip` in editable mode. @@ -41,6 +59,10 @@ make dev ``` To run tests, build documentation and contribute, run +- If Elastix not installed, must indicate this to system by running the following on the commandline: + `export "DOSMA_UNITTEST_DISABLE_ELASTIX=True"` +- If test data not available add: + `export ""DOSMA_UNITTEST_DISABLE_DATA"=True"` ```bash make autoformat test build-docs ``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 326ea994..83608d78 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -h5py<3.0.0 +h5py numpy natsort nested-lookup diff --git a/docs/source/figures/DOSMA_bone_segs.png b/docs/source/figures/DOSMA_bone_segs.png new file mode 100644 index 00000000..34844064 Binary files /dev/null and b/docs/source/figures/DOSMA_bone_segs.png differ diff --git a/dosma/cli.py b/dosma/cli.py index 35f3e35c..61fbb67a 100644 --- a/dosma/cli.py +++ b/dosma/cli.py @@ -33,8 +33,8 @@ from dosma.core.io.format_io import ImageDataFormat from dosma.core.quant_vals import QuantitativeValueType as QV from dosma.defaults import preferences +from dosma.models.model_loading_util import SUPPORTED_MODELS, get_model, model_from_config from dosma.models.seg_model import SegModel -from dosma.models.util import SUPPORTED_MODELS, get_model, model_from_config from dosma.msk import knee from dosma.scan_sequences.mri.cones import Cones from dosma.scan_sequences.mri.cube_quant import CubeQuant @@ -134,7 +134,8 @@ def __init__( if (dicom_path is not None) and (not os.path.isdir(dicom_path)): if load_path is not None: warnings.warn( - "Dicom_path {} not found. Will load data from {}".format(dicom_path, load_path) + "Dicom_path {} not found. Will load data from {}".format(dicom_path, load_path), + stacklevel=2, ) else: raise NotADirectoryError("{} is not a directory".format(dicom_path)) diff --git a/dosma/core/fitting.py b/dosma/core/fitting.py index 2dedccdc..9aaea588 100644 --- a/dosma/core/fitting.py +++ b/dosma/core/fitting.py @@ -69,7 +69,8 @@ def _format_out_ufuncs(self, _out_ufuncs, _func_nparams): if len(_out_ufuncs) > _func_nparams: warnings.warn( f"len(out_ufuncs)={len(_out_ufuncs)}, but only {_func_nparams} parameters. " - f"Extra ufuncs will be ignored." + f"Extra ufuncs will be ignored.", + stacklevel=2, ) return _out_ufuncs @@ -174,7 +175,7 @@ def fit(self, x, y: Sequence[MedicalVolume], mask=None, copy_headers: bool = Tru svs = [] if (not isinstance(y, (list, tuple))) or ( - not all([isinstance(_y, MedicalVolume) for _y in y]) + not all(isinstance(_y, MedicalVolume) for _y in y) ): raise TypeError("`y` must be sequence of MedicalVolumes.") @@ -647,7 +648,8 @@ def __init__( if y is not None: warnings.warn( f"Setting `y` in the constructor can result in significant memory overhead. " - f"Specify `y` in `{type(self).__name__}.fit(y=...)` instead." + f"Specify `y` in `{type(self).__name__}.fit(y=...)` instead.", + stacklevel=2, ) self._check_y(x, y) self.y = y @@ -655,7 +657,8 @@ def __init__( if mask is not None: warnings.warn( f"Setting `mask` in the constructor can result in significant memory overhead. " - f"Specify `mask` in `{type(self).__name__}.fit(mask=...)` instead." + f"Specify `mask` in `{type(self).__name__}.fit(mask=...)` instead.", + stacklevel=2, ) self.mask = mask @@ -739,7 +742,7 @@ def fit(self, x=None, y: Sequence[MedicalVolume] = None, mask=None): return tc_map, r_squared def _check_y(self, x, y): - if (not isinstance(y, Sequence)) or (not all([isinstance(sv, MedicalVolume) for sv in y])): + if (not isinstance(y, Sequence)) or (not all(isinstance(sv, MedicalVolume) for sv in y)): raise TypeError("`y` must be list of MedicalVolumes.") if any(x.device != cpu_device for x in y): @@ -844,7 +847,9 @@ def curve_fit( oob = y_bounds is not None and ((y < y_bounds[0]).any() or (y > y_bounds[1]).any()) if oob: - warnings.warn("Out of bounds values found. Failure in fit will result in np.nan") + warnings.warn( + "Out of bounds values found. Failure in fit will result in np.nan", stacklevel=2 + ) y_T = y.T if p0_seq: @@ -963,7 +968,9 @@ def _compute_r2_matrix(_x, _y, _popts): oob = y_bounds is not None and ((y < y_bounds[0]).any() or (y > y_bounds[1]).any()) if oob: - warnings.warn("Out of bounds values found. Failure in fit will result in np.nan") + warnings.warn( + "Out of bounds values found. Failure in fit will result in np.nan", stacklevel=2 + ) fitter = partial( _polyfit, x=x, deg=deg, y_bounds=y_bounds, rcond=rcond, w=w, eps=eps, xp=xp.__name__ @@ -1137,7 +1144,7 @@ def _format_p0(p0, param_args, N): f"`p0` has unknown keys: {extra_keys}. " f"Function signature has parameters {param_args}." ) - p0_default = {p: 1.0 for p in param_args} + p0_default = {p: 1.0 for p in param_args} # noqa C420 # Update p0_default to keep the parameter keys in order. p0_default.update(p0) p0 = p0_default @@ -1170,6 +1177,7 @@ def func(t, a, b): "__fit_mono_exp__ is deprecated since v0.0.12 and will no longer be " "supported in v0.0.13. Use `curve_fit` instead.", DeprecationWarning, + stacklevel=2, ) x = np.asarray(x) @@ -1191,6 +1199,7 @@ def __fit_monoexp_tc__(x, ys, tc0, show_pbar=False): # pragma: no cover "__fit_monoexp_tc__ is deprecated since v0.12 and will no longer be " "supported in v0.13. Use `curve_fit` instead.", DeprecationWarning, + stacklevel=2, ) p0 = (1.0, -1 / tc0) @@ -1203,7 +1212,8 @@ def __fit_monoexp_tc__(x, ys, tc0, show_pbar=False): # pragma: no cover if (y < 0).any() and not warned_negative: warned_negative = True warnings.warn( - "Negative values found. Failure in monoexponential fit will result in np.nan" + "Negative values found. Failure in monoexponential fit will result in np.nan", + stacklevel=2, ) # Skip any negative values or all values that are 0s diff --git a/dosma/core/io/dicom_io.py b/dosma/core/io/dicom_io.py index 3ea5d82b..d4c51991 100644 --- a/dosma/core/io/dicom_io.py +++ b/dosma/core/io/dicom_io.py @@ -262,7 +262,7 @@ def load( ) if self.num_workers: - fn = functools.partial(pydicom.read_file, force=True) + fn = functools.partial(pydicom.dcmread, force=True) if self.verbose: dicom_slices = process_map(fn, lstFilesDCM, max_workers=self.num_workers) else: @@ -270,7 +270,7 @@ def load( dicom_slices = p.map(fn, lstFilesDCM) else: dicom_slices = [ - pydicom.read_file(fp, force=True) + pydicom.dcmread(fp, force=True) for fp in tqdm(lstFilesDCM, disable=not self.verbose) ] diff --git a/dosma/core/med_volume.py b/dosma/core/med_volume.py index 3ac2ee49..3a093b3f 100644 --- a/dosma/core/med_volume.py +++ b/dosma/core/med_volume.py @@ -394,6 +394,7 @@ def match_orientation(self, mv): "`match_orientation` is deprecated and will be removed in v0.1. " "Use `mv.reformat_as(self, inplace=True)` instead.", DeprecationWarning, + stacklevel=2, ) if not isinstance(mv, MedicalVolume): raise TypeError("`mv` must be a MedicalVolume.") @@ -410,6 +411,7 @@ def match_orientation_batch(self, mvs): # pragma: no cover "`match_orientation_batch` is deprecated and will be removed in v0.1. " "Use `[x.reformat_as(self, inplace=True) for x in mvs]` instead.", DeprecationWarning, + stacklevel=2, ) for mv in mvs: self.match_orientation(mv) @@ -494,26 +496,41 @@ def to_nib(self): return nib.Nifti1Image(self.A, self.affine.copy()) - def to_sitk(self, vdim: int = None, transpose_inplane: bool = False): + def to_sitk( + self, + image_orientation: str = "sagittal", + vdim: int = None, + transpose_inplane: bool = False, + flip_array_x: bool = False, + flip_array_y: bool = False, + flip_array_z: bool = False, + ): """Converts to SimpleITK Image. + SimpleITK loads DICOM files as individual slices that get stacked in ``(x, y, z)`` + order where z is the slice dimension and x/y are the in-plane image dimensions. + However, ``sitk.GetArrayFromImage`` returns numpy arrays with the order + ``(z, y, x)``. Converting to a SimpleITK image using ``sitk.GetImagrFromArray`` + also does this conversion (from z, y, x => x, y, z). This has to do with C++ and + numpy memory mapping. To facilitate converting MedicalVolume to SimpleITK Image, + the image_orientation must be specified. Be defauly, we assume sagittal orientation. + SimpleITK Image objects support vector pixel types, which are represented as an extra dimension in numpy arrays. The vector dimension can be specified with ``vdim``. MedicalVolume must be on cpu. Use ``self.cpu()`` to move. - SimpleITK loads DICOM files as individual slices that get stacked in ``(z, x, y)`` - order. Thus, ``sitk.GetArrayFromImage`` returns an array in ``(y, x, z)`` order. - To return a SimpleITK Image that will follow this convention, set - ``transpose_inplace=True``. If you have been using SimpleITK to load DICOM files, - you will likely want to specify this parameter. - Args: + image_orientation (str, optional): The image orientation. Options are + ``"sagittal"``, ``"coronal"``, and ``"axial"``. Default is ``"sagittal"``. vdim (int, optional): The vector dimension. transpose_inplane (bool, optional): If ``True``, transpose inplane axes. Recommended to be ``True`` for users who are familiar with SimpleITK's DICOM loading convention. + flip_array_x (bool, optional): If ``True``, flip array along x-axis. + flip_array_y (bool, optional): If ``True``, flip array along y-axis. + flip_array_z (bool, optional): If ``True``, flip array along z-axis. Returns: SimpleITK.Image @@ -531,6 +548,15 @@ def to_sitk(self, vdim: int = None, transpose_inplane: bool = False): if device != cpu_device: raise RuntimeError(f"MedicalVolume must be on cpu, got {self.device}") + if image_orientation.lower() == "sagittal": + self.reformat(("LR", "SI", "AP"), inplace=True) + elif image_orientation.lower() == "coronal": + self.reformat(("AP", "SI", "RL"), inplace=True) + elif image_orientation.lower() == "axial": + self.reformat(("IS", "AP", "RL"), inplace=True) + else: + raise ValueError(f"Unsupported image orientation: {image_orientation}") + arr = self.volume ndim = arr.ndim @@ -538,15 +564,17 @@ def to_sitk(self, vdim: int = None, transpose_inplane: bool = False): if vdim < 0: vdim = ndim + vdim axes = tuple(i for i in range(ndim) if i != vdim)[::-1] + (vdim,) - else: - axes = range(ndim)[::-1] - arr = np.transpose(arr, axes) + arr = np.transpose(arr, axes) affine = self.affine.copy() affine[:2] = -affine[:2] # RAS+ -> LPS+ - origin = tuple(affine[:3, 3]) - spacing = self.pixel_spacing + # Swap columns to adjust from (z, y, x) -> (x, y, z) for SimpleITK + affine[:, :3] = affine[:, [2, 1, 0]] # This swaps the columns + + # Adjust origin and spacing for SimpleITK's expected order + origin = tuple(affine[:3, 3]) # origin stays the same. + spacing = self.pixel_spacing[::-1] # Swap x, y, z -> z, y, x direction = affine[:3, :3] / np.asarray(spacing) img = sitk.GetImageFromArray(arr, isVector=vdim is not None) @@ -554,6 +582,10 @@ def to_sitk(self, vdim: int = None, transpose_inplane: bool = False): img.SetSpacing(spacing) img.SetDirection(tuple(direction.flatten())) + if any([flip_array_x, flip_array_y, flip_array_z]): + flip_logic = [flip_array_x, flip_array_y, flip_array_z] + img = sitk.Flip(img, flip_logic) + if transpose_inplane: pa = sitk.PermuteAxesImageFilter() pa.SetOrder([1, 0, 2]) @@ -601,7 +633,7 @@ def to_torch( [[1., 1.], [1., 1.]]], device="cuda:0", dtype=torch.float64) >>> # view complex array as real tensor - >>> mv = MedicalVolume(np.ones((3,4,5), dtype=np.complex), np.eye(4)) + >>> mv = MedicalVolume(np.ones((3,4,5), dtype=complex), np.eye(4)) >>> tensor = mv.to_torch(view_as_real) >>> tensor.shape (3, 4, 5, 2) @@ -718,7 +750,8 @@ def set_metadata(self, key, value, force: bool = False): self._headers = self._validate_and_format_headers([pydicom.Dataset()]) warnings.warn( "Headers were generated and may not contain all attributes " - "required to save the volume in DICOM format." + "required to save the volume in DICOM format.", + stacklevel=2, ) VR_registry = {float: "DS", int: "IS", str: "LS"} @@ -994,11 +1027,10 @@ def from_sitk(cls, image, copy=False, transpose_inplane: bool = False) -> "Medic spacing = image.GetSpacing() direction = np.asarray(image.GetDirection()).reshape(-1, 3) - affine = np.zeros((4, 4)) + affine = np.eye(4) affine[:3, :3] = direction * np.asarray(spacing) affine[:3, 3] = origin affine[:2] = -affine[:2] # LPS+ -> RAS+ - affine[3, 3] = 1 return cls(arr, affine) @@ -1067,7 +1099,7 @@ def from_torch(cls, tensor, affine, headers=None, to_complex: bool = None) -> "M torch_version = env.get_version(torch) supports_cplx = version.Version(torch_version) >= _TORCH_COMPLEX_SUPPORT_VERSION - # Check if tensor needs to be converted to np.complex type. + # Check if tensor needs to be converted to `complex` type. # If tensor is of torch.complex64 or torch.complex128 dtype, then from_numpy will take # care of conversion to appropriate numpy dtype, and we do not need to do the to_complex # logic. @@ -1243,7 +1275,7 @@ def __getitem__(self, _slice): _slice_headers.append(0) else: _slice_headers.append(x) - headers = headers[_slice_headers] + headers = headers[tuple(_slice_headers)] affine = slicer.slice_affine(_slice) return self._partial_clone(volume=volume, affine=affine, headers=headers) diff --git a/dosma/core/numpy_routines.py b/dosma/core/numpy_routines.py index 74f16813..7dc4745c 100644 --- a/dosma/core/numpy_routines.py +++ b/dosma/core/numpy_routines.py @@ -161,7 +161,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): return x._partial_clone(volume=vol) -@implements(np.around, np.round, np.round_) +@implements(np.around, np.round) def around(x, decimals=0, affine=False): """Round medical image pixel data (and optionally affine) to the given number of decimals. @@ -367,12 +367,13 @@ def concatenate(xs, axis: int = -1): else: headers = np.concatenate(headers, axis=axis) if headers.ndim != volume.ndim or any( - [hs != 1 and hs != vs for hs, vs in zip(headers.shape, volume.shape)] + hs != 1 and hs != vs for hs, vs in zip(headers.shape, volume.shape) ): warnings.warn( "Got invalid headers shape ({}) given concatenated output shape ({}). " "Expected header dimensions to be 1 or same as volume dimension for all axes. " - "Dropping all headers in concatenated output.".format(volume.shape, headers.shape) + "Dropping all headers in concatenated output.".format(volume.shape, headers.shape), + stacklevel=2, ) headers = None diff --git a/dosma/core/orientation.py b/dosma/core/orientation.py index 5883e8b4..340902fa 100644 --- a/dosma/core/orientation.py +++ b/dosma/core/orientation.py @@ -304,7 +304,7 @@ def _format_numbers(input, default_val, name, expected_num): end_ornt = nibo.axcodes2ornt(orientation_standard_to_nib(orientation)) ornt = nibo.ornt_transform(start_ornt, end_ornt) - transpose_idxs = ornt[:, 0].astype(np.int) + transpose_idxs = ornt[:, 0].astype(int) flip_idxs = ornt[:, 1] affine[:3] = affine[:3][transpose_idxs] diff --git a/dosma/core/quant_vals.py b/dosma/core/quant_vals.py index 7909feb7..0e6d1730 100644 --- a/dosma/core/quant_vals.py +++ b/dosma/core/quant_vals.py @@ -92,7 +92,8 @@ def save_data( warnings.warn( "Due to bit depth issues, only nifti format is supported for quantitative values. " - "Writing as nifti file..." + "Writing as nifti file...", + stacklevel=2, ) data_format = ImageDataFormat.nifti diff --git a/dosma/core/registration.py b/dosma/core/registration.py index b87409db..48d43cbe 100644 --- a/dosma/core/registration.py +++ b/dosma/core/registration.py @@ -328,7 +328,8 @@ def symlink_elastix(path: str = None, lib_only: bool = True, force: bool = False assert system in ["windows", "darwin", "linux"] if system != "darwin": warnings.warn( - f"Symlinking elastix/transformix paths not recommended for {system} " f"machines" + f"Symlinking elastix/transformix paths not recommended for {system} " f"machines", + stacklevel=2, ) if path is None: diff --git a/dosma/defaults.py b/dosma/defaults.py index 8e5aacdf..eabb7a91 100644 --- a/dosma/defaults.py +++ b/dosma/defaults.py @@ -85,7 +85,8 @@ def _merge_dicts(self, target, base, prefix=""): elif k not in target_keys and k in base_keys: warnings.warn( "Your preferences file may be outdated. " - "Run `preferences.save()` to save your updated preferences." + "Run `preferences.save()` to save your updated preferences.", + stacklevel=2, ) target[k] = base[k] else: @@ -199,7 +200,7 @@ def set(self, key: str, value: Any, prefix: str = ""): val, subdict = self.__get(self.__config, key, prefix) # type of new value has to be the same type as old value - if type(value) != type(val): + if not isinstance(value, type(val)): try: value = type(val)(value) except (ValueError, TypeError): diff --git a/dosma/models/__init__.py b/dosma/models/__init__.py index 519d69cc..c45b759d 100644 --- a/dosma/models/__init__.py +++ b/dosma/models/__init__.py @@ -1,9 +1,13 @@ -from dosma.models import oaiunet2d, stanford_qdess, util +from dosma.models import model_loading_util, oaiunet2d, stanford_qdess, stanford_qdess_bone from dosma.models.oaiunet2d import * # noqa from dosma.models.stanford_qdess import * # noqa: F401, F403 +from dosma.models.stanford_qdess_bone import * # noqa: F401, F403 +from dosma.models.stanford_cube_bone import * # noqa: F401, F403 +from dosma.models.model_loading_util import * # noqa from dosma.models.util import * # noqa __all__ = [] -__all__.extend(util.__all__) +__all__.extend(model_loading_util.__all__) __all__.extend(oaiunet2d.__all__) __all__.extend(stanford_qdess.__all__) +__all__.extend(stanford_qdess_bone.__all__) diff --git a/dosma/models/model_loading_util.py b/dosma/models/model_loading_util.py new file mode 100644 index 00000000..55c73109 --- /dev/null +++ b/dosma/models/model_loading_util.py @@ -0,0 +1,94 @@ +""" +Functions for loading Keras models + +@author: Arjun Desai + (C) Stanford University, 2019 +""" +import os +import yaml +from functools import partial +from typing import Sequence + +from dosma.models.oaiunet2d import IWOAIOAIUnet2D, IWOAIOAIUnet2DNormalized, OAIUnet2D +from dosma.models.seg_model import SegModel + +__all__ = ["get_model", "SUPPORTED_MODELS"] + +# Network architectures currently supported +__SUPPORTED_MODELS__ = [OAIUnet2D, IWOAIOAIUnet2D, IWOAIOAIUnet2DNormalized] + +# Initialize supported models for the command line +SUPPORTED_MODELS = [x.ALIASES[0] for x in __SUPPORTED_MODELS__] + + +def get_model(model_str, input_shape, weights_path, **kwargs): + """Get a Keras model + :param model_str: model identifier + :param input_shape: tuple or list of tuples for initializing input(s) into Keras model + :param weights_path: filepath to weights used to initialize Keras model + :return: a Keras model + """ + for m in __SUPPORTED_MODELS__: + if model_str in m.ALIASES or model_str == m.__name__: + return m(input_shape, weights_path, **kwargs) + + raise LookupError("%s model type not supported" % model_str) + + +def model_from_config(cfg_file_or_dict, weights_dir=None, **kwargs) -> SegModel: + """Builds a new model from a config file. + + This function is useful for building models that have similar structure/architecture + to existing models supported in DOSMA, but have different weights and categories. + The config defines what dosma model should be used as a base, what weights should be loaded, + and what are the categories. + + The config file should be a yaml file that has the following keys: + * "DOSMA_MODEL": The base model that exists in DOSMA off of which data should be built. + * "CATEGORIES": The categories that are supposed to be loaded. + * "WEIGHTS_FILE": The basename of (or full path to) weights that should be loaded. + + Args: + cfg_file_or_dict (str or dict): The yaml file or dictionary corresponding to the config. + weights_dir (str): The directory where weights are stored. If not specified, assumes + "WEIGHTS_FILE" field in the config is the full path to the weights. + **kwargs: Keyword arguments for base model `__init__` + + Returns: + SegModel: A segmentation model with appropriate changes to `generate_mask` to produce + the right masks. + """ + + def _gen_mask(func, *_args, **_kwargs): + out = func(*_args, **_kwargs) + if isinstance(out, dict): + # Assumes that the dict is ordered, which it is for python>=3.6 + out = out.values() + elif not isinstance(out, Sequence): + out = [out] + if not len(categories) == len(out): + raise ValueError("Got {} outputs, but {} categories".format(len(out), len(categories))) + return dict(zip(categories, out)) + + if isinstance(cfg_file_or_dict, str): + with open(cfg_file_or_dict, "r") as f: + cfg = yaml.load(f) + else: + cfg = cfg_file_or_dict + + base_model = cfg["DOSMA_MODEL"] + categories = cfg["CATEGORIES"] + weights = cfg["WEIGHTS_FILE"] + if not os.path.isfile(weights): + assert weights_dir, "`weights_dir` must be specified" + weights = os.path.join(weights_dir, cfg["WEIGHTS_FILE"]) + + try: + model: SegModel = get_model(base_model, weights_path=weights, force_weights=True, **kwargs) + except LookupError as e: + raise LookupError("BASE_MODEL '{}' not supported \n{}".format(base_model, e)) + + # Override generate mask function + model.generate_mask = partial(_gen_mask, model.generate_mask) + + return model diff --git a/dosma/models/oaiunet2d.py b/dosma/models/oaiunet2d.py index 9910d47e..5a919f7e 100644 --- a/dosma/models/oaiunet2d.py +++ b/dosma/models/oaiunet2d.py @@ -6,9 +6,16 @@ import numpy as np try: - from keras.layers import BatchNormalization as BN - from keras.layers import Concatenate, Conv2D, Conv2DTranspose, Dropout, Input, MaxPooling2D - from keras.models import Model + from tensorflow.keras.layers import BatchNormalization as BN + from tensorflow.keras.layers import ( + Concatenate, + Conv2D, + Conv2DTranspose, + Dropout, + Input, + MaxPooling2D, + ) + from tensorflow.keras.models import Model _SUPPORTS_KERAS = True except ImportError: # pragma: no-cover @@ -17,6 +24,7 @@ from dosma.core.med_volume import MedicalVolume from dosma.core.orientation import SAGITTAL from dosma.models.seg_model import KerasSegModel, whiten_volume +from dosma.models.util import get_tensor_shape_as_list __all__ = ["OAIUnet2D", "IWOAIOAIUnet2D", "IWOAIOAIUnet2DNormalized"] @@ -89,7 +97,9 @@ def __load_keras_model__(self, input_shape, force_weights=False): if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool - xres = conv.shape.as_list()[1] + shape_ = get_tensor_shape_as_list(conv) + xres = shape_[1] + if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: @@ -100,7 +110,7 @@ def __load_keras_model__(self, input_shape, force_weights=False): # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): - deconv_shape = conv_ptr[depth_cnt].shape.as_list() + deconv_shape = get_tensor_shape_as_list(conv_ptr[depth_cnt]) deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv @@ -234,7 +244,8 @@ def __load_keras_model__(self, input_shape): if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool - xres = conv.shape.as_list()[1] + shape_ = get_tensor_shape_as_list(conv) + xres = shape_[1] if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: @@ -245,7 +256,7 @@ def __load_keras_model__(self, input_shape): # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): - deconv_shape = conv_ptr[depth_cnt].shape.as_list() + deconv_shape = get_tensor_shape_as_list(conv_ptr[depth_cnt]) deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv diff --git a/dosma/models/seg_model.py b/dosma/models/seg_model.py index 7da9fec9..afb6c02d 100644 --- a/dosma/models/seg_model.py +++ b/dosma/models/seg_model.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod import numpy as np +from scipy.ndimage import binary_fill_holes from dosma.core.med_volume import MedicalVolume from dosma.defaults import preferences +from dosma.tissues.tissue import largest_cc try: - import keras.backend as K + import tensorflow.keras.backend as K except ImportError: # pragma: no cover pass @@ -125,3 +127,53 @@ def whiten_volume(x: np.ndarray, eps: float = 0.0): raise ValueError(f"Input has {x.ndims} dimensions. Expected {__VOLUME_DIMENSIONS__}") return (x - np.mean(x)) / (np.std(x) + eps) + + +# ============================ Postprocessing utils ============================ +def get_connected_segments(mask: np.ndarray) -> np.ndarray: + """ + Get the connected segments in segmentation mask + + Args: + mask (np.ndarray): 3D volume of all segmented tissues. + + Returns: + np.ndarray: 3D volume with only the connected segments. + """ + + unique_tissues = np.unique(mask) + + mask_ = np.zeros_like(mask) + + for idx in unique_tissues: + if idx == 0: + continue + mask_binary = np.zeros_like(mask) + mask_binary[mask == idx] = 1 + mask_binary_conected = np.asarray(largest_cc(mask_binary), dtype=np.uint8) + mask_[mask_binary_conected == 1] = idx + + return mask_ + + +def fill_holes(mask: np.ndarray, label_idx=None) -> np.ndarray: + """ + Fill holes in binary mask. + + Args: + mask (np.ndarray): 3D volume of all segmented tissues. + label_idx (int, optional): Label index to fill holes in mask. + + Returns: + np.ndarray: 3D volume with holes filled. + """ + + if label_idx is None: + assert len(np.unique(mask) == 2), "Mask should be binary" + mask_ = mask.copy() + else: + mask = (mask == label_idx).copy() + + mask_ = binary_fill_holes(mask) + + return mask_ diff --git a/dosma/models/stanford_cube_bone.py b/dosma/models/stanford_cube_bone.py new file mode 100644 index 00000000..6b9b9a89 --- /dev/null +++ b/dosma/models/stanford_cube_bone.py @@ -0,0 +1,63 @@ +from dosma.models.stanford_qdess_bone import StanfordQDessBoneUNet2D + +__all__ = ["StanfordCubeBoneUNet2D"] + + +class StanfordCubeBoneUNet2D(StanfordQDessBoneUNet2D): + """ + This model segments bones (femur, tibia, patella) from cube knee scans. + + There are a few weights files that are associated with this model. + We provide a short description of each below: + + * ````: This is the baseline + model trained on a subset of the K2S MRI reconstruction challenge + hosted by UCSF using the 2D network from Gatti et al. MAGMA, 2021. + + By default this class will resample the input to be the size of the trained + model (512x512) for segmentation and then will re-sample the outputted + segmentation to match the original volume. + + By default, we return the largest connected component of each tissue. This + can be disabled by setting `connected_only=False` in the `model.generate_mask()`. + + By default, we fill holes in bone segmentations. This can be disabled by setting + `fill_bone_holes=False` in the `model.generate_mask()`. + + The output includes individual objects for each segmented tissue. It also + includes a combined label for all of the tissues in a single 3D mask. + + Examples: + + >>> # Create model. + >>> model = StanfordCubeBoneUNet2D("/path/to/model.h5") + + >>> # Generate mask from medical volume. + >>> model.generate_mask(medvol) + + >>> # Generate mask from medical volume without getting largest connected components. + >>> model.generate_mask(medvol, connected_only=False) + + """ + + ALIASES = ("stanford-cube-2024-unet2d-bone", "k2s-unet2d-bone") + + def __init__( + self, + model_path: str, + resample_images: bool = True, + orig_model_image_size: tuple = (512, 512), + tissue_names: tuple = ("fem", "tib", "pat"), + tissues_to_combine: tuple = (), + bone_indices: tuple = (7, 8, 9) + # *args, + # **kwargs + ): + super().__init__( + model_path, + resample_images, + orig_model_image_size, + tissue_names, + tissues_to_combine, + bone_indices, + ) diff --git a/dosma/models/stanford_qdess.py b/dosma/models/stanford_qdess.py index d83f48b9..8dfc91e9 100644 --- a/dosma/models/stanford_qdess.py +++ b/dosma/models/stanford_qdess.py @@ -9,9 +9,16 @@ import numpy as np try: - from keras.layers import BatchNormalization as BN - from keras.layers import Concatenate, Conv2D, Conv2DTranspose, Dropout, Input, MaxPooling2D - from keras.models import Model + from tensorflow.keras.layers import BatchNormalization as BN + from tensorflow.keras.layers import ( + Concatenate, + Conv2D, + Conv2DTranspose, + Dropout, + Input, + MaxPooling2D, + ) + from tensorflow.keras.models import Model _SUPPORTS_KERAS = True except ImportError: # pragma: no-cover @@ -20,6 +27,7 @@ from dosma.core.med_volume import MedicalVolume from dosma.core.orientation import SAGITTAL from dosma.models.seg_model import KerasSegModel, whiten_volume +from dosma.models.util import get_tensor_shape_as_list __all__ = ["StanfordQDessUNet2D"] @@ -101,7 +109,8 @@ def __load_keras_model__(self, input_shape): if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool - xres = conv.shape.as_list()[1] + shape_ = get_tensor_shape_as_list(conv) + xres = shape_[1] if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: @@ -112,7 +121,7 @@ def __load_keras_model__(self, input_shape): # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): - deconv_shape = conv_ptr[depth_cnt].shape.as_list() + deconv_shape = get_tensor_shape_as_list(conv_ptr[depth_cnt]) deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv diff --git a/dosma/models/stanford_qdess_bone.py b/dosma/models/stanford_qdess_bone.py new file mode 100644 index 00000000..beab724a --- /dev/null +++ b/dosma/models/stanford_qdess_bone.py @@ -0,0 +1,407 @@ +from copy import deepcopy +import time +import numpy as np +import skimage +import gc +from dosma.core.med_volume import MedicalVolume +from dosma.core.orientation import SAGITTAL +from dosma.defaults import preferences +from dosma.models.seg_model import SegModel, fill_holes, get_connected_segments, whiten_volume + +import SimpleITK as sitk +from tensorflow.keras.models import load_model + +__all__ = [ + "StanfordQDessBoneUNet2D", "StanfordQDessBoneUNet2DCoronal", + "StanfordQDessBoneUNet2DAxial", "StanfordQDessBoneUNet2DSagittal", + "StanfordQDessBoneUNet2DSTAPLE" + ] + + +class StanfordQDessBoneUNet2D(SegModel): + """ + This model segments patellar cartilage ("pc"), femoral cartilage ("fc"), + tibial cartilage ("tc", "mtc", "ltc"), the meniscus ("men", "med_men", + "lat_men"), and bones ("fem", "tib", "pat") from quantitative + double echo steady state (qDESS) knee scans. The segmentation is computed on + the root-sum-of-squares (RSS) of the two echoes. + + There are a few weights files that are associated with this model. + We provide a short description of each below: + + * ``coarse_2d_sagittal_v1_11-18_Dec-04-2022.h5``: This is the baseline + model trained on a subset of the SKM-TEA dataset (v1.0.0) with bone + labels using the 2D network from Gatti et al. MAGMA, 2021. + * PROVIDE ANOTHER SET OF WEIGHTS USING SAME MODEL AS stanford_qdess + + By default this class will resample the input to be the size of the trained + model (384x384) for segmentation and then will re-sample the outputted + segmentation to match the original volume. + + By default, we y return the largest connected component of each tissue. This + can be disabled by setting `connected_only=False` in the `model.generate_mask()`. + + The output includes individual objects for each segmented tissue, including + separate medial/lateral segments of the meniscus and tibial cartilage. It also + includes a combined label for the meniscus and tibial cartilage, and a combined + label for all of the tissues in a single 3D mask. + + Examples: + + >>> # Create model. + >>> model = StanfordQDessBoneUNet2D("/path/to/model.h5") + + >>> # Generate mask from root-sum-of-squares (rss) volume. + >>> model.generate_mask(rss) + + >>> # Generate mask from dual-echo volume `de_vol` - shape: (SI, AP, LR, 2) + >>> model.generate_mask(de_vol) + + >>> # Generate mask from rss volume without getting largest connected components. + >>> model.generate_mask(rss, connected_only=False) + + """ + + ALIASES = ("stanford-qdess-2022-unet2d-bone", "skm-tea-unet2d-bone") + + TARGET_ORIENTATION = SAGITTAL + DEFAULT_IMAGE_SIZE = (512, 512) + + def __init__( + self, + model_path: str, + resample_images: bool = True, + orig_model_image_size: tuple = None, + tissue_names: tuple = ("pc", "fc", "mtc", "ltc", "med_men", "lat_men", "fem", "tib", "pat"), + tissues_to_combine: tuple = ( + (("lat_men", "med_men"), "men"), + (("mtc", "ltc"), "tc"), + ), + bone_indices: tuple = (7, 8, 9) + ): + """ + Args: + model_path (str): Path to model & weights file. + resample_images (bool): Whether or not to resample input volumes to + match original model size. If False, will build new model specific + to loaded image and will load model weights only. Default: True. + """ + + if orig_model_image_size is None: + orig_model_image_size = self.DEFAULT_IMAGE_SIZE + + self.batch_size = preferences.segmentation_batch_size + self.orig_model_image_size = orig_model_image_size + self.resample_images = resample_images + self.seg_model = self.build_model(model_path=model_path) + + self.tissue_names = tissue_names + self.tissues_to_combine = tissues_to_combine + self.bone_indices = bone_indices + + def build_model(self, model_path: str): + """ + Loads a segmentation model and its weights. + + Args: + model_path: Path to model & its weights. + + Returns: + Keras segmentation model + """ + if self.resample_images is True: + model = load_model(model_path, compile=False) + else: + raise Exception("Segmenting without resampling is not supported yet.") + + return model + + def generate_mask( + self, + volume: MedicalVolume, + connected_only: bool = True, + fill_bone_holes: bool = True, + ): + """Segment tissues. + + Args: + volume (MedicalVolume): The volume to segment. Either 3D or 4D. + If the volume is 3D, it is assumed to be the root-sum-of-squares (RSS) + of the two qDESS echoes. If 4D, volume must be of the shape ``(..., 2)``, + where the last dimension corresponds to echo 1 and 2, respectively. + connected_only (bool): If True, only the largest connected component of + each tissue is returned. Default: True. + fill_bone_holes (bool): If True, fill holes in bone segmentations. Default: True. + + Returns: + dict: Dictionary of segmented tissues. + """ + ndim = volume.ndim + if ndim not in (3, 4): + raise ValueError("`volume` must either be 3D or 4D") + + vol_copy = deepcopy(volume) + + if ndim == 4: + # if 4D, assume last dimension is echo 1 and 2 + vol_copy = np.sqrt(np.sum(vol_copy ** 2, axis=-1)) + + # reorient to the sagittal plane + vol_copy.reformat(self.TARGET_ORIENTATION, inplace=True) + + vol = vol_copy.volume + vol = self.__preprocess_volume__(vol) + + # reshape volumes to be (slice, 1, x, y) + v = np.transpose(vol, (2, 0, 1)) + v = np.expand_dims(v, axis=1) + + mask = self.seg_model.predict(v, batch_size=self.batch_size, verbose=1) + + # return mask + # one-hot encode mask, reorder axes, and re-size to input shape + mask = self.__postprocess_segmentation__( + mask, connected_only=connected_only, fill_bone_holes=fill_bone_holes + ) + + # Create temporary dictionary to hold target-oriented volumes + vols_target = {} + + # Create 'all' volume in target orientation + vol_all_target = deepcopy(vol_copy) + vol_all_target.volume = deepcopy(mask) + vols_target["all"] = vol_all_target + + # Create individual tissues in target orientation + for i, category in enumerate(self.tissue_names): + vol_target = deepcopy(vol_copy) + vol_target.volume = np.zeros_like(mask) + vol_target.volume[mask == i + 1] = 1 + vols_target[category] = vol_target + + # Combine tissues in target orientation space + for tissues, tissue_name in self.tissues_to_combine: + vol_target = deepcopy(vol_copy) + vol_target.volume = np.zeros_like(mask) + # Use logical OR instead of addition for boolean arrays + vol_target.volume[(vols_target[tissues[0]].volume == 1) | (vols_target[tissues[1]].volume == 1)] = 1 + vols_target[tissue_name] = vol_target + + # Now reformat all volumes to original orientation + vols = {} + for name, vol_target in vols_target.items(): + vol_cp = deepcopy(vol_target) + vol_cp.reformat(volume.orientation, inplace=True) + vols[name] = vol_cp + + return vols + + def __preprocess_volume__(self, volume: np.ndarray): + # TODO: Remove epsilon if difference in performance difference is not large. + + self.original_image_shape = volume.shape + + if self.resample_images is True: + volume = skimage.transform.resize( + image=volume, output_shape=self.orig_model_image_size + (volume.shape[-1],), order=3 + ) + else: + raise Exception("Segmenting without resampling is not supported yet.") + + return whiten_volume(volume, eps=1e-8) + + def __postprocess_segmentation__( + self, mask: np.ndarray, connected_only: bool = True, fill_bone_holes: bool = True + ): + + # USE ARGMAX TO GET SINGLE VOLUME SEGMENTATION OF ALL TISSUES + mask = np.argmax(mask, axis=1) + # # reshape mask to be (x, y, slice) + mask = np.transpose(mask, (1, 2, 0)) + + if self.resample_images is True: + mask = skimage.transform.resize( + image=mask, output_shape=self.original_image_shape, order=0 + ) + else: + raise Exception("Segmenting without resampling is not supported yet.") + + if connected_only is True: + mask = get_connected_segments(mask) + + if fill_bone_holes is True: + for bone_idx in self.bone_indices: + mask_ = fill_holes(mask, label_idx=bone_idx) + mask[mask_ == 1] = bone_idx + + mask = mask.astype(np.uint8) + + return mask + +class StanfordQDessBoneUNet2DCoronal(StanfordQDessBoneUNet2D): + """2D UNet for bone segmentation in coronal plane""" + ALIASES = ("stanford-qdess-2022-unet2d-bone-coronal",) + CORONAL_TRANSPOSED = ('LR', 'SI', 'AP') + TARGET_ORIENTATION = CORONAL_TRANSPOSED + DEFAULT_IMAGE_SIZE = (160, 512) + + +class StanfordQDessBoneUNet2DAxial(StanfordQDessBoneUNet2D): + """2D UNet for bone segmentation in axial plane""" + ALIASES = ("stanford-qdess-2022-unet2d-bone-axial",) + AXIAL_TRANSPOSED = ('LR', 'AP', 'SI') + TARGET_ORIENTATION = AXIAL_TRANSPOSED + DEFAULT_IMAGE_SIZE = (160, 512) # Example different size + +class StanfordQDessBoneUNet2DSagittal(StanfordQDessBoneUNet2D): + """2D UNet for bone segmentation in sagittal plane""" + ALIASES = ("stanford-qdess-2022-unet2d-bone-sagittal",) + TARGET_ORIENTATION = SAGITTAL + DEFAULT_IMAGE_SIZE = (512, 512) # Example different size + + +class StanfordQDessBoneUNet2DSTAPLE(): + """ + This model applies the sagittal, coronal, and axial UNet + models to the input volume and then combines the results + using STAPLE. + """ + + # need to combine labels from multiple models - but only trust + # some models for certain tissues. + # ("pc", "fc", "mtc", "ltc", "med_men", "lat_men", "fem", "tib", "pat") + # Sag - include all of them. + # Cor - only: "fc", "mtc", "ltc", "med_men", "lat_men", "fem", "tib", + # Ax - only: "pc", "fem", "tib", "pat" + + list_idx_not_include_STAPLE = [ + [], # what not to include for sagittal + [1], # what not to include for coronal + [2, 3, 4, 5, 6] # what not to include for axial + ] + # dict_tissues_combine_staple = { + # "pc": ["sag", "ax"], + # "fc": ["sag", "cor", "ax"], + # "mtc": ["sag", "cor"], + # "ltc": ["sag", "cor"], + # "med_men": ["sag", "cor"], + # "lat_men": ["sag", "cor"], + # "fem": ["sag", "cor", "ax"], + # "tib": ["sag", "cor", "ax"], + # "pat": ["sag", "ax"] + # } + dict_plane_idx = { + "sag": 0, + "cor": 1, + "ax": 2 + } + + def __init__( + self, + sagittal_model_path, coronal_model_path, axial_model_path, + tissue_names: tuple = ("pc", "fc", "mtc", "ltc", "med_men", "lat_men", "fem", "tib", "pat"), + tissues_to_combine: tuple = ( + (("lat_men", "med_men"), "men"), + (("mtc", "ltc"), "tc"), + ), + verbose=False, + ): + + self.sagittal_model_path = sagittal_model_path + self.coronal_model_path = coronal_model_path + self.axial_model_path = axial_model_path + self.tissue_names = tissue_names + self.tissues_to_combine = tissues_to_combine + self.verbose = verbose + def generate_mask(self, volume: MedicalVolume): + """ + iterate over the models, loading them, generating masks, + then deleting them from memory - don't want to have + a GPU memory issue. + """ + vol_copy = deepcopy(volume) + + list_models = [ + [self.sagittal_model_path, StanfordQDessBoneUNet2DSagittal], + [self.coronal_model_path, StanfordQDessBoneUNet2DCoronal], + [self.axial_model_path, StanfordQDessBoneUNet2DAxial] + ] + + masks = [] + for model_idx, (model_path, model_class) in enumerate(list_models): + start_time = time.time() + model = model_class(model_path) + masks_dict_ = model.generate_mask(volume) + masks.append(masks_dict_["all"]) + del model + gc.collect() + if self.verbose: + print(f"Time taken to generate mask {model_idx}: {time.time() - start_time} seconds") + + # for each mask, go in and set the regions we are not using to zero. + tic = time.time() + for i, mask in enumerate(masks): + for idx in self.list_idx_not_include_STAPLE[i]: + mask.volume[mask.volume == idx] = 0 + if self.verbose: + print(f"Time taken to set the regions we are not using to zero: {time.time() - tic} seconds") + tic = time.time() + masks_sitk = [mask.to_sitk() for mask in masks] + # cast to uint16 to prevent overflow + masks_sitk = [sitk.Cast(mask, sitk.sitkUInt16) for mask in masks_sitk] + + # unpack the sitk_masks + staple_mask_sitk = sitk.MultiLabelSTAPLE(*masks_sitk) + + if self.verbose: + print(f"Time to run STAPLE: {time.time() - tic} seconds") + + tic = time.time() + + staple_mask_mv = MedicalVolume.from_sitk(staple_mask_sitk) + # set any labels > 9 (the expected max label) to 0 + # these were likely undecided labels in the STAPLE algorithm + staple_mask_mv.volume[staple_mask_mv.volume > 9] = 0 + # set to uint8 + staple_mask_mv.volume = staple_mask_mv.volume.astype(np.uint8) + staple_mask_mv.reformat(volume.orientation, inplace=True) + + # now... create the individual tissue masks as was expected/previously done by + # the other models. + # Create temporary dictionary to hold target-oriented volumes + vols = {} + # Create 'all' volume in target orientation + vols["all"] = staple_mask_mv + + # Create individual tissues in target orientation + for i, category in enumerate(self.tissue_names): + vol = deepcopy(vol_copy) + vol.volume = np.zeros_like(staple_mask_mv.volume) + vol.volume[staple_mask_mv.volume == i + 1] = 1 + vols[category] = vol + + # Combine tissues in target orientation space + for tissues, tissue_name in self.tissues_to_combine: + vol = deepcopy(vol_copy) + vol.volume = np.zeros_like(staple_mask_mv.volume) + # Use logical OR instead of addition for boolean arrays + vol.volume[(vols[tissues[0]].volume == 1) | (vols[tissues[1]].volume == 1)] = 1 + vols[tissue_name] = vol + + if self.verbose: + print(f"Time taken to create the individual tissue masks: {time.time() - tic} seconds") + + return vols + + + def __combine_masks__(self, list_masks, vol_copy): + """ + Combine the masks from the different planes. + """ + + # this should use STAPLE algorithm + # this is build into SimpleITK + # need to convert MedicalVolume to SimpleITK image + # then do combination + # then convert back to MedicalVolume + # Then return the MedicalVolume diff --git a/dosma/models/util.py b/dosma/models/util.py index 0654cfa4..12ee7b8a 100644 --- a/dosma/models/util.py +++ b/dosma/models/util.py @@ -1,94 +1,30 @@ """ -Functions for loading Keras models - -@author: Arjun Desai - (C) Stanford University, 2019 +Utility functions common to multiple models and their files. """ -import os -import yaml -from functools import partial -from typing import Sequence - -from dosma.models.oaiunet2d import IWOAIOAIUnet2D, IWOAIOAIUnet2DNormalized, OAIUnet2D -from dosma.models.seg_model import SegModel - -__all__ = ["get_model", "SUPPORTED_MODELS"] - -# Network architectures currently supported -__SUPPORTED_MODELS__ = [OAIUnet2D, IWOAIOAIUnet2D, IWOAIOAIUnet2DNormalized] -# Initialize supported models for the command line -SUPPORTED_MODELS = [x.ALIASES[0] for x in __SUPPORTED_MODELS__] +import tensorflow as tf -def get_model(model_str, input_shape, weights_path, **kwargs): - """Get a Keras model - :param model_str: model identifier - :param input_shape: tuple or list of tuples for initializing input(s) into Keras model - :param weights_path: filepath to weights used to initialize Keras model - :return: a Keras model +def get_tensor_shape_as_list(x): """ - for m in __SUPPORTED_MODELS__: - if model_str in m.ALIASES or model_str == m.__name__: - return m(input_shape, weights_path, **kwargs) - - raise LookupError("%s model type not supported" % model_str) - - -def model_from_config(cfg_file_or_dict, weights_dir=None, **kwargs) -> SegModel: - """Builds a new model from a config file. - - This function is useful for building models that have similar structure/architecture - to existing models supported in DOSMA, but have different weights and categories. - The config defines what dosma model should be used as a base, what weights should be loaded, - and what are the categories. - - The config file should be a yaml file that has the following keys: - * "DOSMA_MODEL": The base model that exists in DOSMA off of which data should be built. - * "CATEGORIES": The categories that are supposed to be loaded. - * "WEIGHTS_FILE": The basename of (or full path to) weights that should be loaded. + Get the shape of a tensor as a list Args: - cfg_file_or_dict (str or dict): The yaml file or dictionary corresponding to the config. - weights_dir (str): The directory where weights are stored. If not specified, assumes - "WEIGHTS_FILE" field in the config is the full path to the weights. - **kwargs: Keyword arguments for base model `__init__` + x: tf.Tensor or tf.TensorShape or list or tuple Returns: - SegModel: A segmentation model with appropriate changes to `generate_mask` to produce - the right masks. - """ - - def _gen_mask(func, *_args, **_kwargs): - out = func(*_args, **_kwargs) - if isinstance(out, dict): - # Assumes that the dict is ordered, which it is for python>=3.6 - out = out.values() - elif not isinstance(out, Sequence): - out = [out] - if not len(categories) == len(out): - raise ValueError("Got {} outputs, but {} categories".format(len(out), len(categories))) - return {cat: out for cat, out in zip(categories, out)} + list: shape of the tensor - if isinstance(cfg_file_or_dict, str): - with open(cfg_file_or_dict, "r") as f: - cfg = yaml.load(f) - else: - cfg = cfg_file_or_dict + Notes: This was implemented becuase getting conv.shape was returning a tuple in some + versions of tensorflow/keras. - base_model = cfg["DOSMA_MODEL"] - categories = cfg["CATEGORIES"] - weights = cfg["WEIGHTS_FILE"] - if not os.path.isfile(weights): - assert weights_dir, "`weights_dir` must be specified" - weights = os.path.join(weights_dir, cfg["WEIGHTS_FILE"]) - - try: - model: SegModel = get_model(base_model, weights_path=weights, force_weights=True, **kwargs) - except LookupError as e: - raise LookupError("BASE_MODEL '{}' not supported \n{}".format(base_model, e)) - - # Override generate mask function - model.generate_mask = partial(_gen_mask, model.generate_mask) + """ - return model + # # handle different tensorflow/keras versions returning tuple vs. tf.TensorShape + shape_ = x.shape + if isinstance(shape_, list): + return shape_ + elif isinstance(shape_, tuple): + return list(shape_) + elif isinstance(shape_, tf.TensorShape): + return shape_.as_list() diff --git a/dosma/scan_sequences/mri/qdess.py b/dosma/scan_sequences/mri/qdess.py index 3358b243..a308a395 100644 --- a/dosma/scan_sequences/mri/qdess.py +++ b/dosma/scan_sequences/mri/qdess.py @@ -118,6 +118,7 @@ def generate_t2_map( nan_bounds: Tuple[float, float] = (0, 100), nan_to_num: float = 0.0, decimals: int = 1, + spoiling: bool = True, ): """Generate 3D T2 map. @@ -158,6 +159,8 @@ def generate_t2_map( will not be replaced. decimals (int): Number of decimal places to round to. If ``None``, values will not be rounded. + spoiling (bool, optional): If ``True``, use spoiling parameters from dicom headers. + If ``False``, assume low-spoiling and use alternative equations. Returns: qv.T2: T2 fit map. @@ -170,9 +173,15 @@ def generate_t2_map( self.get_metadata(self.__GL_AREA_TAG__, gl_area) is None or self.get_metadata(self.__TG_TAG__, tg) is None ): - raise ValueError( - "Dicom headers do not contain tags for `gl_area` and `tg`. Please input manually" + # warning that spoiling parameters are not found in the dicom headers + warnings.warn( + "Dicom headers do not contain tags for `gl_area` and `tg`. " + + "Assuming low-spoiling, and thus dropping these parameters in " + + "The T2 fit. See: Sveinsson et al. " + + "A Simple Analytic Method for Estimating T2 in the Knee from DESS" + + "equations 6 & 7." ) + spoiling = False xp = self.volumes[0].device.xp ref_dicom = self.ref_dicom if self.ref_dicom is not None else pydicom.Dataset() @@ -187,39 +196,57 @@ def generate_t2_map( # All timing in seconds TR = (float(ref_dicom.RepetitionTime) if tr is None else tr) * 1e-3 TE = (float(ref_dicom.EchoTime) if te is None else te) * 1e-3 - Tg = (float(ref_dicom[self.__TG_TAG__].value) if tg is None else tg) * 1e-6 T1 = (float(tissue.T1_EXPECTED) if t1 is None else t1) * 1e-3 + Tg = (float(ref_dicom[self.__TG_TAG__].value) if tg is None else tg) * 1e-6 + GlArea = float(ref_dicom[self.__GL_AREA_TAG__].value) if gl_area is None else gl_area + if (Tg == 0) or (GlArea == 0): + warnings.warn( + "Dicom headers do not contain tags for `gl_area` and `tg`. " + + "Assuming low-spoiling, and thus dropping these parameters in " + + "The T2 fit. See: Sveinsson et al. " + + "A Simple Analytic Method for Estimating T2 in the Knee from DESS" + + "equations 6 & 7." + ) + spoiling = False # Flip Angle (degree -> radians) alpha = float(ref_dicom.FlipAngle) if alpha is None else alpha alpha = math.radians(alpha) if np.allclose(math.sin(alpha / 2), 0): - warnings.warn("sin(flip angle) is close to 0 - t2 map may fail.") - - GlArea = float(ref_dicom[self.__GL_AREA_TAG__].value) if gl_area is None else gl_area - - Gl = GlArea / (Tg * 1e6) * 100 - gamma = 4258 * 2 * math.pi # Gamma, Rad / (G * s). - dkL = gamma * Gl * Tg + warnings.warn("sin(flip angle) is close to 0 - t2 map may fail.", stacklevel=2) - # Simply math - k = ( - xp.power((xp.sin(alpha / 2)), 2) - * (1 + xp.exp(-TR / T1 - TR * xp.power(dkL, 2) * diffusivity)) - / (1 - xp.cos(alpha) * xp.exp(-TR / T1 - TR * xp.power(dkL, 2) * diffusivity)) - ) - - c1 = (TR - Tg / 3) * (xp.power(dkL, 2)) * diffusivity - - # T2 fit mask = xp.ones([r, c, num_slices]) ratio = mask * echo_2 / echo_1 ratio = xp.nan_to_num(ratio) + + if spoiling: + + + Gl = GlArea / (Tg * 1e6) * 100 + gamma = 4258 * 2 * math.pi # Gamma, Rad / (G * s). + dkL = gamma * Gl * Tg + + # Simply math + k = ( + xp.power((xp.sin(alpha / 2)), 2) + * (1 + xp.exp(-TR / T1 - TR * xp.power(dkL, 2) * diffusivity)) + / (1 - xp.cos(alpha) * xp.exp(-TR / T1 - TR * xp.power(dkL, 2) * diffusivity)) + ) + c1 = (TR - Tg / 3) * (xp.power(dkL, 2)) * diffusivity + + else: + # T2 fit + k = ( + xp.power((xp.sin(alpha / 2)), 2) + * (1 + xp.exp(-TR / T1)) + / (1 - xp.cos(alpha) * xp.exp(-TR / T1)) + ) + c1 = 0 + # have to divide division into steps to avoid overflow error t2map = -2000 * (TR - TE) / (xp.log(abs(ratio) / k) + c1) - t2map = xp.nan_to_num(t2map) # Filter calculated T2 values that are below 0ms and over 100ms diff --git a/dosma/scan_sequences/scan_io.py b/dosma/scan_sequences/scan_io.py index 09841284..8a49ec62 100644 --- a/dosma/scan_sequences/scan_io.py +++ b/dosma/scan_sequences/scan_io.py @@ -113,7 +113,9 @@ def from_dict(cls, data: Dict[str, Any], force: bool = False): for k, v in data.items(): if not hasattr(scan, k) and not force: - warnings.warn(f"{cls.__name__} does not have attribute {k}. Skipping...") + warnings.warn( + f"{cls.__name__} does not have attribute {k}. Skipping...", stacklevel=2 + ) continue scan.__setattr__(k, v) @@ -213,7 +215,8 @@ def load(cls, path_or_data: Union[str, Dict], num_workers: int = 0): return scan except Exception: warnings.warn( - f"Failed to load {cls.__name__} from data. Trying to load from dicom file." + f"Failed to load {cls.__name__} from data. Trying to load from dicom file.", + stacklevel=2, ) data = cls._convert_attr_name(data) @@ -238,7 +241,9 @@ def load(cls, path_or_data: Union[str, Dict], num_workers: int = 0): for k, v in data.items(): if not hasattr(scan, k): - warnings.warn(f"{cls.__name__} does not have attribute {k}. Skipping...") + warnings.warn( + f"{cls.__name__} does not have attribute {k}. Skipping...", stacklevel=2 + ) continue scan.__setattr__(k, v) @@ -252,6 +257,7 @@ def save_data( "save_data is deprecated since v0.0.13 and will no longer be " "available in v0.1. Use `save` instead.", DeprecationWarning, + stacklevel=2, ) return self.save(base_save_dirpath, data_format) @@ -304,7 +310,7 @@ def save_custom_data( paths = [paths[k] for k in keys] paths = [os.path.join(_path, f"{k}") for k, _path in zip(keys, paths)] values = self.save_custom_data(metadata.values(), paths, fname_fmt, **kwargs) - metadata = {k: v for k, v in zip(keys, values)} + metadata = dict(zip(keys, values)) elif not isinstance(metadata, str) and isinstance(metadata, (Sequence, Set)): values = list(metadata) paths = [os.path.join(_path, "{:03d}".format(i)) for i, _path in enumerate(paths)] @@ -388,7 +394,7 @@ def load_custom_data(cls, data: Any, **kwargs): if issubclass(dtype, Dict): keys = cls.load_custom_data(data.keys(), **kwargs) values = cls.load_custom_data(data.values(), **kwargs) - data = {k: v for k, v in zip(keys, values)} + data = dict(zip(keys, values)) elif not issubclass(dtype, str) and issubclass(dtype, (list, tuple, set)): data = dtype([cls.load_custom_data(x, **kwargs) for x in data]) else: diff --git a/dosma/scan_sequences/scans.py b/dosma/scan_sequences/scans.py index c5392feb..f4693e15 100644 --- a/dosma/scan_sequences/scans.py +++ b/dosma/scan_sequences/scans.py @@ -148,7 +148,7 @@ def __add_tissue__(self, new_tissue: Tissue): ValueError: If tissue already exists in list. For example, we cannot add FemoralCartilage twice to the list of tissues. """ - contains_tissue = any([tissue.ID == new_tissue.ID for tissue in self.tissues]) + contains_tissue = any(tissue.ID == new_tissue.ID for tissue in self.tissues) if contains_tissue: raise ValueError("Tissue already exists") @@ -185,7 +185,7 @@ class NonTargetSequence(ScanSequence): @abstractmethod def interregister(self, target_path: str, mask_path: str = None): - """Register this scan to the target scan - save as parameter in scan (volumes, subvolumes, etc). + """Register scan to the target scan - save as parameter in scan (volumes, subvolumes, etc). We use the term *interregister* to refer to registration between volumes of different scans. Conversely, *intraregister* refers to registering volumes from the same scan. diff --git a/dosma/tissues/femoral_cartilage.py b/dosma/tissues/femoral_cartilage.py index e5832e79..b63623f9 100644 --- a/dosma/tissues/femoral_cartilage.py +++ b/dosma/tissues/femoral_cartilage.py @@ -3,7 +3,12 @@ import numpy as np import pandas as pd -import scipy.ndimage as sni +import scipy + +if scipy.__version__ >= "2.0.0": + from scipy.ndimage import center_of_mass as c_of_m +else: + from scipy.ndimage.measurements import center_of_mass as c_of_m from dosma.core.device import get_array_module from dosma.core.io.format_io import ImageDataFormat @@ -164,7 +169,7 @@ def split_regions( # medial/lateral division # take into account scanning direction - center_of_mass = sni.measurements.center_of_mass(mask) + center_of_mass = c_of_m(mask) com_slicewise = center_of_mass[-1] ml_volume = np.asarray(np.zeros(mask.shape), dtype=np.uint16) @@ -479,7 +484,8 @@ def __save_quant_data__(self, dirpath: str): else: warnings.warn( "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale." - % (quant_val.name, upper_bound) + % (quant_val.name, upper_bound), + stacklevel=2, ) plt.imshow(data_map, cmap="jet") @@ -533,7 +539,7 @@ def save_data(self, save_dirpath, data_format: ImageDataFormat = preferences.ima ) def __binarize_region_mask__(self, region_mask, roi): - return np.asarray(np.bitwise_and(region_mask, roi) == roi, dtype=np.bool) + return np.asarray(np.bitwise_and(region_mask, roi) == roi, dtype=bool) def __split_mask__(self): assert ( diff --git a/dosma/tissues/meniscus.py b/dosma/tissues/meniscus.py index 93a9d80e..ac4537ac 100644 --- a/dosma/tissues/meniscus.py +++ b/dosma/tissues/meniscus.py @@ -10,7 +10,12 @@ import numpy as np import pandas as pd -import scipy.ndimage as sni +import scipy + +if scipy.__version__ >= "2.0.0": + from scipy.ndimage import center_of_mass as c_of_m +else: + from scipy.ndimage.measurements import center_of_mass as c_of_m from dosma.core.device import get_array_module from dosma.core.med_volume import MedicalVolume @@ -108,7 +113,7 @@ def split_regions(self, base_map): with tilted mensici. This will be addressed in a later release. To avoid computing metrics on these regions, set ``self.split_ml_only=True``. """ - center_of_mass = sni.measurements.center_of_mass(base_map) # zero indexed + center_of_mass = c_of_m(base_map) # zero indexed com_sup_inf = int(np.ceil(center_of_mass[0])) com_ant_post = int(np.ceil(center_of_mass[1])) @@ -166,7 +171,7 @@ def __calc_quant_vals__(self, quant_map: MedicalVolume, map_type: QuantitativeVa coronal_categories = [x for x in coronal_categories if x[0] == -1] categorical_mask = np.zeros(region_mask.shape[:-1]) - base_mask = self.__mask__.A.astype(np.bool) + base_mask = self.__mask__.A.astype(bool) labels = {} for idx, ( (axial, axial_name), @@ -251,7 +256,7 @@ def __calc_quant_vals_old__(self, quant_map, map_type): axial_map = np.asarray( axial_region_mask == self._SUPERIOR_KEY, dtype=np.float32 ) + np.asarray(axial_region_mask == self._INFERIOR_KEY, dtype=np.float32) - axial_map = np.asarray(axial_map, dtype=np.bool) + axial_map = np.asarray(axial_map, dtype=bool) else: axial_map = axial_region_mask == axial @@ -370,7 +375,8 @@ def __save_quant_data__(self, dirpath): else: warnings.warn( "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale." - % (quant_val.name, upper_bound) + % (quant_val.name, upper_bound), + stacklevel=2, ) plt.imshow(data_map, cmap="jet") diff --git a/dosma/tissues/patellar_cartilage.py b/dosma/tissues/patellar_cartilage.py index aba92e43..a9632443 100644 --- a/dosma/tissues/patellar_cartilage.py +++ b/dosma/tissues/patellar_cartilage.py @@ -9,7 +9,12 @@ import numpy as np import pandas as pd -import scipy.ndimage as sni +import scipy + +if scipy.__version__ >= "2.0.0": + from scipy.ndimage import center_of_mass as c_of_m +else: + from scipy.ndimage.measurements import center_of_mass as c_of_m from dosma.core.device import get_array_module from dosma.core.quant_vals import QuantitativeValueType @@ -103,16 +108,13 @@ def split_regions(self, base_map): If `self.medial_to_lateral`, last dimension should be ML. """ if np.sum(base_map) == 0: - warnings.warn("No mask for `%s` was found." % self.FULL_NAME) + warnings.warn("No mask for `%s` was found." % self.FULL_NAME, stacklevel=2) # Superficial/Deep (A/P) locs = base_map.sum(axis=1).nonzero() voxels = base_map[locs[0], :, locs[1]] com_sup_inf = np.asarray( - [ - int(np.ceil(sni.measurements.center_of_mass(voxels[i, :])[0])) - for i in range(voxels.shape[0]) - ] + [int(np.ceil(c_of_m(voxels[i, :])[0])) for i in range(voxels.shape[0])] ) region_mask_sup_deep = np.full(base_map.shape, self._REGION_DEEP_KEY) for i in range(len(com_sup_inf)): @@ -123,7 +125,7 @@ def split_regions(self, base_map): # M/L cum_ml = np.nonzero(base_map.sum(axis=(0, 1)))[0] # noqa: F841 # midpoint_ml = int(np.ceil((np.min(cum_ml) + np.max(cum_ml)) / 2)) - midpoint_ml = int(np.ceil(sni.measurements.center_of_mass(base_map)[2])) + midpoint_ml = int(np.ceil(c_of_m(base_map)[2])) region_mask_med_lat = np.full(base_map.shape, self._LATERAL_KEY) medial_span = slice(0, midpoint_ml) if self.medial_to_lateral else slice(midpoint_ml, None) region_mask_med_lat[:, :, medial_span] = self._MEDIAL_KEY @@ -161,7 +163,7 @@ def __calc_quant_vals__(self, quant_map, map_type): axial_map = np.asarray( deep_superficial_map == self._REGION_SUPERFICIAL_KEY, dtype=np.float32 ) + np.asarray(deep_superficial_map == self._REGION_DEEP_KEY, dtype=np.float32) - axial_map = np.asarray(axial_map, dtype=np.bool) + axial_map = np.asarray(axial_map, dtype=bool) else: axial_map = deep_superficial_map == axial @@ -277,7 +279,8 @@ def __save_quant_data__(self, dirpath): else: warnings.warn( "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale." - % (quant_val.name, upper_bound) + % (quant_val.name, upper_bound), + stacklevel=2, ) plt.imshow(data_map, cmap="jet") diff --git a/dosma/tissues/tibial_cartilage.py b/dosma/tissues/tibial_cartilage.py index 2230ffd3..50bb5bd4 100644 --- a/dosma/tissues/tibial_cartilage.py +++ b/dosma/tissues/tibial_cartilage.py @@ -195,7 +195,7 @@ def __calc_quant_vals__(self, quant_map, map_type): axial_map = (axial_region_mask == self._SUPERIOR_KEY).astype(np.float32) + ( axial_region_mask == self._INFERIOR_KEY ).astype(np.float32) - axial_map = axial_map.astype(np.bool) + axial_map = axial_map.astype(bool) else: axial_map = axial_region_mask == axial @@ -311,7 +311,8 @@ def __save_quant_data__(self, dirpath): else: warnings.warn( "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale." - % (quant_val.name, upper_bound) + % (quant_val.name, upper_bound), + stacklevel=2, ) plt.imshow(data_map, cmap="jet") diff --git a/dosma/tissues/tissue.py b/dosma/tissues/tissue.py index 5b8b757c..d7d372fa 100644 --- a/dosma/tissues/tissue.py +++ b/dosma/tissues/tissue.py @@ -319,7 +319,7 @@ def largest_cc(mask, num=1): """Return the largest `num` connected component(s) of a 3D mask array. Args: - mask (np.ndarray): 3D mask array (`np.bool` or `np.[u]int`). + mask (np.ndarray): 3D mask array (`bool` or `np.[u]int`). num (int, optional): Maximum number of connected components to keep. Returns: @@ -336,8 +336,8 @@ def largest_cc(mask, num=1): if not label_nb: raise ValueError("No non-zero values: no connected components") if label_nb == 1: - return mask.astype(np.bool) - label_count = np.bincount(labels.ravel().astype(np.int)) + return mask.astype(bool) + label_count = np.bincount(labels.ravel().astype(int)) # discard 0 the 0 label label_count[0] = 0 diff --git a/dosma/utils/geometry_utils.py b/dosma/utils/geometry_utils.py index c46f1377..7a9d8575 100644 --- a/dosma/utils/geometry_utils.py +++ b/dosma/utils/geometry_utils.py @@ -1,5 +1,11 @@ import numpy as np -import scipy.ndimage as sni +import scipy + +if scipy.__version__ >= "2.0.0": + from scipy.ndimage import center_of_mass as c_of_m +else: + from scipy.ndimage.measurements import center_of_mass as c_of_m + from scipy import optimize from dosma.core.device import get_array_module @@ -112,13 +118,12 @@ def center_of_mass(input, labels=None, index=None): Note: This is adapted from scipy.ndimage to support cupy. """ - _sni = sni + _c_of_m = c_of_m if env.cupy_available(): import cupy as cp if get_array_module(input) == cp: import cupyx.scipy.ndimage as csni - _sni = csni - - return _sni.center_of_mass(input, labels=labels, index=index) + _c_of_m = csni.center_of_mass + return _c_of_m(input, labels=labels, index=index) diff --git a/dosma/utils/img_utils.py b/dosma/utils/img_utils.py index a4ebdb80..3269e550 100644 --- a/dosma/utils/img_utils.py +++ b/dosma/utils/img_utils.py @@ -36,8 +36,8 @@ def downsample_slice(img_array, ds_factor, is_mask=False): L = list(img_array) def grouper(iterable, n): - args = [iter(iterable)] * n - return itertools.zip_longest(fillvalue=0, *args) + args = [iter(iterable)] * n # noqa B026 + return itertools.zip_longest(fillvalue=0, *args) # noqa B026 final = np.array([sum(x) for x in grouper(L, ds_factor)]) final = np.transpose(final, (1, 2, 0)) diff --git a/dosma/utils/io_utils.py b/dosma/utils/io_utils.py index e06bf28c..acc991b0 100644 --- a/dosma/utils/io_utils.py +++ b/dosma/utils/io_utils.py @@ -95,7 +95,10 @@ def load_h5(file_path): data = {} with h5py.File(file_path, "r") as f: for key in f.keys(): - data[key] = f.get(key).value + if h5py.version.version_tuple < (3, 0, 0): + data[key] = f.get(key).value + else: + data[key] = f[key][:] return data @@ -126,7 +129,10 @@ def save_tables( df = data_frames[i] df.to_excel(writer, sheet_names[i], index=False) - writer.save() + if pd.__version__ >= "1.2.0": + writer.close() + else: + writer.save() def init_logger(log_file: str, debug: bool = False): # pragma: no cover @@ -142,6 +148,7 @@ def init_logger(log_file: str, debug: bool = False): # pragma: no cover "init_logger is deprecated since v0.0.14 and will no longer be " "supported in v0.13. Use `dosma.setup_logger` instead.", DeprecationWarning, + stacklevel=2, ) level = logging.DEBUG if debug else logging.INFO diff --git a/setup.py b/setup.py index 70c1f668..6dec842d 100644 --- a/setup.py +++ b/setup.py @@ -131,11 +131,10 @@ def run(self): "sphinx", "sphinxcontrib.bibtex", "m2r2", - "tensorflow<=2.4.1", - "keras<=2.4.3", + "tensorflow>=2.0.0", "sigpy", ], - "ai": ["tensorflow<=2.4.1", "keras<=2.4.3"], + "ai": ["tensorflow>=2.0.0"], "docs": ["mistune>=0.8.1,<2.0.0", "sphinx", "sphinxcontrib.bibtex", "m2r2"], } diff --git a/tests/core/io/test_dicom_io.py b/tests/core/io/test_dicom_io.py index e902450f..82ab7bd0 100644 --- a/tests/core/io/test_dicom_io.py +++ b/tests/core/io/test_dicom_io.py @@ -179,7 +179,7 @@ def test_dicom_reader_single_file(self): if not x.startswith(".") and x.endswith(".dcm") ] ) - expected = pydicom.read_file(dcm_file, force=True) + expected = pydicom.dcmread(dcm_file, force=True) vol = self.dr.load(dcm_file)[0] assert vol.volume.ndim == 3 @@ -451,7 +451,7 @@ def test_init_params(self): def test_sample_pydicom_data(self): """Test DICOM reader with sample pydicom data.""" filepath = get_testdata_file("MR_small.dcm") - mv_pydicom = pydicom.read_file(filepath) + mv_pydicom = pydicom.dcmread(filepath) arr = mv_pydicom.pixel_array dr = DicomReader(group_by=None) @@ -466,7 +466,7 @@ def test_sample_pydicom_data(self): out_path = os.path.join(out_dir, "I0001.dcm") dw(mv, dir_path=out_dir) - mv_pydicom_loaded = pydicom.read_file(out_path) + mv_pydicom_loaded = pydicom.dcmread(out_path) assert np.all(mv_pydicom_loaded.pixel_array == arr) assert self.are_equivalent_headers(mv_pydicom_loaded, mv_pydicom) diff --git a/tests/core/io/test_nifti_io.py b/tests/core/io/test_nifti_io.py index b3f15197..78a93a4c 100644 --- a/tests/core/io/test_nifti_io.py +++ b/tests/core/io/test_nifti_io.py @@ -72,7 +72,7 @@ def test_nifti_nib(self): def test_state(self): nr1 = NiftiReader() state_dict = nr1.state_dict() - state_dict = {k: "foo" for k in state_dict} + state_dict = {k: "foo" for k in state_dict} # noqa C420 nr2 = NiftiReader() nr2.load_state_dict(state_dict) @@ -81,7 +81,7 @@ def test_state(self): nw1 = NiftiWriter() state_dict = nw1.state_dict() - state_dict = {k: "bar" for k in state_dict} + state_dict = {k: "bar" for k in state_dict} # noqa C420 nw2 = NiftiWriter() nw2.load_state_dict(state_dict) diff --git a/tests/core/test_fitting.py b/tests/core/test_fitting.py index 633515d7..f1e20d00 100644 --- a/tests/core/test_fitting.py +++ b/tests/core/test_fitting.py @@ -268,7 +268,7 @@ def test_polyfit_initialization(self): # The values will not be accurate, but other pixel values should be. x, y, b = _generate_monoexp_data((10, 10, 20)) t = 1 / np.abs(b) - mask_arr = np.zeros(y[0].shape, dtype=np.bool) + mask_arr = np.zeros(y[0].shape, dtype=bool) mask_arr[:5, :5] = 1 y[0][mask_arr] = 0 @@ -313,11 +313,11 @@ def test_mask(self): assert np.allclose(a_hat.volume[mask_arr != 0], 1.0) assert np.allclose(b_hat.volume[mask_arr != 0], b[mask_arr != 0]) - with self.assertRaises(TypeError): + with self.assertRaises(TypeError): # noqa B908 fitter = CurveFitter(monoexponential) popt = fitter.fit(x, y, mask="foo")[0] - with self.assertRaises(RuntimeError): + with self.assertRaises(RuntimeError): # noqa B908 mask_incorrect_shape = np.random.rand(5, 5, 5) > 0.5 fitter = CurveFitter(monoexponential) popt = fitter.fit(x, y, mask=mask_incorrect_shape)[0] diff --git a/tests/core/test_med_volume.py b/tests/core/test_med_volume.py index aa7aa41d..6a4e5e88 100644 --- a/tests/core/test_med_volume.py +++ b/tests/core/test_med_volume.py @@ -6,7 +6,8 @@ import nibabel as nib import nibabel.testing as nib_testing import numpy as np -import pydicom.data as pydd + +# import pydicom.data as pydd import SimpleITK as sitk from dosma.core.device import Device @@ -31,6 +32,9 @@ class TestMedicalVolume(unittest.TestCase): _TEMP_PATH = os.path.join(ututils.TEMP_PATH, __name__) + TO_SITK_RTOL = 1e-4 + TO_SITK_ATOL = 1e-2 + @classmethod def setUpClass(cls): os.makedirs(cls._TEMP_PATH, exist_ok=True) @@ -170,35 +174,288 @@ def test_to_nib(self): assert np.all(nib_from_mv.get_fdata() == nib_img.get_fdata()) assert np.all(nib_from_mv.affine == nib_img.affine) - def test_to_sitk(self): - mv = MedicalVolume(np.random.rand(10, 20, 30), self._AFFINE) - filepath = os.path.join(ututils.TEMP_PATH, "med_vol_to_sitk.nii.gz") - NiftiWriter().save(mv, filepath) + # def test_to_sitk(self): + # mv = MedicalVolume(np.random.rand(10, 20, 30), self._AFFINE) + # filepath = os.path.join(ututils.TEMP_PATH, "med_vol_to_sitk.nii.gz") + # NiftiWriter().save(mv, filepath) - expected = sitk.ReadImage(filepath) + # expected = sitk.ReadImage(filepath) - nr = NiftiReader() - mv = nr.load(filepath) - img = mv.to_sitk() + # nr = NiftiReader() + # mv = nr.load(filepath) + # img = mv.to_sitk() - assert np.allclose(sitk.GetArrayViewFromImage(img), sitk.GetArrayViewFromImage(expected)) - assert img.GetSize() == mv.shape - assert np.allclose(img.GetOrigin(), expected.GetOrigin()) - assert img.GetSpacing() == img.GetSpacing() - assert img.GetDirection() == expected.GetDirection() + # assert np.allclose(sitk.GetArrayViewFromImage(img), sitk.GetArrayViewFromImage(expected)) + # assert img.GetSize() == mv.shape + # assert np.allclose(img.GetOrigin(), expected.GetOrigin()) + # assert img.GetSpacing() == img.GetSpacing() + # assert img.GetDirection() == expected.GetDirection() + + # mv = MedicalVolume(np.zeros((10, 20, 1, 3)), affine=self._AFFINE) + # img = mv.to_sitk(vdim=-1) + # assert np.all(sitk.GetArrayViewFromImage(img) == 0) + # assert img.GetSize() == (10, 20, 1) + + # filepath = pydd.get_testdata_file("MR_small.dcm") + # dr = DicomReader(group_by=None) + # mv = dr.load(filepath)[0] + # mv2 = MedicalVolume.from_sitk( + # mv.to_sitk(transpose_inplane=True), copy=True, transpose_inplane=True + # ) + # assert mv2.is_identical(mv) + + @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") + def test_to_sitk_dicom_coronal(self): + axis = "coronal" + cor_folder_names = ["SER00007", "12869311_cor", "12869314_cor"] + + for cor_folder in cor_folder_names: + cor_example_path = ututils.get_scan_dirpath(cor_folder) + + # Load the DICOM data using SimpleITK + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(cor_example_path, useSeriesDetails=True) + reader.SetFileNames(dicom_names) + expected_image = reader.Execute() + + # Load the DICOM data using DOSMA + dr = DicomReader() + mv = dr.load(cor_example_path, group_by=None)[0] + + # Convert to SimpleITK using the MedicalVolume method + img_sitk = mv.to_sitk(image_orientation=axis, vdim=None, transpose_inplane=False) + + # Assertions to ensure correctness + np.testing.assert_allclose( + sitk.GetArrayFromImage(img_sitk), + sitk.GetArrayFromImage(expected_image), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetOrigin(), + expected_image.GetOrigin(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetSpacing(), + expected_image.GetSpacing(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetDirection(), + expected_image.GetDirection(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + + @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") + def test_to_sitk_dicom_sagittal(self): + axis = "sagittal" + sag_folder_names = ["SER00005", "12869310_sag", "15252_Ser10"] + + for sag_folder in sag_folder_names: + sag_example_path = ututils.get_scan_dirpath(sag_folder) + + # Load the DICOM data using SimpleITK + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(sag_example_path, useSeriesDetails=True) + reader.SetFileNames(dicom_names) + expected_image = reader.Execute() + + # Load the DICOM data using DOSMA + dr = DicomReader() + mv = dr.load(sag_example_path, group_by=None)[0] + + # Convert to SimpleITK using the MedicalVolume method + img_sitk = mv.to_sitk(image_orientation=axis, vdim=None, transpose_inplane=False) + + # Assertions to ensure correctness + np.testing.assert_allclose( + img_sitk.GetOrigin(), + expected_image.GetOrigin(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetSpacing(), + expected_image.GetSpacing(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetDirection(), + expected_image.GetDirection(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + sitk.GetArrayFromImage(img_sitk), + sitk.GetArrayFromImage(expected_image), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + + @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") + def test_to_sitk_dicom_axial(self): + axis = "axial" + ax_folder_names = ["SER00003", "12869304_axial", "12869313_axial"] + + for ax_folder in ax_folder_names: + ax_example_path = ututils.get_scan_dirpath(ax_folder) + + # Load the DICOM data using SimpleITK + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(ax_example_path, useSeriesDetails=True) + reader.SetFileNames(dicom_names) + expected_image = reader.Execute() + + # Load the DICOM data using DOSMA + dr = DicomReader() + mv = dr.load(ax_example_path, group_by=None)[0] + + # Convert to SimpleITK using the MedicalVolume method + img_sitk = mv.to_sitk(image_orientation=axis, vdim=None, transpose_inplane=False) + + # Assertions to ensure correctness + np.testing.assert_allclose( + sitk.GetArrayFromImage(img_sitk), + sitk.GetArrayFromImage(expected_image), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetOrigin(), + expected_image.GetOrigin(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetSpacing(), + expected_image.GetSpacing(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + img_sitk.GetDirection(), + expected_image.GetDirection(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) - mv = MedicalVolume(np.zeros((10, 20, 1, 3)), affine=self._AFFINE) - img = mv.to_sitk(vdim=-1) - assert np.all(sitk.GetArrayViewFromImage(img) == 0) - assert img.GetSize() == (10, 20, 1) + @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") + def test_to_sitk_dicom_flipping(self): + path_image = ututils.get_scan_dirpath("15252_Ser10") + + # Load the DICOM data using SimpleITK + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(path_image, useSeriesDetails=True) + reader.SetFileNames(dicom_names) + expected_image = reader.Execute() + + # Load the DICOM data using DOSMA + dr = DicomReader() + mv = dr.load(path_image, group_by=None)[0] + + # Test all combinations of flipping axes + for flip_array_x in [True, False]: + for flip_array_y in [True, False]: + for flip_array_z in [True, False]: + flip_logic = [flip_array_x, flip_array_y, flip_array_z] + + # Flip the SimpleITK image + sitk_img_flipped = sitk.Flip(expected_image, flip_logic) + + # Convert to SimpleITK using the MedicalVolume method with flipping + mv_to_sitk = mv.to_sitk( + image_orientation="sagittal", + vdim=None, + transpose_inplane=False, + flip_array_x=flip_array_x, + flip_array_y=flip_array_y, + flip_array_z=flip_array_z, + ) + + # Assertions to ensure correctness + np.testing.assert_allclose( + sitk.GetArrayFromImage(mv_to_sitk), + sitk.GetArrayFromImage(sitk_img_flipped), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetOrigin(), + sitk_img_flipped.GetOrigin(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetSpacing(), + sitk_img_flipped.GetSpacing(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetDirection(), + sitk_img_flipped.GetDirection(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + + @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") + def test_to_sitk_dicom_inplane_rotation(self): + path_image = ututils.get_scan_dirpath("15252_Ser10") + + # Load the DICOM data using SimpleITK + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(path_image, useSeriesDetails=True) + reader.SetFileNames(dicom_names) + expected_image = reader.Execute() - filepath = pydd.get_testdata_file("MR_small.dcm") - dr = DicomReader(group_by=None) - mv = dr.load(filepath)[0] - mv2 = MedicalVolume.from_sitk( - mv.to_sitk(transpose_inplane=True), copy=True, transpose_inplane=True + # Load the DICOM data using DOSMA + dr = DicomReader() + mv = dr.load(path_image, group_by=None)[0] + + # Perform in-plane rotation on the SimpleITK image + sitk_img_rotated = sitk.PermuteAxes(expected_image, [1, 0, 2]) + + # Convert to SimpleITK using the MedicalVolume method with in-plane rotation + mv_to_sitk = mv.to_sitk( + image_orientation="sagittal", + vdim=None, + transpose_inplane=True, + flip_array_x=False, + flip_array_y=False, + flip_array_z=False, + ) + + # Assertions to ensure correctness + np.testing.assert_allclose( + sitk.GetArrayFromImage(mv_to_sitk), + sitk.GetArrayFromImage(sitk_img_rotated), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetOrigin(), + sitk_img_rotated.GetOrigin(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetSpacing(), + sitk_img_rotated.GetSpacing(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, + ) + np.testing.assert_allclose( + mv_to_sitk.GetDirection(), + sitk_img_rotated.GetDirection(), + atol=self.TO_SITK_ATOL, + rtol=self.TO_SITK_RTOL, ) - assert mv2.is_identical(mv) @unittest.skipIf(not ututils.is_data_available(), "unittest data is not available") def test_to_from_sitk_dicom_convention(self): @@ -214,7 +471,9 @@ def test_to_from_sitk_dicom_convention(self): reader.SetFileNames(dicom_names) sitk_image = reader.Execute() - sitk_from_mv = mv.to_sitk(transpose_inplane=True) + # got rid of transpose below - this is not needed if properly using the orientation + # method now. + sitk_from_mv = mv.to_sitk(image_orientation="sagittal", transpose_inplane=False) img, expected = sitk_from_mv, sitk_image assert np.allclose(sitk.GetArrayViewFromImage(img), sitk.GetArrayViewFromImage(expected)) assert img.GetSize() == mv.shape @@ -536,7 +795,7 @@ def test_to_torch(self): assert torch.all(tensor == torch.from_numpy(vol)) assert tensor.shape == mv.shape - vol = np.ones((10, 20, 30), np.complex) + vol = np.ones((10, 20, 30), complex) mv = MedicalVolume(vol, self._AFFINE) tensor = mv.to_torch() diff --git a/tests/models/test_oaiunet2d.py b/tests/models/test_oaiunet2d.py index 84e59fbf..c7b8d66a 100644 --- a/tests/models/test_oaiunet2d.py +++ b/tests/models/test_oaiunet2d.py @@ -9,7 +9,7 @@ from dosma.models.seg_model import whiten_volume from dosma.tissues.femoral_cartilage import FemoralCartilage -import keras.backend as K +import tensorflow.keras.backend as K from .. import util @@ -132,8 +132,8 @@ def test_segmentation(self): K.clear_session() for i, tissue in enumerate(classes): - pred = masks[tissue].volume.astype(np.bool) - gt = expected_seg[..., i].astype(np.bool) + pred = masks[tissue].volume.astype(bool) + gt = expected_seg[..., i].astype(bool) dice = 2 * np.sum(pred & gt) / np.sum(pred.astype(np.uint8) + gt.astype(np.uint8)) # Zero-mean normalization of 32-bit vs 64-bit data results in slightly different # estimations of the mean and standard deviation. diff --git a/tests/models/test_util.py b/tests/models/test_util.py index e2481b67..ecbb4b36 100644 --- a/tests/models/test_util.py +++ b/tests/models/test_util.py @@ -1,6 +1,6 @@ import unittest -from dosma.models import util as m_util +from dosma.models import model_loading_util as m_util class TestUtil(unittest.TestCase): @@ -13,7 +13,7 @@ def test_aliases_exist(self): aliases = m.ALIASES # all supported models must have at least 1 alias that is not '' - valid_alias = len(aliases) >= 1 and all([x != "" for x in aliases]) + valid_alias = len(aliases) >= 1 and all(x != "" for x in aliases) assert valid_alias, "%s does not have valid aliases" % m diff --git a/tests/scan_sequences/mri/test_qdess.py b/tests/scan_sequences/mri/test_qdess.py index 0a680ab5..71a19e0e 100644 --- a/tests/scan_sequences/mri/test_qdess.py +++ b/tests/scan_sequences/mri/test_qdess.py @@ -7,11 +7,11 @@ from dosma.core.med_volume import MedicalVolume from dosma.core.quant_vals import QuantitativeValue -from dosma.models.util import get_model +from dosma.models.model_loading_util import get_model from dosma.scan_sequences.mri.qdess import QDess from dosma.tissues.femoral_cartilage import FemoralCartilage -import keras.backend as K +import tensorflow.keras.backend as K from ... import util @@ -92,7 +92,7 @@ def test_segmentation_multiclass(self): """Test support for multiclass segmentation.""" scan = self.SCAN_TYPE.from_dicom(self.dicom_dirpath, num_workers=util.num_workers()) tissue = FemoralCartilage() - tissue.find_weights(SEGMENTATION_WEIGHTS_FOLDER), + tissue.find_weights(SEGMENTATION_WEIGHTS_FOLDER), # noqa B018 dims = scan.get_dimensions() input_shape = (dims[0], dims[1], 1) model = get_model( diff --git a/tests/scan_sequences/test_scan_io.py b/tests/scan_sequences/test_scan_io.py index 1f8c8dd2..6f973557 100644 --- a/tests/scan_sequences/test_scan_io.py +++ b/tests/scan_sequences/test_scan_io.py @@ -38,7 +38,7 @@ def some_property(self): class TestScanIOMixin(ututils.TempPathMixin): def test_from_dicom(self): mr_dcm = get_testdata_file("MR_small.dcm") - fs = pydicom.read_file(mr_dcm) + fs = pydicom.dcmread(mr_dcm) arr = fs.pixel_array scan = MockScanIOMixin.from_dicom(mr_dcm, foo="foofoo", bar="barbar") diff --git a/tests/util.py b/tests/util.py index a46f4959..7e621494 100644 --- a/tests/util.py +++ b/tests/util.py @@ -30,7 +30,22 @@ UNITTEST_SCANDATA_PATH, f"temp-{str(uuid.uuid1())}-{str(uuid.uuid4())}" ) # should be used when for writing with assert_raises clauses -SCANS = ["qdess", "mapss", "cubequant", "cones"] +SCANS = [ + "qdess", + "mapss", + "cubequant", + "cones", + "15252_Ser10", + "12869304_axial", + "12869310_sag", + "12869311_cor", + "12869313_axial", + "12869314_cor", + "MTR_051", + "SER00003", + "SER00005", + "SER00007", +] SCANS_INFO = { "mapss": {"expected_num_echos": 7}, "qdess": {"expected_num_echos": 2}, @@ -109,7 +124,7 @@ def requires_packages(*packages): def _decorator(func): def _wrapper(*args, **kwargs): - if all([env.package_available(x) for x in packages]): + if all(env.package_available(x) for x in packages): func(*args, **kwargs) return _wrapper