-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add text span #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add text span #175
Conversation
long8v
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(24.05.08)
compute_ablation 부분만 읽음. transformer circuit을 이해해야할 듯
| def get_args_parser(): | ||
| parser = argparse.ArgumentParser("Ablations part", add_help=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
논문에서 얘기하는 mean ablation의 정체
| def main(args): | ||
|
|
||
| attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d] | ||
| mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy로 attention과 mlp를 저장해놓음.
[b, l, h, d] 차원인데 아마 b차원이 imagenet 모든 샘플이려나?
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), | ||
| "rb", | ||
| ) as f: | ||
| classifier = np.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
classifier도 있음. 이건 최종 output logit인가..?
| if args.dataset == "imagenet": | ||
| labels = np.array([i // 5 for i in range(attns.shape[0])]) | ||
| else: | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb" | ||
| ) as f: | ||
| labels = np.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label load해줌
| baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) | ||
| baseline_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels) | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("Baseline:", baseline_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
baseline accuracy는 attns를 layer / head별로 sum을 해주고 mlp도 (l+1)차원에서 sum을 해준뒤 둘을 합쳐주면 됨
- attns: attns.sum(axis=(1,2)) [b, l, h, d] -> [b, d]
- mlps: mlps.sum(axis=1) [b, l+1, d] -> [b, d]
왜 연산이 이렇게 되는지 잘 모르겠는데 아마 transformer circuit https://transformer-circuits.pub/2021/framework/index.html#summarizing-ovqk-matrices:~:text=as%20independently%20additive.-,Attention%20Heads%20as%20Information%20Movement,-But%20if%20attention 얘를 이해하면 될듯
| mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0]) | ||
| mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1) | ||
| mlps_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlps_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs ablation:", mlps_ablation_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlps_mean은 아래와 같음. mlp 연산에서 그냥 배치 차원에서 mean을 구한다음에 repeat하고 이를 다시 attn랑 summation해서 구하는 방식
| mlps_no_layers = mlps.sum(axis=1) | ||
| attns_no_cls = attns.sum(axis=2) | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb" | ||
| ) as f: | ||
| cls_attn = np.load(f) # [b, l, d] | ||
| attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :] | ||
| no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers | ||
| no_cls_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ CLS ablation:", no_cls_acc) | ||
| mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1) | ||
| mlp_and_no_cls_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls ablation은 아래와 같음. cls에 대한 attention을 가져온 뒤 이걸 attns에 대해서 빼줌 (잘 이해는 못함)
#172