From d74c8d3b0ecfefc0de34a054ea7885673e11d144 Mon Sep 17 00:00:00 2001 From: simone99n Date: Wed, 28 Jan 2026 15:29:34 +0100 Subject: [PATCH] [1332] remove getattr() --- pyproject.toml | 1 + src/weathergen/datasets/tokenizer_masking.py | 8 +++----- src/weathergen/train/loss_calculator.py | 2 +- .../train/loss_modules/loss_module_physical.py | 7 ++++--- src/weathergen/utils/better_abc.py | 16 +++++++++++----- tests/test_cli.py | 3 ++- 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4dfc14d4..2334711a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,7 @@ ignore = [ [tool.ruff.lint.flake8-tidy-imports.banned-api] "numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example" +"builtins.getattr".msg = "getattr() is not allowed" [tool.ruff.format] # Use Unix `\n` line endings for all files diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index f6cb9ed24..6d1962063 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -316,12 +316,10 @@ def _select_target_subset( if max_num_targets is None or max_num_targets <= 0 or num_points <= max_num_targets: return None - rng = getattr(self, "rng", None) - if rng is None: - rng = np.random.default_rng() - self.rng = rng + if self.rng is None: + self.rng = np.random.default_rng() - selected = np.sort(rng.choice(num_points, max_num_targets, replace=False)) + selected = np.sort(self.rng.choice(num_points, max_num_targets, replace=False)) return torch.from_numpy(selected).to(torch.long) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2f81940a3..3b956a39f 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -67,7 +67,7 @@ def __init__( [ ( params.get("weight", 1.0), - getattr(LossModules, params.type)( + LossModules.__dict__[params.type]( cf, mode_cfg, stage, self.device, **params.loss_fcts ), ) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 64e104827..23816a35a 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -59,7 +59,7 @@ def __init__( # dynamically load loss functions based on configuration and stage self.loss_fcts = [ [ - getattr(loss_fns, name), + loss_fns.__dict__[name], params.get("weight", 1.0), name, ] @@ -95,7 +95,8 @@ def _get_output_step_weights(self, len_forecast_steps): timestep_weight_config = self.mode_cfg.get("forecast", {}).get("timestep_weight", {}) if len(timestep_weight_config) == 0: return [1.0 for _ in range(len_forecast_steps)] - weights_timestep_fct = getattr(loss_fns, list(timestep_weight_config.keys())[0]) + weights_timestep_fct_name = list(timestep_weight_config.keys())[0] + weights_timestep_fct = loss_fns.__dict__[weights_timestep_fct_name] decay_factor = list(timestep_weight_config.values())[0]["decay_factor"] return weights_timestep_fct(len_forecast_steps, decay_factor) @@ -103,7 +104,7 @@ def _get_location_weights(self, stream_info, target_coords): location_weight_type = stream_info.get("location_weight", None) if location_weight_type is None: return None - weights_locations_fct = getattr(loss_fns, location_weight_type) + weights_locations_fct = loss_fns.__dict__[location_weight_type] weights_locations = weights_locations_fct(target_coords) weights_locations = weights_locations.to(device=self.device, non_blocking=True) diff --git a/src/weathergen/utils/better_abc.py b/src/weathergen/utils/better_abc.py index e322927d4..4c8be724c 100644 --- a/src/weathergen/utils/better_abc.py +++ b/src/weathergen/utils/better_abc.py @@ -29,11 +29,17 @@ class ABCMeta(NativeABCMeta): def __call__(cls, *args, **kwargs): instance = NativeABCMeta.__call__(cls, *args, **kwargs) # pylint: disable-next=attribute-defined-outside-init - abstract_attributes = { - name - for name in dir(instance) - if hasattr(getattr(instance, name), "__is_abstract_attribute__") - } + abstract_attributes = set() + + for name in dir(instance): + # Search the attribute name in the class hierarchy + for base in type(instance).__mro__: + if name in base.__dict__: + attr = base.__dict__[name] + if hasattr(attr, "__is_abstract_attribute__"): + abstract_attributes.add(name) + break + if abstract_attributes: raise NotImplementedError( "Can't instantiate abstract class {} with abstract attributes: {}".format( diff --git a/tests/test_cli.py b/tests/test_cli.py index e652c7924..2a6b605ef 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -88,10 +88,11 @@ def test_inference_defaults(inference_parser): default_values[:2] = [cli._format_date(date) for date in default_values[:2]] args = inference_parser.parse_args(BASIC_ARGLIST) + args_dict = vars(args) assert all( [ - getattr(args, arg) == default_value + args_dict[arg] == default_value for arg, default_value in zip(default_args, default_values, strict=True) ] )