diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 7e5f390c..16c0dd46 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -282,6 +282,9 @@ def custom_preprocess( features_in_step=self.features_in_step, input_shape=x_batch.shape[2:], ) + + + def evaluate_batch( self, @@ -292,65 +295,103 @@ def evaluate_batch( **kwargs, ) -> List[int]: """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + np.ndarray + Array of shape (batch_size,), per sample score. + Lower values are better. + """ + # Prepare shapes. Expand a_batch if not the same shape if x_batch.shape != a_batch.shape: a_batch = np.broadcast_to(a_batch, x_batch.shape) # Flatten the attributions. batch_size = a_batch.shape[0] + x_shape = x_batch.shape + x_batch = x_batch.reshape(batch_size, -1) a_batch = a_batch.reshape(batch_size, -1) - n_features = a_batch.shape[-1] non_features = a_batch < self.eps + features = ~non_features - x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True) y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] - # Prepare lists. - n_perturbations = math.ceil(n_features / self.features_in_step) - preds = [] - x_perturbed = x_batch.copy() - x_batch_shape = x_batch.shape - a_indices = np.stack([np.arange(n_features) for _ in x_batch]) - for perturbation_step_index in range(n_perturbations): - # Perturb input by indices of attributions. - a_ix = a_indices[ - :, - perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, - ] - x_perturbed = self.perturb_func( - arr=x_batch.reshape(batch_size, -1), - indices=a_ix, - ) - x_perturbed = x_perturbed.reshape(*x_batch_shape) + pixel_scores_non = self._process_mask( + model, x_batch, y_batch, non_features, x_shape + ) + pixel_scores_feat = self._process_mask( + model, x_batch, y_batch, features, x_shape + ) - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) - y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] - preds.append(y_pred_perturb) - preds = np.stack(preds, axis=1) - preds_differences = abs(preds - y_pred[:, None]) < self.eps + preds_differences = np.abs( + y_pred[:, None] - (pixel_scores_non + pixel_scores_feat) + ) < self.eps return (preds_differences ^ non_features).sum(-1) + + def _create_index_groups(self, mask: np.ndarray) -> List[np.ndarray]: + """Divide mask indices into perturbation groups.""" + indices = np.where(mask)[0] + return [ + indices[i:i + self.features_in_step] + for i in range(0, len(indices), self.features_in_step) + ] + + def _perturb_sample_batch( + self, x_batch: np.ndarray, b: int, indices: np.ndarray + ) -> np.ndarray: + """Perturb a single sample within the batch.""" + perturbed_flat = x_batch.copy() + indices_2d = np.expand_dims(indices, axis=0) + perturbed_flat[b] = self.perturb_func( + arr=perturbed_flat[b:b + 1, :], + indices=indices_2d, + ) + return perturbed_flat + + def _predict_scores( + self, model, x_batch: np.ndarray, y_batch: np.ndarray, x_shape: tuple + ) -> np.ndarray: + """Predict scores for the true labels of the given batch.""" + x_input = model.shape_input( + x_batch.reshape(x_shape), x_shape, channel_first=True, batched=True + ) + return model.predict(x_input)[np.arange(x_batch.shape[0]), y_batch] + + def _process_mask( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + mask: np.ndarray, + x_shape: tuple, + ) -> np.ndarray: + """Handle perturbation and prediction workflow for a single mask type.""" + batch_size = x_batch.shape[0] + pixel_scores = np.zeros_like(x_batch, dtype=float) + + for b in range(batch_size): + for indices in self._create_index_groups(mask[b]): + perturbed_batch = self._perturb_sample_batch(x_batch, b, indices) + preds = self._predict_scores(model, perturbed_batch, y_batch, x_shape) + pixel_scores[b, indices] = preds[b] + + return pixel_scores diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py index 31fd8eb1..83d3836e 100644 --- a/tests/metrics/test_axiomatic_metrics.py +++ b/tests/metrics/test_axiomatic_metrics.py @@ -3,10 +3,65 @@ import pytest from pytest_lazyfixture import lazy_fixture import numpy as np +import torch +import torch.nn as nn from quantus.functions.explanation_func import explain from quantus.metrics.axiomatic import Completeness, InputInvariance, NonSensitivity +# test_axiomatic_metrics.py (or similar) + +def _ensure_4d(x): + """Make sure x is (B, C, H, W), even if passed as (B, N).""" + x = np.array(x) + if x.ndim == 2: + B, N = x.shape + side = int(np.sqrt(N)) + x = x.reshape(B, 1, side, side) + elif x.ndim == 3: + x = x[:, None, :, :] + return x + +class SensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + + def forward(self, x): + x = _ensure_4d(x) + return x.sum(axis=(1, 2, 3), keepdims=True) + + def predict(self, x): + x = _ensure_4d(x) + return self.forward(x) + +class InsensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def forward(self, x): + x = _ensure_4d(x) + B = x.shape[0] + return np.ones((B, 1), dtype=float) * 100.0 + def predict(self, x): + return self.forward(x) + +class SemiSensitiveModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def forward(self, x): + x = _ensure_4d(x) + top_sum = x[:, :, 0, :].sum(axis=(1, 2)) + return top_sum[:, None] + def predict(self, x): + x = _ensure_4d(x) + return self.forward(x) + +class TrickModel(nn.Module): + def shape_input(self, x, shape, channel_first=True, batched=True): + return x + def predict(self, x): + x = _ensure_4d(x) + bottom_sum = x[:, :, 1, :].sum(axis=(1, 2)) + return bottom_sum[:, None] @pytest.mark.axiomatic @pytest.mark.parametrize( @@ -231,6 +286,7 @@ def test_completeness( "normalise": True, "disable_warnings": False, "display_progressbar": False, + "features_in_step": 2, }, "call": { "explain_func": explain, @@ -417,6 +473,82 @@ def test_non_sensitivity( ) assert scores is not None, "Test failed." +@pytest.mark.axiomatic +@pytest.mark.parametrize( + "scenario,model_factory,x_batch,y_batch,a_batch,expected_violations,kwargs", + [ + + ( + "zero_violations", + lambda: SemiSensitiveModel(), + np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float), + np.array([0]), + np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float), + 0, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "low_attr_high_change", + lambda: SensitiveModel(), + np.array([[[[5.0, 5.0], [5.0, 5.0]]]], dtype=float), + np.array([0]), + np.random.uniform(1e-6, 2e-6, size=(1, 1, 2, 2)), + 4, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "high_attr_low_change", + lambda: InsensitiveModel(), + np.random.rand(1, 1, 4, 4), + np.array([0]), + np.ones((1, 1, 4, 4)), + 16, + {"features_in_step": 2, "eps": 1e-5}, + ), + ( + "half_good_half_bad", + lambda: TrickModel(), + np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float), + np.array([0]), + np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float), + 4, + {"features_in_step": 1, "eps": 1e-5}, + ), + ], +) +def test_my_non_sensitivity_logics( + scenario, + model_factory, + x_batch, + y_batch, + a_batch, + expected_violations, + kwargs, +): + """ + Parametrized logic-based tests for NonSensitivity. + Each scenario defines a different consistency pattern between attribution and model behavior. + """ + model = model_factory() + model.eval() + metric = NonSensitivity( + disable_warnings=True, + perturb_baseline="uniform", + normalise=False, + **kwargs, + ) + + scores = metric.evaluate_batch(model, x_batch, y_batch, a_batch) + + # --- Assertions --- + assert isinstance(scores, np.ndarray), f"[{scenario}] Output must be np.ndarray" + assert scores.shape[0] == x_batch.shape[0], f"[{scenario}] Wrong batch size" + assert np.all(np.isfinite(scores)), f"[{scenario}] Scores contain NaN/Inf" + + if expected_violations is not None: + assert scores[0] == expected_violations, ( + f"[{scenario}] expected {expected_violations}, got {scores[0]}" + ) @pytest.mark.axiomatic @pytest.mark.parametrize(