Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
400 changes: 400 additions & 0 deletions clip_text_span/LICENSE.txt

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions clip_text_span/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
## Interpreting CLIP's Image Representation via Text-Based Decomposition
Official PyTorch Implementation

### [Paper](https://arxiv.org/abs/2310.05916) | [Project Page](https://yossigandelsman.github.io/clip_decomposition/)

[Yossi Gandelsman](https://yossigandelsman.github.io/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) and [Jacob Steinhardt](https://jsteinhardt.stat.berkeley.edu/)

![Teaser](images/teaser.png)

### Setup
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment:

```bash
conda env create -f environment.yml
conda activate prsclip
```
### Preprocessing
To obtain the projected residual stream components for the ImageNet validation set, including the contributions from multi-head attentions and MLPs, please run one of the following instructions:

```bash
python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path <PATH>
python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path <PATH>
python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path <PATH>
```

To obtain the precomputed text representations of the ImageNet classes, please run:
```bash
python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k
python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k
python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k
```

### Mean-ablations
To verify that the MLPs and the attention from the class token to itself can be mean-ablated, please run:

```bash
python compute_ablations.py --model ViT-H-14
python compute_ablations.py --model ViT-L-14
python compute_ablations.py --model ViT-B-16
```

### Convert text labels to representation
To convert the text labels for <i>TextSpan</i> to CLIP text representations, please run:

```bash
python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/google_3498_english.txt
python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/image_descriptions_general.txt
```

### ImageNet segmentation
Please download the dataset from [here](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat):

```bash
mkdir imagenet_seg
cd imagenet_seg
wget http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat
```

To get the evaluation results, please run:

```bash
python compute_segmentations.py --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img
python compute_segmentations.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img
python compute_segmentations.py --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img
```
Save the results with the `--save_img` flag.


### TextSpan

To find meaningful directions for all the attenion heads, run:
```bash
python compute_complete_text_set.py --device cuda:0 --model ViT-B-16 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general
python compute_complete_text_set.py --device cuda:0 --model ViT-L-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general
python compute_complete_text_set.py --device cuda:0 --model ViT-H-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general
```

### Other datasets
To download the Waterbirds datasets, run:
```bash
wget https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz
tar -xf waterbird_complete95_forest2water2.tar.gz
```
To compute the overall accuracy, run:
```bash
python compute_prs.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path <PATH>
python compute_text_projection.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k
python compute_use_specific_heads.py --model ViT-L-14 --dataset binary_waterbirds
```

### Spatial decomposition
Please see a demo for the spatial decomposition of CLIP in `demo.ipynb`.


### Nearest neighbors search
Please see the nearest neighbors search demo in `nns.ipynb`.

### BibTeX

```bibtex
@inproceedings{
gandelsman2024interpreting,
title={Interpreting {CLIP}'s Image Representation via Text-Based Decomposition},
author={Yossi Gandelsman and Alexei A. Efros and Jacob Steinhardt},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=5Ca9sSzuDp}
}
```
125 changes: 125 additions & 0 deletions clip_text_span/compute_ablations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import numpy as np
import torch
import os.path
import argparse
import einops
from pathlib import Path

import tqdm
from utils.misc import accuracy


def get_args_parser():
parser = argparse.ArgumentParser("Ablations part", add_help=False)
Comment on lines +12 to +13
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

논문에서 얘기하는 mean ablation의 정체


# Model parameters
parser.add_argument(
"--model",
default="ViT-H-14",
type=str,
metavar="MODEL",
help="Name of model to use",
)
# Dataset parameters
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument(
"--figures_dir", default="./output_dir", help="path where data is saved"
)
parser.add_argument(
"--input_dir", default="./output_dir", help="path where data is saved"
)
parser.add_argument(
"--dataset",
type=str,
default="imagenet",
help="imagenet, waterbirds, cub, binary_waterbirds",
)
return parser


def main(args):

attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d]
mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d]
Comment on lines +40 to +43
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy로 attention과 mlp를 저장해놓음.
[b, l, h, d] 차원인데 아마 b차원이 imagenet 모든 샘플이려나?

with open(
os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"),
"rb",
) as f:
classifier = np.load(f)
Comment on lines +44 to +48
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classifier도 있음. 이건 최종 output logit인가..?

if args.dataset == "imagenet":
labels = np.array([i // 5 for i in range(attns.shape[0])])
else:
with open(
os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb"
) as f:
labels = np.load(f)
Comment on lines +49 to +55
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label load해줌

baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1)
baseline_acc = (
accuracy(
torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels)
)[0]
* 100
)
print("Baseline:", baseline_acc)
Comment on lines +56 to +63
Copy link
Owner Author

@long8v long8v May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

baseline accuracy는 attns를 layer / head별로 sum을 해주고 mlp도 (l+1)차원에서 sum을 해준뒤 둘을 합쳐주면 됨

  • attns: attns.sum(axis=(1,2)) [b, l, h, d] -> [b, d]
  • mlps: mlps.sum(axis=1) [b, l+1, d] -> [b, d]

왜 연산이 이렇게 되는지 잘 모르겠는데 아마 transformer circuit https://transformer-circuits.pub/2021/framework/index.html#summarizing-ovqk-matrices:~:text=as%20independently%20additive.-,Attention%20Heads%20as%20Information%20Movement,-But%20if%20attention 얘를 이해하면 될듯

mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0])
mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1)
mlps_ablation_acc = (
accuracy(
torch.from_numpy(mlps_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ MLPs ablation:", mlps_ablation_acc)
Comment on lines +64 to +73
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlps_mean은 아래와 같음. mlp 연산에서 그냥 배치 차원에서 mean을 구한다음에 repeat하고 이를 다시 attn랑 summation해서 구하는 방식

mlps_no_layers = mlps.sum(axis=1)
attns_no_cls = attns.sum(axis=2)
with open(
os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb"
) as f:
cls_attn = np.load(f) # [b, l, d]
attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :]
no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers
no_cls_acc = (
accuracy(
torch.from_numpy(no_cls_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ CLS ablation:", no_cls_acc)
mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1)
mlp_and_no_cls_ablation_acc = (
accuracy(
torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc)
Comment on lines +74 to +98
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls ablation은 아래와 같음. cls에 대한 attention을 가져온 뒤 이걸 attns에 대해서 빼줌 (잘 이해는 못함)

no_heads_attentions = attns.sum(axis=(2))
all_accuracies = [baseline_acc]
for layer in range(attns.shape[1]):
current_model = (
np.sum(
np.mean(no_heads_attentions[:, :layer], axis=0, keepdims=True), axis=1
)
+ np.mean(no_heads_attentions[:, layer], axis=0, keepdims=True)
+ np.sum(no_heads_attentions[:, layer + 1 :], axis=1)
)
current_accuracy = (
accuracy(
torch.from_numpy((mlps_no_layers + current_model) @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
all_accuracies.append(current_accuracy)
print("Attention ablations:", all_accuracies)


if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
if args.figures_dir:
Path(args.figures_dir).mkdir(parents=True, exist_ok=True)
main(args)
Loading