Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ ignore = [

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"
"builtins.getattr".msg = "getattr() is not allowed"

[tool.ruff.format]
# Use Unix `\n` line endings for all files
Expand Down
8 changes: 3 additions & 5 deletions src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,10 @@ def _select_target_subset(
if max_num_targets is None or max_num_targets <= 0 or num_points <= max_num_targets:
return None

rng = getattr(self, "rng", None)
if rng is None:
rng = np.random.default_rng()
self.rng = rng
if self.rng is None:
self.rng = np.random.default_rng()

selected = np.sort(rng.choice(num_points, max_num_targets, replace=False))
selected = np.sort(self.rng.choice(num_points, max_num_targets, replace=False))

return torch.from_numpy(selected).to(torch.long)

Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/train/loss_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
[
(
params.get("weight", 1.0),
getattr(LossModules, params.type)(
LossModules.__dict__[params.type](
cf, mode_cfg, stage, self.device, **params.loss_fcts
),
)
Expand Down
7 changes: 4 additions & 3 deletions src/weathergen/train/loss_modules/loss_module_physical.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
# dynamically load loss functions based on configuration and stage
self.loss_fcts = [
[
getattr(loss_fns, name),
loss_fns.__dict__[name],
params.get("weight", 1.0),
name,
]
Expand Down Expand Up @@ -95,15 +95,16 @@ def _get_output_step_weights(self, len_forecast_steps):
timestep_weight_config = self.mode_cfg.get("forecast", {}).get("timestep_weight", {})
if len(timestep_weight_config) == 0:
return [1.0 for _ in range(len_forecast_steps)]
weights_timestep_fct = getattr(loss_fns, list(timestep_weight_config.keys())[0])
weights_timestep_fct_name = list(timestep_weight_config.keys())[0]
weights_timestep_fct = loss_fns.__dict__[weights_timestep_fct_name]
decay_factor = list(timestep_weight_config.values())[0]["decay_factor"]
return weights_timestep_fct(len_forecast_steps, decay_factor)

def _get_location_weights(self, stream_info, target_coords):
location_weight_type = stream_info.get("location_weight", None)
if location_weight_type is None:
return None
weights_locations_fct = getattr(loss_fns, location_weight_type)
weights_locations_fct = loss_fns.__dict__[location_weight_type]
weights_locations = weights_locations_fct(target_coords)
weights_locations = weights_locations.to(device=self.device, non_blocking=True)

Expand Down
16 changes: 11 additions & 5 deletions src/weathergen/utils/better_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ class ABCMeta(NativeABCMeta):
def __call__(cls, *args, **kwargs):
instance = NativeABCMeta.__call__(cls, *args, **kwargs)
# pylint: disable-next=attribute-defined-outside-init
abstract_attributes = {
name
for name in dir(instance)
if hasattr(getattr(instance, name), "__is_abstract_attribute__")
}
abstract_attributes = set()

for name in dir(instance):
# Search the attribute name in the class hierarchy
for base in type(instance).__mro__:
if name in base.__dict__:
attr = base.__dict__[name]
if hasattr(attr, "__is_abstract_attribute__"):
abstract_attributes.add(name)
break

if abstract_attributes:
raise NotImplementedError(
"Can't instantiate abstract class {} with abstract attributes: {}".format(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ def test_inference_defaults(inference_parser):
default_values[:2] = [cli._format_date(date) for date in default_values[:2]]

args = inference_parser.parse_args(BASIC_ARGLIST)
args_dict = vars(args)

assert all(
[
getattr(args, arg) == default_value
args_dict[arg] == default_value
for arg, default_value in zip(default_args, default_values, strict=True)
]
)
Expand Down
Loading