From 0a99d09f07105bc6979b3f7fe1b49949f5c899d0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Jul 2024 21:58:44 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- test/test_dtensor.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 354f831..ab1890d 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -248,6 +248,12 @@ def test_fp8_mlp_tensor_parallelism_base( x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) tp_out = tp_model(x_fp32_tp_input) + assert ( + tp_model.ffn.w1.weight.requires_grad + ), "Expecting gradients to be enabled for TP model." + assert tp_out.requires_grad, "Expecting gradients to be enabled for TP model." + awaited_out = tp_out.wait() + assert awaited_out.requires_grad, "Expecting awaited out to require gradients" tp_out.sum().backward() sp_out = sp_model(x_fp32_sp_input) sp_out.sum().backward() @@ -281,12 +287,12 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): # cases in the main func. device_mesh = setup_distributed() tests = [ - test_scaled_mm, - test_fp8_redistribute, - test_dtensor_cast_to_fp8, - test_dtensor_fp8_autograd, + # test_scaled_mm, + # test_fp8_redistribute, + # test_dtensor_cast_to_fp8, + # test_dtensor_fp8_autograd, test_fp8_mlp_tensor_parallelism_base, - test_fp8_mlp_tensor_parallelism_compile, + # test_fp8_mlp_tensor_parallelism_compile, ] for test in tqdm(tests, desc="Running tests"): From ea89bff956036a635e5d44d7fdf52991064eccf3 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Jul 2024 22:05:12 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- test/test_dtensor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index ab1890d..31fce1e 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -287,12 +287,12 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): # cases in the main func. device_mesh = setup_distributed() tests = [ - # test_scaled_mm, - # test_fp8_redistribute, - # test_dtensor_cast_to_fp8, - # test_dtensor_fp8_autograd, + test_scaled_mm, + test_fp8_redistribute, + test_dtensor_cast_to_fp8, + test_dtensor_fp8_autograd, test_fp8_mlp_tensor_parallelism_base, - # test_fp8_mlp_tensor_parallelism_compile, + test_fp8_mlp_tensor_parallelism_compile, ] for test in tqdm(tests, desc="Running tests"):