Skip to content

Conversation

@long8v
Copy link
Owner

@long8v long8v commented May 7, 2024

@long8v long8v changed the title 240507 feat text span feat: add text span May 7, 2024
Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(24.05.08)
compute_ablation 부분만 읽음. transformer circuit을 이해해야할 듯

Comment on lines +12 to +13
def get_args_parser():
parser = argparse.ArgumentParser("Ablations part", add_help=False)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

논문에서 얘기하는 mean ablation의 정체

Comment on lines +40 to +43
def main(args):

attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d]
mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy로 attention과 mlp를 저장해놓음.
[b, l, h, d] 차원인데 아마 b차원이 imagenet 모든 샘플이려나?

Comment on lines +44 to +48
with open(
os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"),
"rb",
) as f:
classifier = np.load(f)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classifier도 있음. 이건 최종 output logit인가..?

Comment on lines +49 to +55
if args.dataset == "imagenet":
labels = np.array([i // 5 for i in range(attns.shape[0])])
else:
with open(
os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb"
) as f:
labels = np.load(f)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label load해줌

Comment on lines +56 to +63
baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1)
baseline_acc = (
accuracy(
torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels)
)[0]
* 100
)
print("Baseline:", baseline_acc)
Copy link
Owner Author

@long8v long8v May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

baseline accuracy는 attns를 layer / head별로 sum을 해주고 mlp도 (l+1)차원에서 sum을 해준뒤 둘을 합쳐주면 됨

  • attns: attns.sum(axis=(1,2)) [b, l, h, d] -> [b, d]
  • mlps: mlps.sum(axis=1) [b, l+1, d] -> [b, d]

왜 연산이 이렇게 되는지 잘 모르겠는데 아마 transformer circuit https://transformer-circuits.pub/2021/framework/index.html#summarizing-ovqk-matrices:~:text=as%20independently%20additive.-,Attention%20Heads%20as%20Information%20Movement,-But%20if%20attention 얘를 이해하면 될듯

Comment on lines +64 to +73
mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0])
mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1)
mlps_ablation_acc = (
accuracy(
torch.from_numpy(mlps_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ MLPs ablation:", mlps_ablation_acc)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlps_mean은 아래와 같음. mlp 연산에서 그냥 배치 차원에서 mean을 구한다음에 repeat하고 이를 다시 attn랑 summation해서 구하는 방식

Comment on lines +74 to +98
mlps_no_layers = mlps.sum(axis=1)
attns_no_cls = attns.sum(axis=2)
with open(
os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb"
) as f:
cls_attn = np.load(f) # [b, l, d]
attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :]
no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers
no_cls_acc = (
accuracy(
torch.from_numpy(no_cls_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ CLS ablation:", no_cls_acc)
mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1)
mlp_and_no_cls_ablation_acc = (
accuracy(
torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(),
torch.from_numpy(labels),
)[0]
* 100
)
print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls ablation은 아래와 같음. cls에 대한 attention을 가져온 뒤 이걸 attns에 대해서 빼줌 (잘 이해는 못함)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants