Skip to content

Conversation

@Suliang-Jin
Copy link
Contributor

Description

I have changed some code in IF, TRAK and TracIn so they now support passing in dict. If the user still passes in tuple or list, the behavior doesn't change.
I have also added some experiments on supporting IF and Huggingface transformers.

@jiaqima
Copy link
Contributor

jiaqima commented Nov 27, 2025

@Suliang-Jin given this is a relatively large PR, could you follow the PR template to provide more detailed context about this PR?

Copy link
Contributor

@jiaqima jiaqima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this PR contained changes from another PR #217. If the proposed changes in this PR depend on the other PR, please make this PR after merging the other one.

README.md Outdated

The following is an example to use `IFAttributorCG` and `AttributionTask` to apply data attribution to a PyTorch model.

Please reference [here](./docs/guide/README.md) for the guide on how to properly define train/test data for Attributor and loss/target function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such file for ./docs/guide/README.md?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have created this README in this PR

python examples/brittleness/mnist_lr_brittleness.py --method cg --device cpu
python examples/data_cleaning/influence_function_data_cleaning.py --device cpu --train_size 1000 --val_size 100 --test_size 100 --remove_number 10
python examples/relatIF/influence_function_comparison.py --no_output
sed -i 's/range(1000)/range(100)/g' examples/lds_vs_gt/mnist.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should belong to another PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry about it. I will fix it.

@Suliang-Jin
Copy link
Contributor Author

Summary

I'm sorry about the confusion in this PR. Please refer to the PR created on Nov 26.
This PR is (1) fixing the support of dict type dataset passed in the attributor and (2) adding new experiments on the support of Huggingface transformers for Influence Function.

What’s Changed

  • Under directory algorithm/, I have added data type checking for tuple, list and dict, and the corresponding way of handling different data type in base.py, influence_function.py, trak.py and tracin.py. That is, if a user wants to pass in a dataset for attribution computation, the script will detect if the type of the dataset is tuple, list or dict, and then process it (e.g., by adding a dummy dimension on each data feature in influence_function.py).
  • Under directory experiments/gpt2_wikitext/, I have added experiment scripts score_IF.py to test the support of using Influence Function on Huggingface Transformers, the dataset passed in for attribution computation is still tuple. On the other hand, score_IF_dict.py and score_TRAK_dict.py are experiments to test if dict type dataset can be properly processed.
  • At the instruction from @TheaperDeng, I also added instruction document docs/guide/README.md to illustrate how to properly define the dataset/data loader and loss/target function.

Motivation

The original issue is raised from issue #165.
After discussion with @TheaperDeng, I aimed to provide both the support of Huggingface Transformers on Influence Function and the support of handling dict type data in attribution computation. The two issues are intertwined, as wikitext dataset is originally a dict and is commonly used for benchmarking LLMs. However, the two issues need to be tackled differently.

How It Works

On the support of Huggingface Transformers on Influence Function (see in experiments/gpt2_wikitext/), I simply defined special loss function tailored to the behavior of how influence_function.py and trak.py under algorithm/ handle data processing. That is, since influence_function.py explicitly added a dummy dimension to the dataset using unsqueeze(0), I "unsqueezed" the data in the loss function when calculating Influence Function. On the other hand, trak.py doesn't create a dummy dimension, so I didn't "unsqueezed" the data in the loss function when calculating TRAK.
On the other hand, the support of dict type dataset is simply adding an isinstance() check and processing the data in the corresonding way in base.py, influence_function.py, trak.py and tracin.py under algorithm/. I only added isinstance() in the places where the data processing was originally needed: for example, in base.py, I added the type checking when this happens:

train_batch_data = tuple(
    data.to(self.device).unsqueeze(0) for data in train_batch_data_
)

Testing

  • Under experiments/gpt2_wikitext/, I have tested all score_IF.py, score_IF_dict.py and score_TRAK_dict.py.
  • To ensure that the scripts still properly support tuple-type dataset, I have rerun trak_dropout_lds.py and influence_function_lds.py under examples/pretrained_benchmark/.

Related Issues

Fixes #165

)
elif isinstance(train_batch_data_, dict):
train_batch_data = {
k: v.unsqueeze(0) for k, v in train_batch_data_.items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also assume the value in dictionary to be tensor right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put the data to self.device here.

k: v.to(self.device) for k, v in full_data_.items()
}
else:
raise Exception("We currently only support the train/test data to be tuple, list or dict.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to fix IFAttributor API, I will delete it in another PR.

"""
if seed is None:
seed = random.getrandbits(64)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is coverred in another PR?

# Calculate and print LDS score
##############################
lds_score = lds(score, ground_truth)[0]
print("lds:", torch.mean(lds_score[~torch.isnan(lds_score)])) No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is coverred in another PR?

@@ -0,0 +1,120 @@
# User Guide
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation is clear, a high-level summary table at the top would be beneficial. It should list the supported data types and callable types for all methods in the https://github.com/TRAIS-Lab/dattri?tab=readme-ov-file#supported-algorithms.

)
logp = -outputs.loss
return logp - torch.log(1 - torch.exp(logp))
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly different requirements should be applied to TRAK (Multi-class Margin) and TracIN (training loss and any target function).

@@ -0,0 +1,686 @@
#!/usr/bin/env python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need one script for IF

@@ -0,0 +1,742 @@
#!/usr/bin/env python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need one script for TRAK.

@jiaqima jiaqima closed this Jan 19, 2026
@jiaqima jiaqima reopened this Jan 19, 2026
@@ -0,0 +1,142 @@
# User Guide
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chage the title to "Data Type Compatibility for Loss and Target Functions". Rename the file to be data_compatibility.md

| | [EK-FAC](https://arxiv.org/abs/2308.03296) | ✔️ | ✔️ || [Code example](../../examples/brittleness/mnist_lr_brittleness.py) |
| | [RelatIF](https://arxiv.org/pdf/2003.11630) | ✔️ | ✔️ || [Code example](../../examples/brittleness/mnist_lr_brittleness.py) |
| | [LoGra](https://arxiv.org/pdf/2405.13954) | ✔️ | ✔️ || [Code example](../../examples/brittleness/mnist_lr_brittleness.py) |
| | [GraSS](https://arxiv.org/pdf/2505.18976) | ✔️ | ✔️ || [Code example](../../examples/brittleness/mnist_lr_brittleness.py) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for GraSS, LoGra, RelateIF, EK-FAC, we don't have their examples in ../../examples/brittleness/mnist_lr_brittleness.py

type=str,
default="tuple",
choices=["tuple", "list", "dict"]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to show-off what we have supported in the examples. Just choose the most convenient way and demonstrate it in the script.

type=str,
default="tuple",
choices=["tuple", "list", "dict"]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

score = attributor.attribute(train_dataloader, eval_dataloader)

torch.save(score, "score_IF.pt")
logger.info("Attribution scores saved to score_IF.pt")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does IF perform on GPT-2 + wikitext setting?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants