From 8eb343b55043f4c5ea5fc98b4ac2e6e7aa810c49 Mon Sep 17 00:00:00 2001 From: Tprojects66554 <38327606992@mby.co.il> Date: Tue, 4 Nov 2025 18:53:54 +0200 Subject: [PATCH 1/5] after_testing_in_py310 --- quantus/metrics/axiomatic/non_sensitivity.py | 170 ++++++++++++------- 1 file changed, 110 insertions(+), 60 deletions(-) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 7e5f390c..e7ab859c 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -1,10 +1,18 @@ """This module contains the implementation of the Non-Sensitivity metric.""" # This file is part of Quantus. -# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. -# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. -# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . -# Quantus project URL: . +# Quantus is free software: you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for +# more details. +# You should have received a copy of the GNU Lesser General Public License +# along with Quantus. If not, see . +# Quantus project URL: +# . import sys import math @@ -20,7 +28,6 @@ ModelType, ScoreDirection, ) -from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func from quantus.metrics.base import Metric @@ -35,8 +42,8 @@ class NonSensitivity(Metric[List[float]]): """ Implementation of NonSensitivity by Nguyen et al., 2020. - Non-sensitivity measures if zero-importance is only assigned to features, that the model is not - functionally dependent on. + Non-sensitivity measures if zero-importance is only assigned to features, + that the model is not functionally dependent on. References: 1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model @@ -135,7 +142,9 @@ def __init__( # Save metric-specific attributes. self.eps = eps self.features_in_step = features_in_step - self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) # Asserts and warnings. if not self.disable_warnings: @@ -285,72 +294,113 @@ def custom_preprocess( def evaluate_batch( self, - model: ModelInterface, + model, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, **kwargs, - ) -> List[int]: + ): """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. + Evaluate a batch for the custom Non-Sensitivity metric. + + This implementation perturbs *feature* and *non-feature* pixels separately, + evaluating how sensitive the model’s predictions are to each perturbation. + The metric quantifies violations of the Non-Sensitivity principle in both directions: + (1) when perturbing **non-feature** pixels *does* affect model predictions, and + (2) when perturbing **feature** pixels *does not* affect model predictions. Parameters ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. + model : object + Model with `predict` and `shape_input` methods, compatible with Quantus conventions. + x_batch : np.ndarray + Input batch, of shape (B, C, H, W) or (B, H, W, C), depending on `channel_first`. + y_batch : np.ndarray + Ground truth or target class indices, of shape (B,). + a_batch : np.ndarray + Attribution maps aligned with `x_batch`, of shape (B, C, H, W). + **kwargs : + Additional keyword arguments passed through for flexibility. Returns ------- - scores_batch: - The evaluation results. - """ + np.ndarray + Array of shape (B,), where each value indicates the total number of + non-sensitivity violations per sample. Lower values indicate higher sensitivity + (fewer violations), whereas higher values indicate non-sensitivity. + + Notes + ----- + - The function assumes that a lower attribution value (below `self.eps`) + represents a "non-feature" pixel. + - Perturbations are applied in groups of size `self.features_in_step`. + - The perturbation function `self.perturb_func` must follow the Quantus API: + it receives an array and an index mask, and returns a perturbed copy. + - Designed to comply with Quantus internal metric conventions and to be + lint-clean under `black` and `flake8`. - # Prepare shapes. Expand a_batch if not the same shape + """ + # --- Step 1. Prepare shapes --- if x_batch.shape != a_batch.shape: a_batch = np.broadcast_to(a_batch, x_batch.shape) - # Flatten the attributions. - batch_size = a_batch.shape[0] - a_batch = a_batch.reshape(batch_size, -1) - n_features = a_batch.shape[-1] - - non_features = a_batch < self.eps - - x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) - y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] - - # Prepare lists. - n_perturbations = math.ceil(n_features / self.features_in_step) - preds = [] - x_perturbed = x_batch.copy() - x_batch_shape = x_batch.shape - a_indices = np.stack([np.arange(n_features) for _ in x_batch]) - for perturbation_step_index in range(n_perturbations): - # Perturb input by indices of attributions. - a_ix = a_indices[ - :, - perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, - ] - x_perturbed = self.perturb_func( - arr=x_batch.reshape(batch_size, -1), - indices=a_ix, - ) - x_perturbed = x_perturbed.reshape(*x_batch_shape) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) - y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] - preds.append(y_pred_perturb) - preds = np.stack(preds, axis=1) - preds_differences = abs(preds - y_pred[:, None]) < self.eps + B = x_batch.shape[0] + x_shape = x_batch.shape + x_flat = x_batch.reshape(B, -1) + a_flat = a_batch.reshape(B, -1) + + # --- Step 2. Split feature vs non-feature --- + non_features = a_flat < self.eps + features = ~non_features + + # --- Step 3. Get base predictions --- + x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(B), y_batch] + + # --- Step 4. Allocate score map --- + pixel_scores = np.zeros_like(a_flat, dtype=float) + + # --- Helper: perturbation loop --- + def perturb_and_record(indices_mask, desc="nonfeature"): + for b in range(B): + indices = np.where(indices_mask[b])[0] + n_pixels = len(indices) + if n_pixels == 0: + continue + + n_steps = math.ceil(n_pixels / self.features_in_step) + for step in range(n_steps): + print( + f"Processing batch {b+1}/{B}, {desc} step {step+1}/{n_steps}", + end="\r", + ) + start = step * self.features_in_step + end = min((step + 1) * self.features_in_step, n_pixels) + subset_idx = indices[start:end] + indices_2d = np.expand_dims(subset_idx, axis=0) + + # --- Perturb only selected pixels --- + perturbed_flat = x_flat.copy() + perturbed_flat[b] = self.perturb_func( + arr=perturbed_flat[b : b + 1, :], + indices=indices_2d, + ) + x_perturbed = perturbed_flat.reshape(x_shape) + x_input = model.shape_input( + x_perturbed, x_shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch] + + # Assign scores for the perturbed pixels + pixel_scores[b, subset_idx] = y_pred_perturb[b] + + # --- Step 5. Run loops --- + perturb_and_record(non_features, "nonfeature") + perturb_and_record(features, "feature") + + # --- Step 6. Reshape to image shape --- + preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores) + preds_differences = preds_differences < self.eps + pixel_scores = pixel_scores.reshape(x_shape) return (preds_differences ^ non_features).sum(-1) From 15e4c4a88346d9156aa727e2bac731f165db5364 Mon Sep 17 00:00:00 2001 From: Tprojects66554 <38327606992@mby.co.il> Date: Tue, 4 Nov 2025 18:59:50 +0200 Subject: [PATCH 2/5] testing_non_sensitivity_with_features_in_step_2 --- tests/metrics/test_axiomatic_metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py index 31fd8eb1..15de8fc8 100644 --- a/tests/metrics/test_axiomatic_metrics.py +++ b/tests/metrics/test_axiomatic_metrics.py @@ -231,6 +231,7 @@ def test_completeness( "normalise": True, "disable_warnings": False, "display_progressbar": False, + "features_in_step": 2, }, "call": { "explain_func": explain, From 44217f2ef6d5b9eab587e645308d9d823d346415 Mon Sep 17 00:00:00 2001 From: Tprojects66554 <38327606992@mby.co.il> Date: Thu, 6 Nov 2025 09:54:53 +0200 Subject: [PATCH 3/5] with_logic_tests --- quantus/metrics/axiomatic/non_sensitivity.py | 13 --- tests/metrics/test_axiomatic_metrics.py | 112 +++++++++++++++++++ 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index e7ab859c..de166b00 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -340,7 +340,6 @@ def evaluate_batch( lint-clean under `black` and `flake8`. """ - # --- Step 1. Prepare shapes --- if x_batch.shape != a_batch.shape: a_batch = np.broadcast_to(a_batch, x_batch.shape) @@ -349,18 +348,14 @@ def evaluate_batch( x_flat = x_batch.reshape(B, -1) a_flat = a_batch.reshape(B, -1) - # --- Step 2. Split feature vs non-feature --- non_features = a_flat < self.eps features = ~non_features - # --- Step 3. Get base predictions --- x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True) y_pred = model.predict(x_input)[np.arange(B), y_batch] - # --- Step 4. Allocate score map --- pixel_scores = np.zeros_like(a_flat, dtype=float) - # --- Helper: perturbation loop --- def perturb_and_record(indices_mask, desc="nonfeature"): for b in range(B): indices = np.where(indices_mask[b])[0] @@ -370,16 +365,11 @@ def perturb_and_record(indices_mask, desc="nonfeature"): n_steps = math.ceil(n_pixels / self.features_in_step) for step in range(n_steps): - print( - f"Processing batch {b+1}/{B}, {desc} step {step+1}/{n_steps}", - end="\r", - ) start = step * self.features_in_step end = min((step + 1) * self.features_in_step, n_pixels) subset_idx = indices[start:end] indices_2d = np.expand_dims(subset_idx, axis=0) - # --- Perturb only selected pixels --- perturbed_flat = x_flat.copy() perturbed_flat[b] = self.perturb_func( arr=perturbed_flat[b : b + 1, :], @@ -391,14 +381,11 @@ def perturb_and_record(indices_mask, desc="nonfeature"): ) y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch] - # Assign scores for the perturbed pixels pixel_scores[b, subset_idx] = y_pred_perturb[b] - # --- Step 5. Run loops --- perturb_and_record(non_features, "nonfeature") perturb_and_record(features, "feature") - # --- Step 6. Reshape to image shape --- preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores) preds_differences = preds_differences < self.eps pixel_scores = pixel_scores.reshape(x_shape) diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py index 15de8fc8..18d93d96 100644 --- a/tests/metrics/test_axiomatic_metrics.py +++ b/tests/metrics/test_axiomatic_metrics.py @@ -3,10 +3,46 @@ import pytest from pytest_lazyfixture import lazy_fixture import numpy as np +import torch +import torch.nn as nn from quantus.functions.explanation_func import explain from quantus.metrics.axiomatic import Completeness, InputInvariance, NonSensitivity +# test_axiomatic_metrics.py (or similar) + +class SensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def forward(self, x): + return x.sum(axis=(1, 2, 3), keepdims=True) + def predict(self, x): + return self.forward(x) + +class InsensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def forward(self, x): + B = x.shape[0] + return np.ones((B, 1), dtype=float) * 100.0 + def predict(self, x): + return self.forward(x) + +class SemiSensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def forward(self, x): + top_sum = x[:, :, 0, :].sum(axis=(1, 2)) + return top_sum[:, None] + def predict(self, x): + return self.forward(x) + +class TrickModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def predict(self, x): + bottom_sum = x[:, :, 1, :].sum(axis=(1, 2)) + return bottom_sum[:, None] @pytest.mark.axiomatic @pytest.mark.parametrize( @@ -418,6 +454,82 @@ def test_non_sensitivity( ) assert scores is not None, "Test failed." +@pytest.mark.axiomatic +@pytest.mark.parametrize( + "scenario,model_factory,x_batch,y_batch,a_batch,expected_violations,kwargs", + [ + + ( + "zero_violations", + lambda: SemiSensitiveModel(), + np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float), + np.array([0]), + np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float), + 0, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "low_attr_high_change", + lambda: SensitiveModel(), + np.array([[[[5.0, 5.0], [5.0, 5.0]]]], dtype=float), + np.array([0]), + np.random.uniform(1e-6, 2e-6, size=(1, 1, 2, 2)), + 4, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "high_attr_low_change", + lambda: InsensitiveModel(), + np.random.rand(1, 1, 4, 4), + np.array([0]), + np.ones((1, 1, 4, 4)), + 16, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "half_good_half_bad", + lambda: TrickModel(), + np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float), + np.array([0]), + np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float), + 4, + {"features_in_step": 1, "eps": 1e-5}, + ), + ], +) +def test_my_non_sensitivity_logics( + scenario, + model_factory, + x_batch, + y_batch, + a_batch, + expected_violations, + kwargs, +): + """ + Parametrized logic-based tests for NonSensitivity. + Each scenario defines a different consistency pattern between attribution and model behavior. + """ + model = model_factory() + model.eval() + metric = NonSensitivity( + disable_warnings=True, + perturb_baseline="uniform", + normalise=False, + **kwargs, + ) + + scores = metric.evaluate_batch(model, x_batch, y_batch, a_batch) + + # --- Assertions --- + assert isinstance(scores, np.ndarray), f"[{scenario}] Output must be np.ndarray" + assert scores.shape[0] == x_batch.shape[0], f"[{scenario}] Wrong batch size" + assert np.all(np.isfinite(scores)), f"[{scenario}] Scores contain NaN/Inf" + + if expected_violations is not None: + assert scores[0] == expected_violations, ( + f"[{scenario}] expected {expected_violations}, got {scores[0]}" + ) @pytest.mark.axiomatic @pytest.mark.parametrize( From d244aa3a29927d6fe312b3c1666c204577ae557f Mon Sep 17 00:00:00 2001 From: Tprojects66554 <38327606992@mby.co.il> Date: Thu, 6 Nov 2025 10:09:05 +0200 Subject: [PATCH 4/5] change evaluate_batch_ documentation --- quantus/metrics/axiomatic/non_sensitivity.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index de166b00..d04c19d2 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -336,9 +336,7 @@ def evaluate_batch( - Perturbations are applied in groups of size `self.features_in_step`. - The perturbation function `self.perturb_func` must follow the Quantus API: it receives an array and an index mask, and returns a perturbed copy. - - Designed to comply with Quantus internal metric conventions and to be - lint-clean under `black` and `flake8`. - + - Returned scores are counts of violations per sample or aggregated across samples. """ if x_batch.shape != a_batch.shape: a_batch = np.broadcast_to(a_batch, x_batch.shape) From 4190c555d12df9cf40b9358a52e6791e5b0eb8d3 Mon Sep 17 00:00:00 2001 From: Tprojects66554 <38327606992@mby.co.il> Date: Sun, 9 Nov 2025 19:20:39 +0200 Subject: [PATCH 5/5] Splitting functions for SRP --- quantus/metrics/axiomatic/non_sensitivity.py | 200 ++++++++++--------- tests/metrics/test_axiomatic_metrics.py | 19 ++ 2 files changed, 122 insertions(+), 97 deletions(-) diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index d04c19d2..16c0dd46 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -1,18 +1,10 @@ """This module contains the implementation of the Non-Sensitivity metric.""" # This file is part of Quantus. -# Quantus is free software: you can redistribute it and/or modify it under the -# terms of the GNU Lesser General Public License as published by the Free -# Software Foundation, either version 3 of the License, or (at your option) -# any later version. -# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY -# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS -# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for -# more details. -# You should have received a copy of the GNU Lesser General Public License -# along with Quantus. If not, see . -# Quantus project URL: -# . +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . import sys import math @@ -28,6 +20,7 @@ ModelType, ScoreDirection, ) +from quantus.helpers.model.model_interface import ModelInterface from quantus.helpers.perturbation_utils import make_perturb_func from quantus.metrics.base import Metric @@ -42,8 +35,8 @@ class NonSensitivity(Metric[List[float]]): """ Implementation of NonSensitivity by Nguyen et al., 2020. - Non-sensitivity measures if zero-importance is only assigned to features, - that the model is not functionally dependent on. + Non-sensitivity measures if zero-importance is only assigned to features, that the model is not + functionally dependent on. References: 1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model @@ -142,9 +135,7 @@ def __init__( # Save metric-specific attributes. self.eps = eps self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -291,101 +282,116 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + + def evaluate_batch( self, - model, + model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, a_batch: np.ndarray, **kwargs, - ): - """ - Evaluate a batch for the custom Non-Sensitivity metric. - - This implementation perturbs *feature* and *non-feature* pixels separately, - evaluating how sensitive the model’s predictions are to each perturbation. - The metric quantifies violations of the Non-Sensitivity principle in both directions: - (1) when perturbing **non-feature** pixels *does* affect model predictions, and - (2) when perturbing **feature** pixels *does not* affect model predictions. - - Parameters - ---------- - model : object - Model with `predict` and `shape_input` methods, compatible with Quantus conventions. - x_batch : np.ndarray - Input batch, of shape (B, C, H, W) or (B, H, W, C), depending on `channel_first`. - y_batch : np.ndarray - Ground truth or target class indices, of shape (B,). - a_batch : np.ndarray - Attribution maps aligned with `x_batch`, of shape (B, C, H, W). - **kwargs : - Additional keyword arguments passed through for flexibility. - - Returns - ------- - np.ndarray - Array of shape (B,), where each value indicates the total number of - non-sensitivity violations per sample. Lower values indicate higher sensitivity - (fewer violations), whereas higher values indicate non-sensitivity. - - Notes - ----- - - The function assumes that a lower attribution value (below `self.eps`) - represents a "non-feature" pixel. - - Perturbations are applied in groups of size `self.features_in_step`. - - The perturbation function `self.perturb_func` must follow the Quantus API: - it receives an array and an index mask, and returns a perturbed copy. - - Returned scores are counts of violations per sample or aggregated across samples. + ) -> List[int]: """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + np.ndarray + Array of shape (batch_size,), per sample score. + Lower values are better. + """ + + # Prepare shapes. Expand a_batch if not the same shape if x_batch.shape != a_batch.shape: a_batch = np.broadcast_to(a_batch, x_batch.shape) - B = x_batch.shape[0] + # Flatten the attributions. + batch_size = a_batch.shape[0] x_shape = x_batch.shape - x_flat = x_batch.reshape(B, -1) - a_flat = a_batch.reshape(B, -1) + x_batch = x_batch.reshape(batch_size, -1) + a_batch = a_batch.reshape(batch_size, -1) - non_features = a_flat < self.eps + non_features = a_batch < self.eps features = ~non_features x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True) - y_pred = model.predict(x_input)[np.arange(B), y_batch] - - pixel_scores = np.zeros_like(a_flat, dtype=float) - - def perturb_and_record(indices_mask, desc="nonfeature"): - for b in range(B): - indices = np.where(indices_mask[b])[0] - n_pixels = len(indices) - if n_pixels == 0: - continue - - n_steps = math.ceil(n_pixels / self.features_in_step) - for step in range(n_steps): - start = step * self.features_in_step - end = min((step + 1) * self.features_in_step, n_pixels) - subset_idx = indices[start:end] - indices_2d = np.expand_dims(subset_idx, axis=0) - - perturbed_flat = x_flat.copy() - perturbed_flat[b] = self.perturb_func( - arr=perturbed_flat[b : b + 1, :], - indices=indices_2d, - ) - x_perturbed = perturbed_flat.reshape(x_shape) - x_input = model.shape_input( - x_perturbed, x_shape, channel_first=True, batched=True - ) - y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch] - - pixel_scores[b, subset_idx] = y_pred_perturb[b] - - perturb_and_record(non_features, "nonfeature") - perturb_and_record(features, "feature") - - preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores) - preds_differences = preds_differences < self.eps - pixel_scores = pixel_scores.reshape(x_shape) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + pixel_scores_non = self._process_mask( + model, x_batch, y_batch, non_features, x_shape + ) + pixel_scores_feat = self._process_mask( + model, x_batch, y_batch, features, x_shape + ) + + preds_differences = np.abs( + y_pred[:, None] - (pixel_scores_non + pixel_scores_feat) + ) < self.eps return (preds_differences ^ non_features).sum(-1) + + def _create_index_groups(self, mask: np.ndarray) -> List[np.ndarray]: + """Divide mask indices into perturbation groups.""" + indices = np.where(mask)[0] + return [ + indices[i:i + self.features_in_step] + for i in range(0, len(indices), self.features_in_step) + ] + + def _perturb_sample_batch( + self, x_batch: np.ndarray, b: int, indices: np.ndarray + ) -> np.ndarray: + """Perturb a single sample within the batch.""" + perturbed_flat = x_batch.copy() + indices_2d = np.expand_dims(indices, axis=0) + perturbed_flat[b] = self.perturb_func( + arr=perturbed_flat[b:b + 1, :], + indices=indices_2d, + ) + return perturbed_flat + + def _predict_scores( + self, model, x_batch: np.ndarray, y_batch: np.ndarray, x_shape: tuple + ) -> np.ndarray: + """Predict scores for the true labels of the given batch.""" + x_input = model.shape_input( + x_batch.reshape(x_shape), x_shape, channel_first=True, batched=True + ) + return model.predict(x_input)[np.arange(x_batch.shape[0]), y_batch] + + def _process_mask( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + mask: np.ndarray, + x_shape: tuple, + ) -> np.ndarray: + """Handle perturbation and prediction workflow for a single mask type.""" + batch_size = x_batch.shape[0] + pixel_scores = np.zeros_like(x_batch, dtype=float) + + for b in range(batch_size): + for indices in self._create_index_groups(mask[b]): + perturbed_batch = self._perturb_sample_batch(x_batch, b, indices) + preds = self._predict_scores(model, perturbed_batch, y_batch, x_shape) + pixel_scores[b, indices] = preds[b] + + return pixel_scores diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py index 18d93d96..83d3836e 100644 --- a/tests/metrics/test_axiomatic_metrics.py +++ b/tests/metrics/test_axiomatic_metrics.py @@ -11,18 +11,34 @@ # test_axiomatic_metrics.py (or similar) +def _ensure_4d(x): + """Make sure x is (B, C, H, W), even if passed as (B, N).""" + x = np.array(x) + if x.ndim == 2: + B, N = x.shape + side = int(np.sqrt(N)) + x = x.reshape(B, 1, side, side) + elif x.ndim == 3: + x = x[:, None, :, :] + return x + class SensitiveModel(nn.Module): def shape_input(self, x, shape, channel_first=True, batched=True): return x + def forward(self, x): + x = _ensure_4d(x) return x.sum(axis=(1, 2, 3), keepdims=True) + def predict(self, x): + x = _ensure_4d(x) return self.forward(x) class InsensitiveModel(nn.Module): def shape_input(self, x, shape, channel_first=True, batched=True): return x def forward(self, x): + x = _ensure_4d(x) B = x.shape[0] return np.ones((B, 1), dtype=float) * 100.0 def predict(self, x): @@ -32,15 +48,18 @@ class SemiSensitiveModel(nn.Module): def shape_input(self, x, shape, channel_first=True, batched=True): return x def forward(self, x): + x = _ensure_4d(x) top_sum = x[:, :, 0, :].sum(axis=(1, 2)) return top_sum[:, None] def predict(self, x): + x = _ensure_4d(x) return self.forward(x) class TrickModel(nn.Module): def shape_input(self, x, shape, channel_first=True, batched=True): return x def predict(self, x): + x = _ensure_4d(x) bottom_sum = x[:, :, 1, :].sum(axis=(1, 2)) return bottom_sum[:, None]