From cffbf1e2a60f28ff7ba9e2aeb2b01d6da6f129ef Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Jul 2024 09:45:54 -0700 Subject: [PATCH] update float8 integration after UX changes Summary: float8_experimental landed various BC-breaking UX changes last week. This PR updates torchtitan to work with the version of float8_experimental after https://github.com/pytorch-labs/float8_experimental/pull/332 Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 40 ++++++++++++++----------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 770531d50a..557fca64c8 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,7 +12,6 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import contextlib import functools from typing import Optional @@ -24,20 +23,6 @@ from torchtitan.logging_utils import logger -@contextlib.contextmanager -def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): - import float8_experimental.config as config - - prev = config.enable_fsdp_fp8_all_gather - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather - try: - yield - finally: - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = prev - - @functools.lru_cache(None) def is_sm90_or_later(): # Float8 is only supported on H100+ GPUs @@ -63,21 +48,26 @@ def maybe_build_fp8_linear( ) return try: - from float8_experimental.float8_linear import TensorScalingType - from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, + from float8_experimental import ( + CastConfig, + convert_to_float8_training, + Float8LinearConfig, + ScalingType, ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) - with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): - swap_linear_with_float8_linear( - model, - scaling_type_w=TensorScalingType.DYNAMIC, - skip_fqn_list=["output"], - ) + float8_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ) + convert_to_float8_training( + model, + config=float8_config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" ) @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( "Skipped precomputing fp8 scales because SM90 or later is not available", ) return - from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + from float8_experimental import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model)