diff --git a/packages/common/src/weathergen/common/data.py b/packages/common/src/weathergen/common/data.py new file mode 100644 index 000000000..d36e5b77c --- /dev/null +++ b/packages/common/src/weathergen/common/data.py @@ -0,0 +1,269 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import dataclasses +import logging + +import numpy as np +from numpy import datetime64 +from numpy.typing import NDArray + +type DType = np.float32 +type NPDT64 = datetime64 + +_logger = logging.getLogger(__name__) + +_DT_ZERO = np.datetime64("1850-01-01T00:00") + + +@dataclasses.dataclass +class DTRange: + """ + Defines a time window for indexing into datasets. + + It is defined as numpy datetime64 objects. + """ + + start: NPDT64 + end: NPDT64 + + def __post_init__(self): + if self.start >= self.end: + raise ValueError("start time must be before end time") + if self.start <= _DT_ZERO: + raise ValueError("start time must be after 1850-01-01T00:00") + + +@dataclasses.dataclass +class ReaderData: + """ + Wrapper for return values from DataReader.get_source and DataReader.get_target. + """ + + coords: NDArray[DType] + geoinfos: NDArray[DType] + data: NDArray[DType] + datetimes: NDArray[NPDT64] + is_spoof: bool = False + + def __len__(self): + return len(self.data) + + @classmethod + def create(cls, other: "ReaderData") -> "ReaderData": + """ + Create an instance from another ReaderData instance. + + Parameters + ------ + other: ReaderData + Input data. + + Returns + ------ + ReaderData: + Has the same underlying data as `other`. + """ + if other is None: + raise TypeError("Input cannot be None.") + + if not isinstance(other, ReaderData): + raise TypeError(f"Expected input of type ReaderData. Got {type(other)}") + + coords = np.asarray(other.coords) + geoinfos = np.asarray(other.geoinfos) + data = np.asarray(other.data) + datetimes = np.asarray(other.datetimes) + + n_datapoints = len(data) + + assert coords.shape == (n_datapoints, 2), "number of datapoints do not match data" + assert geoinfos.shape[0] == n_datapoints, "number of datapoints do not match data" + assert datetimes.shape[0] == n_datapoints, "number of datapoints do not match data" + + return cls(**dataclasses.asdict(other)) + + @classmethod + def combine(cls, others: list["ReaderData"]) -> "ReaderData": + """ + Create an instance from ReaderData instance by combining multiple ones. + + Parameters + ------ + others: list[ReaderData] + A list of input datas to combine. + + Returns + ------ + ReaderData + Instance with concatenated input data. + """ + if others is None: + raise TypeError("Input cannot be None.") + + if not isinstance(others, list): + raise TypeError(f"Input must be a List. Got {type(others)}") + + assert len(others) > 0, len(others) + + first = others[0] + coords = np.zeros((0, first.coords.shape[1]), dtype=first.coords.dtype) + geoinfos = np.zeros((0, first.geoinfos.shape[1]), dtype=first.geoinfos.dtype) + data = np.zeros((0, first.data.shape[1]), dtype=first.data.dtype) + datetimes = np.array([], dtype=first.datetimes.dtype) + is_spoof = True + + for item in others: + n_datapoints = len(item.data) + assert item.coords.shape == (n_datapoints, 2), "number of datapoints do not match" + assert item.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match" + assert item.datetimes.shape[0] == n_datapoints, "number of datapoints do not match" + + coords = np.concatenate([coords, item.coords]) + geoinfos = np.concatenate([geoinfos, item.geoinfos]) + data = np.concatenate([data, item.data]) + datetimes = np.concatenate([datetimes, item.datetimes]) + is_spoof = is_spoof and item.is_spoof + + return cls(coords, geoinfos, data, datetimes, is_spoof) + + @staticmethod + def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": + """ + Create an empty ReaderData object + + Parameters + ------ + num_data_fields: int + Number of data fields. + num_geo_fields: + Number of geo fields. + + Returns + ------- + ReaderData + Empty ReaderData object + """ + return ReaderData( + coords=np.zeros((0, 2), dtype=np.float32), + geoinfos=np.zeros((0, num_geo_fields), dtype=np.float32), + data=np.zeros((0, num_data_fields), dtype=np.float32), + datetimes=np.zeros((0,), dtype=np.datetime64), + is_spoof=False, + ) + + def is_empty(self): + """ + Test if data object is empty + """ + return len(self) == 0 + + def remove_nan_coords(self) -> "ReaderData": + """ + Remove all data points where coords are NaN + + Returns + ------- + ReaderData + """ + # Identify valid coordinates (where both lat/lon are not NaN) + idx_valid = ~np.isnan(self.coords).any(axis=1) + + # Apply filtering + return ReaderData( + self.coords[idx_valid], + self.geoinfos[idx_valid], + self.data[idx_valid], + self.datetimes[idx_valid], + self.is_spoof, + ) + + def shuffle(self, rng, shuffle: bool, num_subset: int) -> "ReaderData": + """ + Drop a random subset of points as specified by num_subset + num_subset = -1 indicates no points to be dropped + + Returns + ------- + self + """ + + # nothing to be done + if num_subset < 0 and shuffle is False: + return self + + num_datapoints = self.coords.shape[0] + if (num_datapoints == 0) or (num_datapoints < num_subset and shuffle is False): + return self + + # only shuffling + if num_subset == -1 and shuffle is True: + num_subset = num_datapoints + + # ensure num_subset <= num_datapoints + num_subset = min(num_subset, num_datapoints) + + idxs_subset = rng.choice(num_datapoints, num_subset, replace=False) + if shuffle is False: + idxs_subset = np.sort(idxs_subset) + + self.coords = self.coords[idxs_subset] + self.geoinfos = self.geoinfos[idxs_subset] + self.data = self.data[idxs_subset] + self.datetimes = self.datetimes[idxs_subset] + + return self + + +def check_reader_data(rdata: ReaderData, dtr: DTRange) -> None: + """ + Check that ReaderData is valid + + Parameters + ---------- + rdata : + ReaderData to check + dtr : + datetime range of window for which the rdata is valid + + Returns + ------- + None + """ + + # Validate dimensions + assert rdata.coords.ndim == 2, f"coords must be 2D {rdata.coords.shape}" + assert rdata.coords.shape[1] == 2, ( + f"coords must have 2 columns (lat, lon), got {rdata.coords.shape}" + ) + assert rdata.geoinfos.ndim == 2, f"geoinfos must be 2D, got {rdata.geoinfos.shape}" + assert rdata.data.ndim == 2, f"data must be 2D {rdata.data.shape}" + assert rdata.datetimes.ndim == 1, f"datetimes must be 1D {rdata.datetimes.shape}" + + # Validate consistency of lengths + n_points = rdata.coords.shape[0] + assert n_points == rdata.data.shape[0], "coords and data must have same length" + assert n_points == rdata.geoinfos.shape[0], "geoinfos and data must have same length" + + # Check that all fields have the same length + assert ( + rdata.coords.shape[0] + == rdata.geoinfos.shape[0] + == rdata.data.shape[0] + == rdata.datetimes.shape[0] + ), ( + f"coords, geoinfos, data and datetimes must have the same length " + f"{rdata.coords.shape[0]}, {rdata.geoinfos.shape[0]}, {rdata.data.shape[0]}, " + f"{rdata.datetimes.shape[0]}" + ) + + # Check that all datetimes fall within the specified range + assert np.logical_and(rdata.datetimes >= dtr.start, rdata.datetimes < dtr.end).all(), ( + f"datetimes for data points violate window {dtr}." + ) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2243cdee8..aee730904 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -27,6 +27,8 @@ from zarr.errors import ZarrUserWarning from zarr.storage import LocalStore, ZipStore +from weathergen.common.data import ReaderData + # experimental value, should be inferred more intelligently SHARDING_ENABLED = True SHARD_N_SAMPLES = 40320 @@ -119,79 +121,6 @@ def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": return TimeRange(self.start + offset, self.end + offset) -@dataclasses.dataclass -class IOReaderData: - """ - Equivalent to data_reader_base.ReaderData - - This class needs to exist since otherwise the common package would - have a dependecy on the core model. Ultimately a unified data model - should be implemented in the common package. - """ - - coords: NDArray[DType] - geoinfos: NDArray[DType] - data: NDArray[DType] - datetimes: NDArray[NPDT64] - is_spoof: bool = False - - def is_empty(self): - """ - Test if data object is empty - """ - return len(self.data) == 0 - - @classmethod - def create(cls, other: typing.Any) -> "IOReaderData": - """ - Create an instance from data_reader_base.ReaderData instance. - - other should be such an instance. - """ - coords = np.asarray(other.coords) - geoinfos = np.asarray(other.geoinfos) - data = np.asarray(other.data) - datetimes = np.asarray(other.datetimes) - - n_datapoints = len(data) - - assert coords.shape == (n_datapoints, 2), "number of datapoints do not match data" - assert geoinfos.shape[0] == n_datapoints, "number of datapoints do not match data" - assert datetimes.shape[0] == n_datapoints, "number of datapoints do not match data" - - return cls(**dataclasses.asdict(other)) - - @classmethod - def combine(cls, others: list["IOReaderData"]) -> "IOReaderData": - """ - Create an instance from data_reader_base.ReaderData instance by combining mulitple ones. - - others is list of ReaderData instances. - """ - assert len(others) > 0, len(others) - - other = others[0] - coords = np.zeros((0, other.coords.shape[1]), dtype=other.coords.dtype) - geoinfos = np.zeros((0, other.geoinfos.shape[1]), dtype=other.geoinfos.dtype) - data = np.zeros((0, other.data.shape[1]), dtype=other.data.dtype) - datetimes = np.array([], dtype=other.datetimes.dtype) - is_spoof = True - - for other in others: - n_datapoints = len(other.data) - assert other.coords.shape == (n_datapoints, 2), "number of datapoints do not match" - assert other.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match" - assert other.datetimes.shape[0] == n_datapoints, "number of datapoints do not match" - - coords = np.concatenate([coords, other.coords]) - geoinfos = np.concatenate([geoinfos, other.geoinfos]) - data = np.concatenate([data, other.data]) - datetimes = np.concatenate([datetimes, other.datetimes]) - is_spoof = is_spoof and other.is_spoof - - return cls(coords, geoinfos, data, datetimes, is_spoof) - - @dataclasses.dataclass class ItemKey: """Metadata to identify one output item.""" @@ -555,7 +484,7 @@ class OutputBatchData: # sample, stream, tensor(datapoint, channel+coords) # => datapoints is accross all datasets per stream - sources: list[list[IOReaderData]] + sources: list[list[ReaderData]] # sample source_intervals: list[TimeRange] @@ -743,7 +672,7 @@ def _extract_sources( channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] - source: IOReaderData = self.sources[sample][stream_idx] + source: ReaderData = self.sources[sample][stream_idx] assert source.data.shape[1] == len(channels), ( f"Number of source channel names {len(channels)} does not align with source data." diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_cams.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_cams.py index 5f76098f0..227f77551 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_cams.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_cams.py @@ -6,13 +6,12 @@ import numpy as np import xarray as xr +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - ReaderData, TimeWindowHandler, TIndex, - check_reader_data, ) ############################################################################ diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py index 4f0157792..470f2993e 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py @@ -15,12 +15,11 @@ import xarray as xr from numpy.typing import NDArray +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - ReaderData, TimeWindowHandler, TIndex, - check_reader_data, str_to_timedelta, ) diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon_esm.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon_esm.py index d167f4dfd..8d12d70e4 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon_esm.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon_esm.py @@ -17,13 +17,12 @@ import xarray as xr import zarr +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - ReaderData, TimeWindowHandler, TIndex, - check_reader_data, ) _logger = logging.getLogger(__name__) diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py index 748a3499d..afe8a4fe5 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_iconart.py @@ -17,13 +17,12 @@ import xarray as xr from numpy.typing import NDArray +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - ReaderData, TimeWindowHandler, TIndex, - check_reader_data, ) _logger = logging.getLogger(__name__) diff --git a/src/weathergen/datasets/data_reader_anemoi.py b/src/weathergen/datasets/data_reader_anemoi.py index 88031b3a7..24333fa7a 100644 --- a/src/weathergen/datasets/data_reader_anemoi.py +++ b/src/weathergen/datasets/data_reader_anemoi.py @@ -17,12 +17,11 @@ from anemoi.datasets.data.dataset import Dataset from numpy.typing import NDArray +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - ReaderData, TimeWindowHandler, TIndex, - check_reader_data, ) _logger = logging.getLogger(__name__) diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 3fe20717e..5a469fb54 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -16,6 +16,7 @@ from numpy.typing import NDArray from weathergen.common.config import timedelta_to_str +from weathergen.common.data import DTRange, ReaderData from weathergen.utils.better_abc import ABCMeta, abstract_attribute _logger = logging.getLogger(__name__) @@ -48,22 +49,6 @@ class TimeIndexRange: end: TIndex -@dataclass -class DTRange: - """ - Defines a time window for indexing into datasets. - - It is defined as numpy datetime64 objects. - """ - - start: NPDT64 - end: NPDT64 - - def __post_init__(self): - assert self.start < self.end, "start time must be before end time" - assert self.start > _DT_ZERO, "start time must be after 1850-01-01T00:00" - - class TimeWindowHandler: """ Handler for time windows and translation of indices to times @@ -145,150 +130,6 @@ def window(self, idx: TIndex) -> DTRange: return DTRange(t_start_win, t_end_win) -@dataclass -class ReaderData: - """ - Wrapper for return values from DataReader.get_source and DataReader.get_target - """ - - coords: NDArray[DType] - geoinfos: NDArray[DType] - data: NDArray[DType] - datetimes: NDArray[NPDT64] - is_spoof: bool = False - - @staticmethod - def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData": - """ - Create an empty ReaderData object - - Returns - ------- - ReaderData - Empty ReaderData object - """ - return ReaderData( - coords=np.zeros((0, 2), dtype=np.float32), - geoinfos=np.zeros((0, num_geo_fields), dtype=np.float32), - data=np.zeros((0, num_data_fields), dtype=np.float32), - datetimes=np.zeros((0,), dtype=np.datetime64), - is_spoof=False, - ) - - def is_empty(self): - return self.len() == 0 - - def len(self): - """ - Length of data - - Returns - ------- - length of data - """ - return len(self.data) - - def remove_nan_coords(self) -> "ReaderData": - """ - Remove all data points where coords are NaN - - Returns - ------- - self - """ - idx_valid = ~np.isnan(self.coords) - # filter should be if any (of the two) coords is NaN - idx_valid = np.logical_and(idx_valid[:, 0], idx_valid[:, 1]) - - # apply - return ReaderData( - self.coords[idx_valid], - self.geoinfos[idx_valid], - self.data[idx_valid], - self.datetimes[idx_valid], - ) - - def shuffle(self, rng, shuffle: bool, num_subset: int) -> "ReaderData": - """ - Drop a random subset of points as specified by num_subset - num_subset = -1 indicates no points to be dropped - - Returns - ------- - self - """ - - # nothing to be done - if num_subset < 0 and shuffle is False: - return self - - num_datapoints = self.coords.shape[0] - if (num_datapoints == 0) or (num_datapoints < num_subset and shuffle is False): - return self - - # only shuffling - if num_subset == -1 and shuffle is True: - num_subset = num_datapoints - - # ensure num_subset <= num_datapoints - num_subset = min(num_subset, num_datapoints) - - idxs_subset = rng.choice(num_datapoints, num_subset, replace=False) - if shuffle is False: - idxs_subset = np.sort(idxs_subset) - - self.coords = self.coords[idxs_subset] - self.geoinfos = self.geoinfos[idxs_subset] - self.data = self.data[idxs_subset] - self.datetimes = self.datetimes[idxs_subset] - - return self - - -def check_reader_data(rdata: ReaderData, dtr: DTRange) -> None: - """ - Check that ReaderData is valid - - Parameters - ---------- - rdata : - ReaderData to check - dtr : - datetime range of window for which the rdata is valid - - Returns - ------- - None - """ - - assert rdata.coords.ndim == 2, f"coords must be 2D {rdata.coords.shape}" - assert rdata.coords.shape[1] == 2, ( - f"coords must have 2 columns (lat, lon), got {rdata.coords.shape}" - ) - assert rdata.geoinfos.ndim == 2, f"geoinfos must be 2D, got {rdata.geoinfos.shape}" - assert rdata.data.ndim == 2, f"data must be 2D {rdata.data.shape}" - assert rdata.datetimes.ndim == 1, f"datetimes must be 1D {rdata.datetimes.shape}" - - assert rdata.coords.shape[0] == rdata.data.shape[0], "coords and data must have same length" - assert rdata.geoinfos.shape[0] == rdata.data.shape[0], "geoinfos and data must have same length" - - # Check that all fields have the same length - assert ( - rdata.coords.shape[0] - == rdata.geoinfos.shape[0] - == rdata.data.shape[0] - == rdata.datetimes.shape[0] - ), ( - f"coords, geoinfos, data and datetimes must have the same length " - f"{rdata.coords.shape[0]}, {rdata.geoinfos.shape[0]}, {rdata.data.shape[0]}, " - f"{rdata.datetimes.shape[0]}" - ) - - assert np.logical_and(rdata.datetimes >= dtr.start, rdata.datetimes < dtr.end).all(), ( - f"datetimes for data points violate window {dtr}." - ) - - class DataReaderBase(metaclass=ABCMeta): """ Base class for data readers. diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index db46bdf73..c138e1fba 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -17,11 +17,10 @@ import numpy as np import zarr +from weathergen.common.data import DTRange, ReaderData from weathergen.datasets.data_reader_base import ( DataReaderTimestep, - DTRange, NDArray, - ReaderData, TimeWindowHandler, TIndex, t_epsilon, diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 139a6aa7e..506102c6f 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -15,11 +15,10 @@ import numpy as np import zarr +from weathergen.common.data import ReaderData, check_reader_data from weathergen.datasets.data_reader_base import ( DataReaderBase, - ReaderData, TimeWindowHandler, - check_reader_data, ) _logger = logging.getLogger(__name__) diff --git a/src/weathergen/datasets/memory_pinning.py b/src/weathergen/datasets/memory_pinning.py index da69e6e48..961c6f54f 100644 --- a/src/weathergen/datasets/memory_pinning.py +++ b/src/weathergen/datasets/memory_pinning.py @@ -2,7 +2,7 @@ import torch -from weathergen.common.io import IOReaderData +from weathergen.common.data import ReaderData @runtime_checkable @@ -20,13 +20,13 @@ class Pinnable(Protocol): def pin_memory(self): ... -def pin_object(obj: Pinnable | torch.Tensor | IOReaderData | list | dict | None): +def pin_object(obj: Pinnable | torch.Tensor | ReaderData | list | dict | None): if obj is None: return elif isinstance(obj, torch.Tensor | Pinnable): obj.pin_memory() - elif isinstance(obj, IOReaderData): - # Special case: IOReaderData is in common package and can't have torch deps + elif isinstance(obj, ReaderData): + # Special case: ReaderData is in common package and can't have torch deps # Note: These SHOULD be numpy arrays per the type hints, but might be tensors pin_object(obj.coords) pin_object(obj.data) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 86049d389..805e1f373 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -14,7 +14,7 @@ import torch from weathergen.common.config import Config -from weathergen.common.io import IOReaderData +from weathergen.common.data import ReaderData from weathergen.datasets.batch import ModelBatch from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi from weathergen.datasets.data_reader_base import ( @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) -def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IOReaderData: +def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> ReaderData: """ Utility function to collect all sources / targets from streams list @@ -72,7 +72,7 @@ def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IORe rdata.geoinfos = ds.normalize_geoinfos(rdata.geoinfos) rdatas += [rdata] - return IOReaderData.combine(rdatas) + return ReaderData.combine(rdatas) class MultiStreamDataSampler(torch.utils.data.IterableDataset): diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 152c092b2..b848efb58 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -11,7 +11,7 @@ import numpy as np import torch -from weathergen.common.io import IOReaderData +from weathergen.common.io import ReaderData def _pin_tensor(tensor: torch.Tensor) -> torch.Tensor: @@ -128,7 +128,7 @@ def pin_memory(self): self.source_idxs_embed = _pin_tensor_list(self.source_idxs_embed) self.source_idxs_embed_pe = _pin_tensor_list(self.source_idxs_embed_pe) - # Pin source_raw (list of IOReaderData objects) + # Pin source_raw (list of ReaderData objects) if hasattr(self, "source_raw"): for raw_data in self.source_raw: if raw_data is not None and hasattr(raw_data, "pin_memory"): @@ -170,14 +170,14 @@ def to_device(self, device: str) -> None: return self def add_source( - self, step: int, ss_raw: IOReaderData, ss_lens: torch.Tensor, ss_cells: list + self, step: int, ss_raw: ReaderData, ss_lens: torch.Tensor, ss_cells: list ) -> None: """ Add data for source for one input. Parameters ---------- - ss_raw : IOReaderData( dataclass containing coords, geoinfos, data, and datetimes ) + ss_raw : ReaderData( dataclass containing coords, geoinfos, data, and datetimes ) ss_lens : torch.Tensor( number of healpix cells ) ss_cells : list( number of healpix cells ) [ torch.Tensor( tokens per cell, token size, number of channels) ] @@ -373,7 +373,7 @@ def is_spoof(self) -> bool: return self.source_is_spoof or self.target_is_spoof -def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderData: +def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> ReaderData: """ Spoof an instance from data_reader_base.ReaderData instance. other should be such an instance. @@ -409,4 +409,4 @@ def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderD (n_datapoints,), ) - return IOReaderData(coords, geoinfos, data, datetimes) + return ReaderData(coords, geoinfos, data, datetimes) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 6dfe71c89..b8b00ecdc 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -11,7 +11,7 @@ import numpy as np import torch -from weathergen.common.io import IOReaderData +from weathergen.common.io import ReaderData from weathergen.datasets.batch import SampleMetaData from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer @@ -25,7 +25,7 @@ ) -def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData: +def readerdata_to_torch(rdata: ReaderData) -> ReaderData: """ Convert data, coords, and geoinfos to torch tensor """ @@ -122,7 +122,7 @@ def cell_to_token_mask(self, idxs_cells, idxs_cells_lens, mask): def get_source( self, stream_info: dict, - rdata: IOReaderData, + rdata: ReaderData, idxs_cells_data, time_win: tuple, cell_mask: torch.Tensor, @@ -153,7 +153,7 @@ def get_source( def get_target_coords( self, stream_info: dict, - rdata: IOReaderData, + rdata: ReaderData, token_data, time_win: tuple, cell_mask, @@ -186,7 +186,7 @@ def get_target_coords( def get_target_values( self, stream_info: dict, - rdata: IOReaderData, + rdata: ReaderData, token_data, time_win: tuple, cell_mask, diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 17295d18a..e1598de3a 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -4,7 +4,7 @@ from astropy_healpix.healpy import ang2pix from torch import Tensor -from weathergen.common.io import IOReaderData +from weathergen.common.data import ReaderData from weathergen.datasets.utils import ( locs_to_cell_coords_ctrs, locs_to_ctr_coords, @@ -209,7 +209,7 @@ def tokenize_spacetime( for _, t in enumerate(t_unique): # data for current time step mask = t == rdata.datetimes - rdata_cur = IOReaderData( + rdata_cur = ReaderData( rdata.coords[mask], rdata.geoinfos[mask], rdata.data[mask], rdata.datetimes[mask] ) idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens)