From 8eb343b55043f4c5ea5fc98b4ac2e6e7aa810c49 Mon Sep 17 00:00:00 2001
From: Tprojects66554 <38327606992@mby.co.il>
Date: Tue, 4 Nov 2025 18:53:54 +0200
Subject: [PATCH 1/5] after_testing_in_py310
---
quantus/metrics/axiomatic/non_sensitivity.py | 170 ++++++++++++-------
1 file changed, 110 insertions(+), 60 deletions(-)
diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py
index 7e5f390c..e7ab859c 100644
--- a/quantus/metrics/axiomatic/non_sensitivity.py
+++ b/quantus/metrics/axiomatic/non_sensitivity.py
@@ -1,10 +1,18 @@
"""This module contains the implementation of the Non-Sensitivity metric."""
# This file is part of Quantus.
-# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
-# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
-# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see .
-# Quantus project URL: .
+# Quantus is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option)
+# any later version.
+# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
+# more details.
+# You should have received a copy of the GNU Lesser General Public License
+# along with Quantus. If not, see .
+# Quantus project URL:
+# .
import sys
import math
@@ -20,7 +28,6 @@
ModelType,
ScoreDirection,
)
-from quantus.helpers.model.model_interface import ModelInterface
from quantus.helpers.perturbation_utils import make_perturb_func
from quantus.metrics.base import Metric
@@ -35,8 +42,8 @@ class NonSensitivity(Metric[List[float]]):
"""
Implementation of NonSensitivity by Nguyen et al., 2020.
- Non-sensitivity measures if zero-importance is only assigned to features, that the model is not
- functionally dependent on.
+ Non-sensitivity measures if zero-importance is only assigned to features,
+ that the model is not functionally dependent on.
References:
1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model
@@ -135,7 +142,9 @@ def __init__(
# Save metric-specific attributes.
self.eps = eps
self.features_in_step = features_in_step
- self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)
+ self.perturb_func = make_perturb_func(
+ perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
+ )
# Asserts and warnings.
if not self.disable_warnings:
@@ -285,72 +294,113 @@ def custom_preprocess(
def evaluate_batch(
self,
- model: ModelInterface,
+ model,
x_batch: np.ndarray,
y_batch: np.ndarray,
a_batch: np.ndarray,
**kwargs,
- ) -> List[int]:
+ ):
"""
- This method performs XAI evaluation on a single batch of explanations.
- For more information on the specific logic, we refer the metric’s initialisation docstring.
+ Evaluate a batch for the custom Non-Sensitivity metric.
+
+ This implementation perturbs *feature* and *non-feature* pixels separately,
+ evaluating how sensitive the model’s predictions are to each perturbation.
+ The metric quantifies violations of the Non-Sensitivity principle in both directions:
+ (1) when perturbing **non-feature** pixels *does* affect model predictions, and
+ (2) when perturbing **feature** pixels *does not* affect model predictions.
Parameters
----------
- model: ModelInterface
- A ModelInterface that is subject to explanation.
- x_batch: np.ndarray
- The input to be evaluated on a batch-basis.
- y_batch: np.ndarray
- The output to be evaluated on a batch-basis.
- a_batch: np.ndarray
- The explanation to be evaluated on a batch-basis.
- kwargs:
- Unused.
+ model : object
+ Model with `predict` and `shape_input` methods, compatible with Quantus conventions.
+ x_batch : np.ndarray
+ Input batch, of shape (B, C, H, W) or (B, H, W, C), depending on `channel_first`.
+ y_batch : np.ndarray
+ Ground truth or target class indices, of shape (B,).
+ a_batch : np.ndarray
+ Attribution maps aligned with `x_batch`, of shape (B, C, H, W).
+ **kwargs :
+ Additional keyword arguments passed through for flexibility.
Returns
-------
- scores_batch:
- The evaluation results.
- """
+ np.ndarray
+ Array of shape (B,), where each value indicates the total number of
+ non-sensitivity violations per sample. Lower values indicate higher sensitivity
+ (fewer violations), whereas higher values indicate non-sensitivity.
+
+ Notes
+ -----
+ - The function assumes that a lower attribution value (below `self.eps`)
+ represents a "non-feature" pixel.
+ - Perturbations are applied in groups of size `self.features_in_step`.
+ - The perturbation function `self.perturb_func` must follow the Quantus API:
+ it receives an array and an index mask, and returns a perturbed copy.
+ - Designed to comply with Quantus internal metric conventions and to be
+ lint-clean under `black` and `flake8`.
- # Prepare shapes. Expand a_batch if not the same shape
+ """
+ # --- Step 1. Prepare shapes ---
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)
- # Flatten the attributions.
- batch_size = a_batch.shape[0]
- a_batch = a_batch.reshape(batch_size, -1)
- n_features = a_batch.shape[-1]
-
- non_features = a_batch < self.eps
-
- x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True)
- y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]
-
- # Prepare lists.
- n_perturbations = math.ceil(n_features / self.features_in_step)
- preds = []
- x_perturbed = x_batch.copy()
- x_batch_shape = x_batch.shape
- a_indices = np.stack([np.arange(n_features) for _ in x_batch])
- for perturbation_step_index in range(n_perturbations):
- # Perturb input by indices of attributions.
- a_ix = a_indices[
- :,
- perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step,
- ]
- x_perturbed = self.perturb_func(
- arr=x_batch.reshape(batch_size, -1),
- indices=a_ix,
- )
- x_perturbed = x_perturbed.reshape(*x_batch_shape)
-
- # Predict on perturbed input x.
- x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True)
- y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch]
- preds.append(y_pred_perturb)
- preds = np.stack(preds, axis=1)
- preds_differences = abs(preds - y_pred[:, None]) < self.eps
+ B = x_batch.shape[0]
+ x_shape = x_batch.shape
+ x_flat = x_batch.reshape(B, -1)
+ a_flat = a_batch.reshape(B, -1)
+
+ # --- Step 2. Split feature vs non-feature ---
+ non_features = a_flat < self.eps
+ features = ~non_features
+
+ # --- Step 3. Get base predictions ---
+ x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True)
+ y_pred = model.predict(x_input)[np.arange(B), y_batch]
+
+ # --- Step 4. Allocate score map ---
+ pixel_scores = np.zeros_like(a_flat, dtype=float)
+
+ # --- Helper: perturbation loop ---
+ def perturb_and_record(indices_mask, desc="nonfeature"):
+ for b in range(B):
+ indices = np.where(indices_mask[b])[0]
+ n_pixels = len(indices)
+ if n_pixels == 0:
+ continue
+
+ n_steps = math.ceil(n_pixels / self.features_in_step)
+ for step in range(n_steps):
+ print(
+ f"Processing batch {b+1}/{B}, {desc} step {step+1}/{n_steps}",
+ end="\r",
+ )
+ start = step * self.features_in_step
+ end = min((step + 1) * self.features_in_step, n_pixels)
+ subset_idx = indices[start:end]
+ indices_2d = np.expand_dims(subset_idx, axis=0)
+
+ # --- Perturb only selected pixels ---
+ perturbed_flat = x_flat.copy()
+ perturbed_flat[b] = self.perturb_func(
+ arr=perturbed_flat[b : b + 1, :],
+ indices=indices_2d,
+ )
+ x_perturbed = perturbed_flat.reshape(x_shape)
+ x_input = model.shape_input(
+ x_perturbed, x_shape, channel_first=True, batched=True
+ )
+ y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch]
+
+ # Assign scores for the perturbed pixels
+ pixel_scores[b, subset_idx] = y_pred_perturb[b]
+
+ # --- Step 5. Run loops ---
+ perturb_and_record(non_features, "nonfeature")
+ perturb_and_record(features, "feature")
+
+ # --- Step 6. Reshape to image shape ---
+ preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores)
+ preds_differences = preds_differences < self.eps
+ pixel_scores = pixel_scores.reshape(x_shape)
return (preds_differences ^ non_features).sum(-1)
From 15e4c4a88346d9156aa727e2bac731f165db5364 Mon Sep 17 00:00:00 2001
From: Tprojects66554 <38327606992@mby.co.il>
Date: Tue, 4 Nov 2025 18:59:50 +0200
Subject: [PATCH 2/5] testing_non_sensitivity_with_features_in_step_2
---
tests/metrics/test_axiomatic_metrics.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py
index 31fd8eb1..15de8fc8 100644
--- a/tests/metrics/test_axiomatic_metrics.py
+++ b/tests/metrics/test_axiomatic_metrics.py
@@ -231,6 +231,7 @@ def test_completeness(
"normalise": True,
"disable_warnings": False,
"display_progressbar": False,
+ "features_in_step": 2,
},
"call": {
"explain_func": explain,
From 44217f2ef6d5b9eab587e645308d9d823d346415 Mon Sep 17 00:00:00 2001
From: Tprojects66554 <38327606992@mby.co.il>
Date: Thu, 6 Nov 2025 09:54:53 +0200
Subject: [PATCH 3/5] with_logic_tests
---
quantus/metrics/axiomatic/non_sensitivity.py | 13 ---
tests/metrics/test_axiomatic_metrics.py | 112 +++++++++++++++++++
2 files changed, 112 insertions(+), 13 deletions(-)
diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py
index e7ab859c..de166b00 100644
--- a/quantus/metrics/axiomatic/non_sensitivity.py
+++ b/quantus/metrics/axiomatic/non_sensitivity.py
@@ -340,7 +340,6 @@ def evaluate_batch(
lint-clean under `black` and `flake8`.
"""
- # --- Step 1. Prepare shapes ---
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)
@@ -349,18 +348,14 @@ def evaluate_batch(
x_flat = x_batch.reshape(B, -1)
a_flat = a_batch.reshape(B, -1)
- # --- Step 2. Split feature vs non-feature ---
non_features = a_flat < self.eps
features = ~non_features
- # --- Step 3. Get base predictions ---
x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True)
y_pred = model.predict(x_input)[np.arange(B), y_batch]
- # --- Step 4. Allocate score map ---
pixel_scores = np.zeros_like(a_flat, dtype=float)
- # --- Helper: perturbation loop ---
def perturb_and_record(indices_mask, desc="nonfeature"):
for b in range(B):
indices = np.where(indices_mask[b])[0]
@@ -370,16 +365,11 @@ def perturb_and_record(indices_mask, desc="nonfeature"):
n_steps = math.ceil(n_pixels / self.features_in_step)
for step in range(n_steps):
- print(
- f"Processing batch {b+1}/{B}, {desc} step {step+1}/{n_steps}",
- end="\r",
- )
start = step * self.features_in_step
end = min((step + 1) * self.features_in_step, n_pixels)
subset_idx = indices[start:end]
indices_2d = np.expand_dims(subset_idx, axis=0)
- # --- Perturb only selected pixels ---
perturbed_flat = x_flat.copy()
perturbed_flat[b] = self.perturb_func(
arr=perturbed_flat[b : b + 1, :],
@@ -391,14 +381,11 @@ def perturb_and_record(indices_mask, desc="nonfeature"):
)
y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch]
- # Assign scores for the perturbed pixels
pixel_scores[b, subset_idx] = y_pred_perturb[b]
- # --- Step 5. Run loops ---
perturb_and_record(non_features, "nonfeature")
perturb_and_record(features, "feature")
- # --- Step 6. Reshape to image shape ---
preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores)
preds_differences = preds_differences < self.eps
pixel_scores = pixel_scores.reshape(x_shape)
diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py
index 15de8fc8..18d93d96 100644
--- a/tests/metrics/test_axiomatic_metrics.py
+++ b/tests/metrics/test_axiomatic_metrics.py
@@ -3,10 +3,46 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
+import torch
+import torch.nn as nn
from quantus.functions.explanation_func import explain
from quantus.metrics.axiomatic import Completeness, InputInvariance, NonSensitivity
+# test_axiomatic_metrics.py (or similar)
+
+class SensitiveModel(nn.Module):
+ def shape_input(self, x, shape, channel_first=True, batched=True):
+ return x
+ def forward(self, x):
+ return x.sum(axis=(1, 2, 3), keepdims=True)
+ def predict(self, x):
+ return self.forward(x)
+
+class InsensitiveModel(nn.Module):
+ def shape_input(self, x, shape, channel_first=True, batched=True):
+ return x
+ def forward(self, x):
+ B = x.shape[0]
+ return np.ones((B, 1), dtype=float) * 100.0
+ def predict(self, x):
+ return self.forward(x)
+
+class SemiSensitiveModel(nn.Module):
+ def shape_input(self, x, shape, channel_first=True, batched=True):
+ return x
+ def forward(self, x):
+ top_sum = x[:, :, 0, :].sum(axis=(1, 2))
+ return top_sum[:, None]
+ def predict(self, x):
+ return self.forward(x)
+
+class TrickModel(nn.Module):
+ def shape_input(self, x, shape, channel_first=True, batched=True):
+ return x
+ def predict(self, x):
+ bottom_sum = x[:, :, 1, :].sum(axis=(1, 2))
+ return bottom_sum[:, None]
@pytest.mark.axiomatic
@pytest.mark.parametrize(
@@ -418,6 +454,82 @@ def test_non_sensitivity(
)
assert scores is not None, "Test failed."
+@pytest.mark.axiomatic
+@pytest.mark.parametrize(
+ "scenario,model_factory,x_batch,y_batch,a_batch,expected_violations,kwargs",
+ [
+
+ (
+ "zero_violations",
+ lambda: SemiSensitiveModel(),
+ np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float),
+ np.array([0]),
+ np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float),
+ 0,
+ {"features_in_step": 2, "eps": 1e-5},
+ ),
+ (
+ "low_attr_high_change",
+ lambda: SensitiveModel(),
+ np.array([[[[5.0, 5.0], [5.0, 5.0]]]], dtype=float),
+ np.array([0]),
+ np.random.uniform(1e-6, 2e-6, size=(1, 1, 2, 2)),
+ 4,
+ {"features_in_step": 2, "eps": 1e-5},
+ ),
+ (
+ "high_attr_low_change",
+ lambda: InsensitiveModel(),
+ np.random.rand(1, 1, 4, 4),
+ np.array([0]),
+ np.ones((1, 1, 4, 4)),
+ 16,
+ {"features_in_step": 2, "eps": 1e-5},
+ ),
+ (
+ "half_good_half_bad",
+ lambda: TrickModel(),
+ np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float),
+ np.array([0]),
+ np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float),
+ 4,
+ {"features_in_step": 1, "eps": 1e-5},
+ ),
+ ],
+)
+def test_my_non_sensitivity_logics(
+ scenario,
+ model_factory,
+ x_batch,
+ y_batch,
+ a_batch,
+ expected_violations,
+ kwargs,
+):
+ """
+ Parametrized logic-based tests for NonSensitivity.
+ Each scenario defines a different consistency pattern between attribution and model behavior.
+ """
+ model = model_factory()
+ model.eval()
+ metric = NonSensitivity(
+ disable_warnings=True,
+ perturb_baseline="uniform",
+ normalise=False,
+ **kwargs,
+ )
+
+ scores = metric.evaluate_batch(model, x_batch, y_batch, a_batch)
+
+ # --- Assertions ---
+ assert isinstance(scores, np.ndarray), f"[{scenario}] Output must be np.ndarray"
+ assert scores.shape[0] == x_batch.shape[0], f"[{scenario}] Wrong batch size"
+ assert np.all(np.isfinite(scores)), f"[{scenario}] Scores contain NaN/Inf"
+
+ if expected_violations is not None:
+ assert scores[0] == expected_violations, (
+ f"[{scenario}] expected {expected_violations}, got {scores[0]}"
+ )
@pytest.mark.axiomatic
@pytest.mark.parametrize(
From d244aa3a29927d6fe312b3c1666c204577ae557f Mon Sep 17 00:00:00 2001
From: Tprojects66554 <38327606992@mby.co.il>
Date: Thu, 6 Nov 2025 10:09:05 +0200
Subject: [PATCH 4/5] change evaluate_batch_ documentation
---
quantus/metrics/axiomatic/non_sensitivity.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py
index de166b00..d04c19d2 100644
--- a/quantus/metrics/axiomatic/non_sensitivity.py
+++ b/quantus/metrics/axiomatic/non_sensitivity.py
@@ -336,9 +336,7 @@ def evaluate_batch(
- Perturbations are applied in groups of size `self.features_in_step`.
- The perturbation function `self.perturb_func` must follow the Quantus API:
it receives an array and an index mask, and returns a perturbed copy.
- - Designed to comply with Quantus internal metric conventions and to be
- lint-clean under `black` and `flake8`.
-
+ - Returned scores are counts of violations per sample or aggregated across samples.
"""
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)
From 4190c555d12df9cf40b9358a52e6791e5b0eb8d3 Mon Sep 17 00:00:00 2001
From: Tprojects66554 <38327606992@mby.co.il>
Date: Sun, 9 Nov 2025 19:20:39 +0200
Subject: [PATCH 5/5] Splitting functions for SRP
---
quantus/metrics/axiomatic/non_sensitivity.py | 200 ++++++++++---------
tests/metrics/test_axiomatic_metrics.py | 19 ++
2 files changed, 122 insertions(+), 97 deletions(-)
diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py
index d04c19d2..16c0dd46 100644
--- a/quantus/metrics/axiomatic/non_sensitivity.py
+++ b/quantus/metrics/axiomatic/non_sensitivity.py
@@ -1,18 +1,10 @@
"""This module contains the implementation of the Non-Sensitivity metric."""
# This file is part of Quantus.
-# Quantus is free software: you can redistribute it and/or modify it under the
-# terms of the GNU Lesser General Public License as published by the Free
-# Software Foundation, either version 3 of the License, or (at your option)
-# any later version.
-# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY
-# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
-# more details.
-# You should have received a copy of the GNU Lesser General Public License
-# along with Quantus. If not, see .
-# Quantus project URL:
-# .
+# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
+# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
+# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see .
+# Quantus project URL: .
import sys
import math
@@ -28,6 +20,7 @@
ModelType,
ScoreDirection,
)
+from quantus.helpers.model.model_interface import ModelInterface
from quantus.helpers.perturbation_utils import make_perturb_func
from quantus.metrics.base import Metric
@@ -42,8 +35,8 @@ class NonSensitivity(Metric[List[float]]):
"""
Implementation of NonSensitivity by Nguyen et al., 2020.
- Non-sensitivity measures if zero-importance is only assigned to features,
- that the model is not functionally dependent on.
+ Non-sensitivity measures if zero-importance is only assigned to features, that the model is not
+ functionally dependent on.
References:
1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model
@@ -142,9 +135,7 @@ def __init__(
# Save metric-specific attributes.
self.eps = eps
self.features_in_step = features_in_step
- self.perturb_func = make_perturb_func(
- perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
- )
+ self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)
# Asserts and warnings.
if not self.disable_warnings:
@@ -291,101 +282,116 @@ def custom_preprocess(
features_in_step=self.features_in_step,
input_shape=x_batch.shape[2:],
)
+
+
+
def evaluate_batch(
self,
- model,
+ model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray,
a_batch: np.ndarray,
**kwargs,
- ):
- """
- Evaluate a batch for the custom Non-Sensitivity metric.
-
- This implementation perturbs *feature* and *non-feature* pixels separately,
- evaluating how sensitive the model’s predictions are to each perturbation.
- The metric quantifies violations of the Non-Sensitivity principle in both directions:
- (1) when perturbing **non-feature** pixels *does* affect model predictions, and
- (2) when perturbing **feature** pixels *does not* affect model predictions.
-
- Parameters
- ----------
- model : object
- Model with `predict` and `shape_input` methods, compatible with Quantus conventions.
- x_batch : np.ndarray
- Input batch, of shape (B, C, H, W) or (B, H, W, C), depending on `channel_first`.
- y_batch : np.ndarray
- Ground truth or target class indices, of shape (B,).
- a_batch : np.ndarray
- Attribution maps aligned with `x_batch`, of shape (B, C, H, W).
- **kwargs :
- Additional keyword arguments passed through for flexibility.
-
- Returns
- -------
- np.ndarray
- Array of shape (B,), where each value indicates the total number of
- non-sensitivity violations per sample. Lower values indicate higher sensitivity
- (fewer violations), whereas higher values indicate non-sensitivity.
-
- Notes
- -----
- - The function assumes that a lower attribution value (below `self.eps`)
- represents a "non-feature" pixel.
- - Perturbations are applied in groups of size `self.features_in_step`.
- - The perturbation function `self.perturb_func` must follow the Quantus API:
- it receives an array and an index mask, and returns a perturbed copy.
- - Returned scores are counts of violations per sample or aggregated across samples.
+ ) -> List[int]:
"""
+ This method performs XAI evaluation on a single batch of explanations.
+ For more information on the specific logic, we refer the metric’s initialisation docstring.
+
+ Parameters
+ ----------
+ model: ModelInterface
+ A ModelInterface that is subject to explanation.
+ x_batch: np.ndarray
+ The input to be evaluated on a batch-basis.
+ y_batch: np.ndarray
+ The output to be evaluated on a batch-basis.
+ a_batch: np.ndarray
+ The explanation to be evaluated on a batch-basis.
+ kwargs:
+ Unused.
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (batch_size,), per sample score.
+ Lower values are better.
+ """
+
+ # Prepare shapes. Expand a_batch if not the same shape
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)
- B = x_batch.shape[0]
+ # Flatten the attributions.
+ batch_size = a_batch.shape[0]
x_shape = x_batch.shape
- x_flat = x_batch.reshape(B, -1)
- a_flat = a_batch.reshape(B, -1)
+ x_batch = x_batch.reshape(batch_size, -1)
+ a_batch = a_batch.reshape(batch_size, -1)
- non_features = a_flat < self.eps
+ non_features = a_batch < self.eps
features = ~non_features
x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True)
- y_pred = model.predict(x_input)[np.arange(B), y_batch]
-
- pixel_scores = np.zeros_like(a_flat, dtype=float)
-
- def perturb_and_record(indices_mask, desc="nonfeature"):
- for b in range(B):
- indices = np.where(indices_mask[b])[0]
- n_pixels = len(indices)
- if n_pixels == 0:
- continue
-
- n_steps = math.ceil(n_pixels / self.features_in_step)
- for step in range(n_steps):
- start = step * self.features_in_step
- end = min((step + 1) * self.features_in_step, n_pixels)
- subset_idx = indices[start:end]
- indices_2d = np.expand_dims(subset_idx, axis=0)
-
- perturbed_flat = x_flat.copy()
- perturbed_flat[b] = self.perturb_func(
- arr=perturbed_flat[b : b + 1, :],
- indices=indices_2d,
- )
- x_perturbed = perturbed_flat.reshape(x_shape)
- x_input = model.shape_input(
- x_perturbed, x_shape, channel_first=True, batched=True
- )
- y_pred_perturb = model.predict(x_input)[np.arange(B), y_batch]
-
- pixel_scores[b, subset_idx] = y_pred_perturb[b]
-
- perturb_and_record(non_features, "nonfeature")
- perturb_and_record(features, "feature")
-
- preds_differences = np.abs(y_pred[:, np.newaxis] - pixel_scores)
- preds_differences = preds_differences < self.eps
- pixel_scores = pixel_scores.reshape(x_shape)
+ y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]
+
+ pixel_scores_non = self._process_mask(
+ model, x_batch, y_batch, non_features, x_shape
+ )
+ pixel_scores_feat = self._process_mask(
+ model, x_batch, y_batch, features, x_shape
+ )
+
+ preds_differences = np.abs(
+ y_pred[:, None] - (pixel_scores_non + pixel_scores_feat)
+ ) < self.eps
return (preds_differences ^ non_features).sum(-1)
+
+ def _create_index_groups(self, mask: np.ndarray) -> List[np.ndarray]:
+ """Divide mask indices into perturbation groups."""
+ indices = np.where(mask)[0]
+ return [
+ indices[i:i + self.features_in_step]
+ for i in range(0, len(indices), self.features_in_step)
+ ]
+
+ def _perturb_sample_batch(
+ self, x_batch: np.ndarray, b: int, indices: np.ndarray
+ ) -> np.ndarray:
+ """Perturb a single sample within the batch."""
+ perturbed_flat = x_batch.copy()
+ indices_2d = np.expand_dims(indices, axis=0)
+ perturbed_flat[b] = self.perturb_func(
+ arr=perturbed_flat[b:b + 1, :],
+ indices=indices_2d,
+ )
+ return perturbed_flat
+
+ def _predict_scores(
+ self, model, x_batch: np.ndarray, y_batch: np.ndarray, x_shape: tuple
+ ) -> np.ndarray:
+ """Predict scores for the true labels of the given batch."""
+ x_input = model.shape_input(
+ x_batch.reshape(x_shape), x_shape, channel_first=True, batched=True
+ )
+ return model.predict(x_input)[np.arange(x_batch.shape[0]), y_batch]
+
+ def _process_mask(
+ self,
+ model,
+ x_batch: np.ndarray,
+ y_batch: np.ndarray,
+ mask: np.ndarray,
+ x_shape: tuple,
+ ) -> np.ndarray:
+ """Handle perturbation and prediction workflow for a single mask type."""
+ batch_size = x_batch.shape[0]
+ pixel_scores = np.zeros_like(x_batch, dtype=float)
+
+ for b in range(batch_size):
+ for indices in self._create_index_groups(mask[b]):
+ perturbed_batch = self._perturb_sample_batch(x_batch, b, indices)
+ preds = self._predict_scores(model, perturbed_batch, y_batch, x_shape)
+ pixel_scores[b, indices] = preds[b]
+
+ return pixel_scores
diff --git a/tests/metrics/test_axiomatic_metrics.py b/tests/metrics/test_axiomatic_metrics.py
index 18d93d96..83d3836e 100644
--- a/tests/metrics/test_axiomatic_metrics.py
+++ b/tests/metrics/test_axiomatic_metrics.py
@@ -11,18 +11,34 @@
# test_axiomatic_metrics.py (or similar)
+def _ensure_4d(x):
+ """Make sure x is (B, C, H, W), even if passed as (B, N)."""
+ x = np.array(x)
+ if x.ndim == 2:
+ B, N = x.shape
+ side = int(np.sqrt(N))
+ x = x.reshape(B, 1, side, side)
+ elif x.ndim == 3:
+ x = x[:, None, :, :]
+ return x
+
class SensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
+
def forward(self, x):
+ x = _ensure_4d(x)
return x.sum(axis=(1, 2, 3), keepdims=True)
+
def predict(self, x):
+ x = _ensure_4d(x)
return self.forward(x)
class InsensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def forward(self, x):
+ x = _ensure_4d(x)
B = x.shape[0]
return np.ones((B, 1), dtype=float) * 100.0
def predict(self, x):
@@ -32,15 +48,18 @@ class SemiSensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def forward(self, x):
+ x = _ensure_4d(x)
top_sum = x[:, :, 0, :].sum(axis=(1, 2))
return top_sum[:, None]
def predict(self, x):
+ x = _ensure_4d(x)
return self.forward(x)
class TrickModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def predict(self, x):
+ x = _ensure_4d(x)
bottom_sum = x[:, :, 1, :].sum(axis=(1, 2))
return bottom_sum[:, None]