Skip to content
269 changes: 269 additions & 0 deletions packages/common/src/weathergen/common/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import dataclasses
import logging

import numpy as np
from numpy import datetime64
from numpy.typing import NDArray

type DType = np.float32
type NPDT64 = datetime64

_logger = logging.getLogger(__name__)

_DT_ZERO = np.datetime64("1850-01-01T00:00")


@dataclasses.dataclass
class DTRange:
"""
Defines a time window for indexing into datasets.

It is defined as numpy datetime64 objects.
"""

start: NPDT64
end: NPDT64

def __post_init__(self):
if self.start >= self.end:
raise ValueError("start time must be before end time")
if self.start <= _DT_ZERO:
raise ValueError("start time must be after 1850-01-01T00:00")


@dataclasses.dataclass
class ReaderData:
"""
Wrapper for return values from DataReader.get_source and DataReader.get_target.
"""

coords: NDArray[DType]
geoinfos: NDArray[DType]
data: NDArray[DType]
datetimes: NDArray[NPDT64]
is_spoof: bool = False

def __len__(self):
return len(self.data)

@classmethod
def create(cls, other: "ReaderData") -> "ReaderData":
"""
Create an instance from another ReaderData instance.

Parameters
------
other: ReaderData
Input data.

Returns
------
ReaderData:
Has the same underlying data as `other`.
"""
if other is None:
raise TypeError("Input cannot be None.")

if not isinstance(other, ReaderData):
raise TypeError(f"Expected input of type ReaderData. Got {type(other)}")

coords = np.asarray(other.coords)
geoinfos = np.asarray(other.geoinfos)
data = np.asarray(other.data)
datetimes = np.asarray(other.datetimes)

n_datapoints = len(data)

assert coords.shape == (n_datapoints, 2), "number of datapoints do not match data"
assert geoinfos.shape[0] == n_datapoints, "number of datapoints do not match data"
assert datetimes.shape[0] == n_datapoints, "number of datapoints do not match data"

return cls(**dataclasses.asdict(other))

@classmethod
def combine(cls, others: list["ReaderData"]) -> "ReaderData":
"""
Create an instance from ReaderData instance by combining multiple ones.

Parameters
------
others: list[ReaderData]
A list of input datas to combine.

Returns
------
ReaderData
Instance with concatenated input data.
"""
if others is None:
raise TypeError("Input cannot be None.")

if not isinstance(others, list):
raise TypeError(f"Input must be a List. Got {type(others)}")

assert len(others) > 0, len(others)

first = others[0]
coords = np.zeros((0, first.coords.shape[1]), dtype=first.coords.dtype)
geoinfos = np.zeros((0, first.geoinfos.shape[1]), dtype=first.geoinfos.dtype)
data = np.zeros((0, first.data.shape[1]), dtype=first.data.dtype)
datetimes = np.array([], dtype=first.datetimes.dtype)
is_spoof = True

for item in others:
n_datapoints = len(item.data)
assert item.coords.shape == (n_datapoints, 2), "number of datapoints do not match"
assert item.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match"
assert item.datetimes.shape[0] == n_datapoints, "number of datapoints do not match"

coords = np.concatenate([coords, item.coords])
geoinfos = np.concatenate([geoinfos, item.geoinfos])
data = np.concatenate([data, item.data])
datetimes = np.concatenate([datetimes, item.datetimes])
is_spoof = is_spoof and item.is_spoof

return cls(coords, geoinfos, data, datetimes, is_spoof)

@staticmethod
def empty(num_data_fields: int, num_geo_fields: int) -> "ReaderData":
"""
Create an empty ReaderData object

Parameters
------
num_data_fields: int
Number of data fields.
num_geo_fields:
Number of geo fields.

Returns
-------
ReaderData
Empty ReaderData object
"""
return ReaderData(
coords=np.zeros((0, 2), dtype=np.float32),
geoinfos=np.zeros((0, num_geo_fields), dtype=np.float32),
data=np.zeros((0, num_data_fields), dtype=np.float32),
datetimes=np.zeros((0,), dtype=np.datetime64),
is_spoof=False,
)

def is_empty(self):
"""
Test if data object is empty
"""
return len(self) == 0

def remove_nan_coords(self) -> "ReaderData":
"""
Remove all data points where coords are NaN

Returns
-------
ReaderData
"""
# Identify valid coordinates (where both lat/lon are not NaN)
idx_valid = ~np.isnan(self.coords).any(axis=1)

# Apply filtering
return ReaderData(
self.coords[idx_valid],
self.geoinfos[idx_valid],
self.data[idx_valid],
self.datetimes[idx_valid],
self.is_spoof,
)

def shuffle(self, rng, shuffle: bool, num_subset: int) -> "ReaderData":
"""
Drop a random subset of points as specified by num_subset
num_subset = -1 indicates no points to be dropped

Returns
-------
self
"""

# nothing to be done
if num_subset < 0 and shuffle is False:
return self

num_datapoints = self.coords.shape[0]
if (num_datapoints == 0) or (num_datapoints < num_subset and shuffle is False):
return self

# only shuffling
if num_subset == -1 and shuffle is True:
num_subset = num_datapoints

# ensure num_subset <= num_datapoints
num_subset = min(num_subset, num_datapoints)

idxs_subset = rng.choice(num_datapoints, num_subset, replace=False)
if shuffle is False:
idxs_subset = np.sort(idxs_subset)

self.coords = self.coords[idxs_subset]
self.geoinfos = self.geoinfos[idxs_subset]
self.data = self.data[idxs_subset]
self.datetimes = self.datetimes[idxs_subset]

return self


def check_reader_data(rdata: ReaderData, dtr: DTRange) -> None:
"""
Check that ReaderData is valid

Parameters
----------
rdata :
ReaderData to check
dtr :
datetime range of window for which the rdata is valid

Returns
-------
None
"""

# Validate dimensions
assert rdata.coords.ndim == 2, f"coords must be 2D {rdata.coords.shape}"
assert rdata.coords.shape[1] == 2, (
f"coords must have 2 columns (lat, lon), got {rdata.coords.shape}"
)
assert rdata.geoinfos.ndim == 2, f"geoinfos must be 2D, got {rdata.geoinfos.shape}"
assert rdata.data.ndim == 2, f"data must be 2D {rdata.data.shape}"
assert rdata.datetimes.ndim == 1, f"datetimes must be 1D {rdata.datetimes.shape}"

# Validate consistency of lengths
n_points = rdata.coords.shape[0]
assert n_points == rdata.data.shape[0], "coords and data must have same length"
assert n_points == rdata.geoinfos.shape[0], "geoinfos and data must have same length"

# Check that all fields have the same length
assert (
rdata.coords.shape[0]
== rdata.geoinfos.shape[0]
== rdata.data.shape[0]
== rdata.datetimes.shape[0]
), (
f"coords, geoinfos, data and datetimes must have the same length "
f"{rdata.coords.shape[0]}, {rdata.geoinfos.shape[0]}, {rdata.data.shape[0]}, "
f"{rdata.datetimes.shape[0]}"
)

# Check that all datetimes fall within the specified range
assert np.logical_and(rdata.datetimes >= dtr.start, rdata.datetimes < dtr.end).all(), (
f"datetimes for data points violate window {dtr}."
)
79 changes: 4 additions & 75 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from zarr.errors import ZarrUserWarning
from zarr.storage import LocalStore, ZipStore

from weathergen.common.data import ReaderData

# experimental value, should be inferred more intelligently
SHARDING_ENABLED = True
SHARD_N_SAMPLES = 40320
Expand Down Expand Up @@ -119,79 +121,6 @@ def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange":
return TimeRange(self.start + offset, self.end + offset)


@dataclasses.dataclass
class IOReaderData:
"""
Equivalent to data_reader_base.ReaderData

This class needs to exist since otherwise the common package would
have a dependecy on the core model. Ultimately a unified data model
should be implemented in the common package.
"""

coords: NDArray[DType]
geoinfos: NDArray[DType]
data: NDArray[DType]
datetimes: NDArray[NPDT64]
is_spoof: bool = False

def is_empty(self):
"""
Test if data object is empty
"""
return len(self.data) == 0

@classmethod
def create(cls, other: typing.Any) -> "IOReaderData":
"""
Create an instance from data_reader_base.ReaderData instance.

other should be such an instance.
"""
coords = np.asarray(other.coords)
geoinfos = np.asarray(other.geoinfos)
data = np.asarray(other.data)
datetimes = np.asarray(other.datetimes)

n_datapoints = len(data)

assert coords.shape == (n_datapoints, 2), "number of datapoints do not match data"
assert geoinfos.shape[0] == n_datapoints, "number of datapoints do not match data"
assert datetimes.shape[0] == n_datapoints, "number of datapoints do not match data"

return cls(**dataclasses.asdict(other))

@classmethod
def combine(cls, others: list["IOReaderData"]) -> "IOReaderData":
"""
Create an instance from data_reader_base.ReaderData instance by combining mulitple ones.

others is list of ReaderData instances.
"""
assert len(others) > 0, len(others)

other = others[0]
coords = np.zeros((0, other.coords.shape[1]), dtype=other.coords.dtype)
geoinfos = np.zeros((0, other.geoinfos.shape[1]), dtype=other.geoinfos.dtype)
data = np.zeros((0, other.data.shape[1]), dtype=other.data.dtype)
datetimes = np.array([], dtype=other.datetimes.dtype)
is_spoof = True

for other in others:
n_datapoints = len(other.data)
assert other.coords.shape == (n_datapoints, 2), "number of datapoints do not match"
assert other.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match"
assert other.datetimes.shape[0] == n_datapoints, "number of datapoints do not match"

coords = np.concatenate([coords, other.coords])
geoinfos = np.concatenate([geoinfos, other.geoinfos])
data = np.concatenate([data, other.data])
datetimes = np.concatenate([datetimes, other.datetimes])
is_spoof = is_spoof and other.is_spoof

return cls(coords, geoinfos, data, datetimes, is_spoof)


@dataclasses.dataclass
class ItemKey:
"""Metadata to identify one output item."""
Expand Down Expand Up @@ -555,7 +484,7 @@ class OutputBatchData:

# sample, stream, tensor(datapoint, channel+coords)
# => datapoints is accross all datasets per stream
sources: list[list[IOReaderData]]
sources: list[list[ReaderData]]

# sample
source_intervals: list[TimeRange]
Expand Down Expand Up @@ -743,7 +672,7 @@ def _extract_sources(
channels = self.source_channels[stream_idx]
geoinfo_channels = self.geoinfo_channels[stream_idx]

source: IOReaderData = self.sources[sample][stream_idx]
source: ReaderData = self.sources[sample][stream_idx]

assert source.data.shape[1] == len(channels), (
f"Number of source channel names {len(channels)} does not align with source data."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import numpy as np
import xarray as xr

from weathergen.common.data import ReaderData, check_reader_data
from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon
from weathergen.datasets.data_reader_base import (
DataReaderTimestep,
ReaderData,
TimeWindowHandler,
TIndex,
check_reader_data,
)

############################################################################
Expand Down
Loading
Loading