[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor)#326
[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor)#326weifengpy wants to merge 10 commits intometa-pytorch:mainfrom
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
| self.assertTrue( | ||
| isinstance(colwise_param, DTensor) | ||
| and isinstance( | ||
| colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor |
There was a problem hiding this comment.
editted: without this PR, torch.chunk returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor
There was a problem hiding this comment.
Can you explain where the bf16 came from?
There was a problem hiding this comment.
correct my word to be accurate: without this PR, torch.chunk returns plain Tensor (can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor
| torch.ops.aten.as_strided.default, | ||
| torch.ops.aten._to_copy.default, | ||
| torch.ops.aten._pin_memory.default, | ||
| torch.ops.aten.split.Tensor, |
There was a problem hiding this comment.
aten.split is from torch.chunk, when calling from distribute_tensor during TP init
editted: @awgu curious if you still remember the reason to return Tensor from torch.chunk instead of WeightWithDynamicFloat8CastTensor. Is it for padding? any concerns if I prefer torch.chunk to returning WeightWithDynamicFloat8CastTensor ?
There was a problem hiding this comment.
@awgu curious if you still remember the reason to return bf16 from torch.chunk.
I thought that dtype and whether is WeightWithDynamicFloat8CastTensor are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor or not?
I think originally I only added the ops that I saw I needed. Adding aten.split and aten.clone seems okay to me.
There was a problem hiding this comment.
whether is WeightWithDynamicFloat8CastTensor or not
exactly, WeightWithDynamicFloat8CastTensor or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor
I think originally I only added the ops that I saw I needed
changing torch.chunk affects both TP and FSDP2. will double check FSDP2 after the change
| elif isinstance(out, DTensor) and isinstance( | ||
| out._local_tensor, Float8Tensor | ||
| ): | ||
| out._local_tensor._scale = scale |
There was a problem hiding this comment.
not sure about this change yet. just want to have someting sketchy to discuss first
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
draft this PR for discussion, before having something landable
we see 2 problems in float8 all-gather FSDP2 + TP
weight, but expect all-reduce only forinputcrux is how we dispatch
torch.chunk, which is called fromdistribute_tensorfor TP inittorch.chunkreturnsTensor. FSDP2 happens after TP, thus only seeFloat8Linear(weight=DTensor(_local_tensor=Tensor))torch.chunkreturnsWeightWithDynamicFloat8CastTensorprofiler trace without this PR: AR (all-reduce) for input -> AG (all-gather) -> 4 ARs for wq,k,v,o -> 1 AR for input. 4 ARs for wq,k,v,o should not happen if we precompute amax/scales for
model.parameters()afteropt.step()