-
Notifications
You must be signed in to change notification settings - Fork 23
Improved the code in IF and TRAK to support passing in dict #225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@Suliang-Jin given this is a relatively large PR, could you follow the PR template to provide more detailed context about this PR? |
jiaqima
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this PR contained changes from another PR #217. If the proposed changes in this PR depend on the other PR, please make this PR after merging the other one.
README.md
Outdated
|
|
||
| The following is an example to use `IFAttributorCG` and `AttributionTask` to apply data attribution to a PyTorch model. | ||
|
|
||
| Please reference [here](./docs/guide/README.md) for the guide on how to properly define train/test data for Attributor and loss/target function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no such file for ./docs/guide/README.md?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have created this README in this PR
.github/workflows/examples_test.yml
Outdated
| python examples/brittleness/mnist_lr_brittleness.py --method cg --device cpu | ||
| python examples/data_cleaning/influence_function_data_cleaning.py --device cpu --train_size 1000 --val_size 100 --test_size 100 --remove_number 10 | ||
| python examples/relatIF/influence_function_comparison.py --no_output | ||
| sed -i 's/range(1000)/range(100)/g' examples/lds_vs_gt/mnist.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should belong to another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry about it. I will fix it.
SummaryI'm sorry about the confusion in this PR. Please refer to the PR created on Nov 26. What’s Changed
MotivationThe original issue is raised from issue #165. How It WorksOn the support of Huggingface Transformers on Influence Function (see in train_batch_data = tuple(
data.to(self.device).unsqueeze(0) for data in train_batch_data_
)Testing
Related IssuesFixes #165 |
dattri/algorithm/base.py
Outdated
| ) | ||
| elif isinstance(train_batch_data_, dict): | ||
| train_batch_data = { | ||
| k: v.unsqueeze(0) for k, v in train_batch_data_.items() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also assume the value in dictionary to be tensor right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should put the data to self.device here.
| k: v.to(self.device) for k, v in full_data_.items() | ||
| } | ||
| else: | ||
| raise Exception("We currently only support the train/test data to be tuple, list or dict.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to fix IFAttributor API, I will delete it in another PR.
dattri/model_util/retrain.py
Outdated
| """ | ||
| if seed is None: | ||
| seed = random.getrandbits(64) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is coverred in another PR?
examples/lds_vs_gt/mnist.py
Outdated
| # Calculate and print LDS score | ||
| ############################## | ||
| lds_score = lds(score, ground_truth)[0] | ||
| print("lds:", torch.mean(lds_score[~torch.isnan(lds_score)])) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is coverred in another PR?
docs/guide/README.md
Outdated
| @@ -0,0 +1,120 @@ | |||
| # User Guide | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation is clear, a high-level summary table at the top would be beneficial. It should list the supported data types and callable types for all methods in the https://github.com/TRAIS-Lab/dattri?tab=readme-ov-file#supported-algorithms.
| ) | ||
| logp = -outputs.loss | ||
| return logp - torch.log(1 - torch.exp(logp)) | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly different requirements should be applied to TRAK (Multi-class Margin) and TracIN (training loss and any target function).
| @@ -0,0 +1,686 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need one script for IF
| @@ -0,0 +1,742 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need one script for TRAK.
9471f17 to
3ae6be8
Compare
docs/guide/README.md
Outdated
| @@ -0,0 +1,142 @@ | |||
| # User Guide | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chage the title to "Data Type Compatibility for Loss and Target Functions". Rename the file to be data_compatibility.md
docs/guide/README.md
Outdated
| | | [EK-FAC](https://arxiv.org/abs/2308.03296) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [RelatIF](https://arxiv.org/pdf/2003.11630) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [LoGra](https://arxiv.org/pdf/2405.13954) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [GraSS](https://arxiv.org/pdf/2505.18976) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for GraSS, LoGra, RelateIF, EK-FAC, we don't have their examples in ../../examples/brittleness/mnist_lr_brittleness.py
| type=str, | ||
| default="tuple", | ||
| choices=["tuple", "list", "dict"] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to show-off what we have supported in the examples. Just choose the most convenient way and demonstrate it in the script.
| type=str, | ||
| default="tuple", | ||
| choices=["tuple", "list", "dict"] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
| score = attributor.attribute(train_dataloader, eval_dataloader) | ||
|
|
||
| torch.save(score, "score_IF.pt") | ||
| logger.info("Attribution scores saved to score_IF.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does IF perform on GPT-2 + wikitext setting?
Description
I have changed some code in IF, TRAK and TracIn so they now support passing in dict. If the user still passes in tuple or list, the behavior doesn't change.
I have also added some experiments on supporting IF and Huggingface transformers.