-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add text span #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add text span #175
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| ## Interpreting CLIP's Image Representation via Text-Based Decomposition | ||
| Official PyTorch Implementation | ||
|
|
||
| ### [Paper](https://arxiv.org/abs/2310.05916) | [Project Page](https://yossigandelsman.github.io/clip_decomposition/) | ||
|
|
||
| [Yossi Gandelsman](https://yossigandelsman.github.io/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) and [Jacob Steinhardt](https://jsteinhardt.stat.berkeley.edu/) | ||
|
|
||
|  | ||
|
|
||
| ### Setup | ||
| We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment: | ||
|
|
||
| ```bash | ||
| conda env create -f environment.yml | ||
| conda activate prsclip | ||
| ``` | ||
| ### Preprocessing | ||
| To obtain the projected residual stream components for the ImageNet validation set, including the contributions from multi-head attentions and MLPs, please run one of the following instructions: | ||
|
|
||
| ```bash | ||
| python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path <PATH> | ||
| python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path <PATH> | ||
| python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path <PATH> | ||
| ``` | ||
|
|
||
| To obtain the precomputed text representations of the ImageNet classes, please run: | ||
| ```bash | ||
| python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k | ||
| python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k | ||
| python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k | ||
| ``` | ||
|
|
||
| ### Mean-ablations | ||
| To verify that the MLPs and the attention from the class token to itself can be mean-ablated, please run: | ||
|
|
||
| ```bash | ||
| python compute_ablations.py --model ViT-H-14 | ||
| python compute_ablations.py --model ViT-L-14 | ||
| python compute_ablations.py --model ViT-B-16 | ||
| ``` | ||
|
|
||
| ### Convert text labels to representation | ||
| To convert the text labels for <i>TextSpan</i> to CLIP text representations, please run: | ||
|
|
||
| ```bash | ||
| python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/google_3498_english.txt | ||
| python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/image_descriptions_general.txt | ||
| ``` | ||
|
|
||
| ### ImageNet segmentation | ||
| Please download the dataset from [here](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat): | ||
|
|
||
| ```bash | ||
| mkdir imagenet_seg | ||
| cd imagenet_seg | ||
| wget http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat | ||
| ``` | ||
|
|
||
| To get the evaluation results, please run: | ||
|
|
||
| ```bash | ||
| python compute_segmentations.py --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img | ||
| python compute_segmentations.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img | ||
| python compute_segmentations.py --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img | ||
| ``` | ||
| Save the results with the `--save_img` flag. | ||
|
|
||
|
|
||
| ### TextSpan | ||
|
|
||
| To find meaningful directions for all the attenion heads, run: | ||
| ```bash | ||
| python compute_complete_text_set.py --device cuda:0 --model ViT-B-16 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general | ||
| python compute_complete_text_set.py --device cuda:0 --model ViT-L-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general | ||
| python compute_complete_text_set.py --device cuda:0 --model ViT-H-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general | ||
| ``` | ||
|
|
||
| ### Other datasets | ||
| To download the Waterbirds datasets, run: | ||
| ```bash | ||
| wget https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz | ||
| tar -xf waterbird_complete95_forest2water2.tar.gz | ||
| ``` | ||
| To compute the overall accuracy, run: | ||
| ```bash | ||
| python compute_prs.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path <PATH> | ||
| python compute_text_projection.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k | ||
| python compute_use_specific_heads.py --model ViT-L-14 --dataset binary_waterbirds | ||
| ``` | ||
|
|
||
| ### Spatial decomposition | ||
| Please see a demo for the spatial decomposition of CLIP in `demo.ipynb`. | ||
|
|
||
|
|
||
| ### Nearest neighbors search | ||
| Please see the nearest neighbors search demo in `nns.ipynb`. | ||
|
|
||
| ### BibTeX | ||
|
|
||
| ```bibtex | ||
| @inproceedings{ | ||
| gandelsman2024interpreting, | ||
| title={Interpreting {CLIP}'s Image Representation via Text-Based Decomposition}, | ||
| author={Yossi Gandelsman and Alexei A. Efros and Jacob Steinhardt}, | ||
| booktitle={The Twelfth International Conference on Learning Representations}, | ||
| year={2024}, | ||
| url={https://openreview.net/forum?id=5Ca9sSzuDp} | ||
| } | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| import numpy as np | ||
| import torch | ||
| import os.path | ||
| import argparse | ||
| import einops | ||
| from pathlib import Path | ||
|
|
||
| import tqdm | ||
| from utils.misc import accuracy | ||
|
|
||
|
|
||
| def get_args_parser(): | ||
| parser = argparse.ArgumentParser("Ablations part", add_help=False) | ||
|
|
||
| # Model parameters | ||
| parser.add_argument( | ||
| "--model", | ||
| default="ViT-H-14", | ||
| type=str, | ||
| metavar="MODEL", | ||
| help="Name of model to use", | ||
| ) | ||
| # Dataset parameters | ||
| parser.add_argument("--num_workers", default=10, type=int) | ||
| parser.add_argument( | ||
| "--figures_dir", default="./output_dir", help="path where data is saved" | ||
| ) | ||
| parser.add_argument( | ||
| "--input_dir", default="./output_dir", help="path where data is saved" | ||
| ) | ||
| parser.add_argument( | ||
| "--dataset", | ||
| type=str, | ||
| default="imagenet", | ||
| help="imagenet, waterbirds, cub, binary_waterbirds", | ||
| ) | ||
| return parser | ||
|
|
||
|
|
||
| def main(args): | ||
|
|
||
| attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d] | ||
| mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d] | ||
|
Comment on lines
+40
to
+43
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numpy로 attention과 mlp를 저장해놓음. |
||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), | ||
| "rb", | ||
| ) as f: | ||
| classifier = np.load(f) | ||
|
Comment on lines
+44
to
+48
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. classifier도 있음. 이건 최종 output logit인가..? |
||
| if args.dataset == "imagenet": | ||
| labels = np.array([i // 5 for i in range(attns.shape[0])]) | ||
| else: | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb" | ||
| ) as f: | ||
| labels = np.load(f) | ||
|
Comment on lines
+49
to
+55
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. label load해줌 |
||
| baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) | ||
| baseline_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels) | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("Baseline:", baseline_acc) | ||
|
Comment on lines
+56
to
+63
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. baseline accuracy는 attns를 layer / head별로 sum을 해주고 mlp도 (l+1)차원에서 sum을 해준뒤 둘을 합쳐주면 됨
왜 연산이 이렇게 되는지 잘 모르겠는데 아마 transformer circuit https://transformer-circuits.pub/2021/framework/index.html#summarizing-ovqk-matrices:~:text=as%20independently%20additive.-,Attention%20Heads%20as%20Information%20Movement,-But%20if%20attention 얘를 이해하면 될듯 |
||
| mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0]) | ||
| mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1) | ||
| mlps_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlps_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs ablation:", mlps_ablation_acc) | ||
|
Comment on lines
+64
to
+73
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mlps_mean은 아래와 같음. mlp 연산에서 그냥 배치 차원에서 mean을 구한다음에 repeat하고 이를 다시 attn랑 summation해서 구하는 방식 |
||
| mlps_no_layers = mlps.sum(axis=1) | ||
| attns_no_cls = attns.sum(axis=2) | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb" | ||
| ) as f: | ||
| cls_attn = np.load(f) # [b, l, d] | ||
| attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :] | ||
| no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers | ||
| no_cls_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ CLS ablation:", no_cls_acc) | ||
| mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1) | ||
| mlp_and_no_cls_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc) | ||
|
Comment on lines
+74
to
+98
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cls ablation은 아래와 같음. cls에 대한 attention을 가져온 뒤 이걸 attns에 대해서 빼줌 (잘 이해는 못함) |
||
| no_heads_attentions = attns.sum(axis=(2)) | ||
| all_accuracies = [baseline_acc] | ||
| for layer in range(attns.shape[1]): | ||
| current_model = ( | ||
| np.sum( | ||
| np.mean(no_heads_attentions[:, :layer], axis=0, keepdims=True), axis=1 | ||
| ) | ||
| + np.mean(no_heads_attentions[:, layer], axis=0, keepdims=True) | ||
| + np.sum(no_heads_attentions[:, layer + 1 :], axis=1) | ||
| ) | ||
| current_accuracy = ( | ||
| accuracy( | ||
| torch.from_numpy((mlps_no_layers + current_model) @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| all_accuracies.append(current_accuracy) | ||
| print("Attention ablations:", all_accuracies) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = get_args_parser() | ||
| args = args.parse_args() | ||
| if args.figures_dir: | ||
| Path(args.figures_dir).mkdir(parents=True, exist_ok=True) | ||
| main(args) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
논문에서 얘기하는 mean ablation의 정체