From d40d34da9e52706570bb6139eef89297fb5243b1 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Tue, 27 Dec 2022 18:54:48 +0000 Subject: [PATCH 01/13] Rework perceiver tensor logic --- point_e/models/perceiver.py | 50 +++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 9e7c730..0dd352b 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -84,26 +84,44 @@ def __init__( if data_width is None: data_width = width - self.attn = MultiheadCrossAttention( - device=device, - dtype=dtype, - n_data=n_data, - width=width, - heads=heads, - data_width=data_width, - init_scale=init_scale, - ) - self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) - self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) - self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) - self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + # Use the torch.cuda.device_of() function to determine if the input tensors are on the GPU or CPU, + # and then use the appropriate layer implementations for better performance. + # Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass. + if device.type == "cuda": + self.ln_1 = nn.CUDALayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.CUDALayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.CUDALayerNorm(width, device=device, dtype=dtype) + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + ) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + else: + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + ) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) def forward(self, x: torch.Tensor, data: torch.Tensor): - x = x + self.attn(self.ln_1(x), self.ln_2(data)) - x = x + self.mlp(self.ln_3(x)) + with torch.no_grad(): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) return x - class SimplePerceiver(nn.Module): """ Only does cross attention From 4830080e209355d982e5fc18cd96e3e32795edf3 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Thu, 29 Dec 2022 11:31:20 +0000 Subject: [PATCH 02/13] Add apex lib to setup --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3c58cbf..f5f06ab 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ "filelock", "Pillow", "torch", + "apex", "fire", "humanize", "requests", From 400535e01c0e5a069645b5822a0251cbc2e9baf2 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Thu, 29 Dec 2022 11:50:34 +0000 Subject: [PATCH 03/13] Switch to apex fused layer norm --- point_e/models/perceiver.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 0dd352b..d7d6569 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -1,6 +1,7 @@ import math from typing import Optional +import apex import torch import torch.nn as nn @@ -88,9 +89,9 @@ def __init__( # and then use the appropriate layer implementations for better performance. # Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass. if device.type == "cuda": - self.ln_1 = nn.CUDALayerNorm(width, device=device, dtype=dtype) - self.ln_2 = nn.CUDALayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = nn.CUDALayerNorm(width, device=device, dtype=dtype) + self.ln_1 = apex.normalization.FusedLayerNorm(width, device=device, dtype=dtype) + self.ln_2 = apex.normalization.FusedLayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = apex.normalization.FusedLayerNorm(width, device=device, dtype=dtype) self.attn = MultiheadCrossAttention( device=device, dtype=dtype, From 032673c6cb5b776a200dffbe39f5a2dd5401f611 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Thu, 29 Dec 2022 14:55:09 +0000 Subject: [PATCH 04/13] Tidy up repeating code and apex import --- point_e/models/perceiver.py | 43 +++++++++++++++---------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index d7d6569..8d6afbc 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -1,12 +1,11 @@ import math -from typing import Optional - -import apex import torch import torch.nn as nn from .checkpoint import checkpoint from .transformer import MLP, init_linear +from typing import Optional +from apex.normalization import FusedLayerNorm class MultiheadCrossAttention(nn.Module): @@ -88,34 +87,26 @@ def __init__( # Use the torch.cuda.device_of() function to determine if the input tensors are on the GPU or CPU, # and then use the appropriate layer implementations for better performance. # Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass. + if device.type == "cuda": - self.ln_1 = apex.normalization.FusedLayerNorm(width, device=device, dtype=dtype) - self.ln_2 = apex.normalization.FusedLayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = apex.normalization.FusedLayerNorm(width, device=device, dtype=dtype) - self.attn = MultiheadCrossAttention( - device=device, - dtype=dtype, - n_data=n_data, - width=width, - heads=heads, - data_width=data_width, - init_scale=init_scale, - ) - self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) + self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) else: self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) - self.attn = MultiheadCrossAttention( - device=device, - dtype=dtype, - n_data=n_data, - width=width, - heads=heads, - data_width=data_width, - init_scale=init_scale, - ) - self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + ) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) def forward(self, x: torch.Tensor, data: torch.Tensor): with torch.no_grad(): From 80a2920dc46c4b2c943af2cf8cac8bde66e26652 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Thu, 29 Dec 2022 15:00:56 +0000 Subject: [PATCH 05/13] Add apex gitlib to install requires --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f5f06ab..d3af888 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ "filelock", "Pillow", "torch", - "apex", + "apex @ git+https://github.com/NVIDIA/apex.git", "fire", "humanize", "requests", From 16b349247232971bd0a9107f8e9026d895662db7 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Thu, 29 Dec 2022 21:55:03 +0000 Subject: [PATCH 06/13] Add np_grad to every forward pass in percevier --- point_e/models/perceiver.py | 39 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 8d6afbc..d31c3df 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -36,10 +36,11 @@ def __init__( init_linear(self.c_proj, init_scale) def forward(self, x, data): - x = self.c_q(x) - data = self.c_kv(data) - x = checkpoint(self.attention, (x, data), (), True) - x = self.c_proj(x) + with torch.no_grad(): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) return x @@ -52,18 +53,19 @@ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_da self.n_data = n_data def forward(self, q, kv): - _, n_ctx, _ = q.shape - bs, n_data, width = kv.shape - attn_ch = width // self.heads // 2 - scale = 1 / math.sqrt(math.sqrt(attn_ch)) - q = q.view(bs, n_ctx, self.heads, -1) - kv = kv.view(bs, n_data, self.heads, -1) - k, v = torch.split(kv, attn_ch, dim=-1) - weight = torch.einsum( - "bthc,bshc->bhts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - wdtype = weight.dtype - weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + with torch.no_grad(): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) @@ -151,6 +153,7 @@ def __init__( ) def forward(self, x: torch.Tensor, data: torch.Tensor): - for block in self.resblocks: - x = block(x, data) + with torch.no_grad(): + for block in self.resblocks: + x = block(x, data) return x From bfa50202378568b370a38362183cfd1bc237159b Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Fri, 30 Dec 2022 13:21:19 +0000 Subject: [PATCH 07/13] Change to fusedlayernorm lib, fix forward pass --- point_e/models/perceiver.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index d31c3df..a9e4299 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -86,18 +86,10 @@ def __init__( if data_width is None: data_width = width - # Use the torch.cuda.device_of() function to determine if the input tensors are on the GPU or CPU, - # and then use the appropriate layer implementations for better performance. - # Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass. - - if device.type == "cuda": - self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) - self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) - else: - self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) - self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + # Use the FusedLayerNorm module for faster layer normalization on GPU + self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) + self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) self.attn = MultiheadCrossAttention( device=device, @@ -110,10 +102,16 @@ def __init__( ) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) - def forward(self, x: torch.Tensor, data: torch.Tensor): + def forward(self, x: torch.Tensor, data: torch.Tensor, device: torch.device): with torch.no_grad(): + # Use the to() method to move the input tensors to the specified device + x = x.to(device) + data = data.to(device) + + # Normalize input tensors and pass them through the attention and MLP layers x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) + return x class SimplePerceiver(nn.Module): From cbbf4058a14129e72d3ce12a1eaee0b9e7a605fc Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Fri, 30 Dec 2022 13:23:27 +0000 Subject: [PATCH 08/13] Add back cuda check --- point_e/models/perceiver.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index a9e4299..76fdc3e 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -86,10 +86,15 @@ def __init__( if data_width is None: data_width = width - # Use the FusedLayerNorm module for faster layer normalization on GPU - self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) - self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) + if device.type == "cuda": + # Use the FusedLayerNorm module for faster layer normalization on GPU + self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) + self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) + else: + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) self.attn = MultiheadCrossAttention( device=device, From 816f343c03347c2f8f0929bbcf8c2ea44556cd54 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Fri, 30 Dec 2022 14:29:16 +0000 Subject: [PATCH 09/13] Revert to std layernorm --- point_e/models/perceiver.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 76fdc3e..7d0ea78 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -86,15 +86,9 @@ def __init__( if data_width is None: data_width = width - if device.type == "cuda": - # Use the FusedLayerNorm module for faster layer normalization on GPU - self.ln_1 = FusedLayerNorm(width, device=device, dtype=dtype) - self.ln_2 = FusedLayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = FusedLayerNorm(width, device=device, dtype=dtype) - else: - self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) - self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) - self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) self.attn = MultiheadCrossAttention( device=device, From 9aff88e195ce4904b18a060a68baf8b928acc5d4 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Fri, 30 Dec 2022 14:30:16 +0000 Subject: [PATCH 10/13] Remove apex from setup and import --- point_e/models/perceiver.py | 1 - setup.py | 1 - 2 files changed, 2 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 7d0ea78..75db4ba 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -5,7 +5,6 @@ from .checkpoint import checkpoint from .transformer import MLP, init_linear from typing import Optional -from apex.normalization import FusedLayerNorm class MultiheadCrossAttention(nn.Module): diff --git a/setup.py b/setup.py index d3af888..3c58cbf 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "filelock", "Pillow", "torch", - "apex @ git+https://github.com/NVIDIA/apex.git", "fire", "humanize", "requests", From fadbf06abf5155184f9072030852ca2187e8917d Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Mon, 2 Jan 2023 10:39:50 +0000 Subject: [PATCH 11/13] Remove device to pass until figured out the logic --- point_e/models/perceiver.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 75db4ba..125e919 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -100,12 +100,8 @@ def __init__( ) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) - def forward(self, x: torch.Tensor, data: torch.Tensor, device: torch.device): + def forward(self, x: torch.Tensor, data: torch.Tensor): with torch.no_grad(): - # Use the to() method to move the input tensors to the specified device - x = x.to(device) - data = data.to(device) - # Normalize input tensors and pass them through the attention and MLP layers x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) From 51a3f95d9dd6501165e524d11806cc3b171ae588 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Mon, 2 Jan 2023 10:41:34 +0000 Subject: [PATCH 12/13] Add back .to() with torch.device directly --- point_e/models/perceiver.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 125e919..bab7b83 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -102,6 +102,10 @@ def __init__( def forward(self, x: torch.Tensor, data: torch.Tensor): with torch.no_grad(): + # Use the to() method to move the input tensors to the specified device + x = x.to(torch.device) + data = data.to(torch.device) + # Normalize input tensors and pass them through the attention and MLP layers x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) From 1a7a7eae046d886d0c36d516f7e978a57c7baee9 Mon Sep 17 00:00:00 2001 From: Robert McMenemy Date: Mon, 2 Jan 2023 20:29:04 +0000 Subject: [PATCH 13/13] remove .to() --- point_e/models/perceiver.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index bab7b83..125e919 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -102,10 +102,6 @@ def __init__( def forward(self, x: torch.Tensor, data: torch.Tensor): with torch.no_grad(): - # Use the to() method to move the input tensors to the specified device - x = x.to(torch.device) - data = data.to(torch.device) - # Normalize input tensors and pass them through the attention and MLP layers x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x))