diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 1820d17c6..ce60313d6 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -107,7 +107,16 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s algorithms/archived_paper_baselines/adamw/jax/submission.py -w wmt -t algorithms/archived_paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false --data_bucket mlcommons-data --logs_bucket mlcommons-runs --data_bucket mlcommons-data --logs_bucket mlcommons-runs + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s algorithms/archived_paper_baselines/adamw/jax/submission.py -w wmt -t algorithms/archived_paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false --data_bucket mlcommons-data --logs_bucket mlcommons-runs --data_bucket mlcommons-data --logs_bucket mlcommons-runs + finewebedu_lm_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fineweb_edu_10B -f jax -s algorithms/archived_paper_baselines/adamw/jax/submission.py -w finewebedu_lm -t algorithms/archived_paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false --data_bucket mlcommons-data --logs_bucket mlcommons-runs --data_bucket mlcommons-data --logs_bucket mlcommons-runs fastmri_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -181,3 +190,12 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s algorithms/archived_paper_baselines/adamw/pytorch/submission.py -w wmt -t algorithms/archived_paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false --data_bucket mlcommons-data --logs_bucket mlcommons-runs --data_bucket mlcommons-data --logs_bucket mlcommons-runs + finewebedu_lm_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fineweb_edu_10B -f pytorch -s algorithms/archived_paper_baselines/adamw/pytorch/submission.py -w finewebedu_lm -t algorithms/archived_paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false --data_bucket mlcommons-data --logs_bucket mlcommons-runs --data_bucket mlcommons-data --logs_bucket mlcommons-runs diff --git a/.gitignore b/.gitignore index 7d35f0ccc..916a29ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algoperf/_version.py +algoperf/_version.py \ No newline at end of file diff --git a/README.md b/README.md index f8e6763b4..71595e11b 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ The MLCommonsβ„’ **AlgoPerf: Training Algorithms benchmark** is designed to find When training neural nets, practitioners face many critical yet often opaque decisions: What optimizer to choose? How should its learning rate be tuned? What learning rate schedule should be used? These choices can make or break training, yet the community has lacked a clear, standardized way to identify the state of the art. Unlike benchmarks focused on hardware or model architecture, AlgoPerf isolates the **training algorithm** itself, which includes the optimizer, regularization, data selection, and hyperparameters like the learning rate schedule. By standardizing the benchmark process, AlgoPerf offers a meaningful apples-to-apples comparison of training algorithms and follows the following **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (8x NVIDIA V100 GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (4x A100 (40GB) GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](/docs/DOCUMENTATION.md#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance, using [**performance profiles**](/docs/DOCUMENTATION.md#benchmark-score-using-performance-profiles), across all workloads to ensure general-purpose algorithms. - πŸ“¦ **Fully-Specified Algorithms:** Submissions must be complete procedures and thus hyperparameter tuning is treated as part of the algorithm. Submissions can either provide a search space for automated tuning ([**External tuning ruleset**](/docs/DOCUMENTATION.md#external-tuning-ruleset)) or be hyperparameter-free ([**Self-tuning ruleset**](/docs/DOCUMENTATION.md#self-tuning-ruleset)) with any tuning done automatically and "on the clock". This measures an algorithm's _total_ practical cost and provides practitioners with a complete method, eliminating the guesswork of how to apply it. diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..af05111cd 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,14 +5,16 @@ """ import os -from typing import Sequence, Tuple +from typing import Optional, Sequence, Tuple import numpy as np +import orbax.checkpoint as ocp import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint +from orbax.checkpoint.type_handlers import NumpyHandler from tensorflow.io import gfile # pytype: disable=import-error from algoperf import spec @@ -30,6 +32,51 @@ ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) + + def maybe_restore_checkpoint( framework: str, optimizer_state: spec.OptimizerState, diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 908ef0f27..26a351bb4 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -44,6 +44,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..706a4fffd 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -20,6 +20,8 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: + torch.set_float32_matmul_precision('high') + use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') @@ -27,7 +29,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init( + use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True +) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -39,7 +43,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..f053fd828 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -110,12 +110,12 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=2 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, ) - dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) + dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE) return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/finewebedu_lm/__init__.py b/algoperf/workloads/finewebedu_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py new file mode 100644 index 000000000..d08e9b7bf --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -0,0 +1,397 @@ +""" +Originally based on code from the NanoDO repository under the Apache 2.0 license: +https://github.com/google-deepmind/nanodo +""" + +import dataclasses +from functools import partial + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +@dataclasses.dataclass +class ModelConfig: + """Hyper-parameters for Transformer decoder-only.""" + + model_dim: int # model/embed dim = qkv dim + num_heads: int # num attention heads + seq_len: int # max context/sequence length + num_layers: int # number of transformer block layers + vocab_size: int # vocab size + expanded_model_dim: int # FF inner dimension + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True # Whether to tie input and output embed + qknorm_epsilon: float = 1e-6 + + dtype: jnp.dtype = jnp.float32 + attention_init: nn.initializers.Initializer = nn.initializers.normal( + stddev=0.02 + ) + linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + + def __post_init__(self): + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.num_layers) + ) + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: ModelConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + linear = partial( + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.expanded_model_dim * 2 / 3 + hidden_dim = cfg.multiple_of * ( + (cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = nn.Dense( + cfg.model_dim, + use_bias=False, + dtype=cfg.dtype, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + )(x_BxLxF) + return x_BxLxD + + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack( + [jnp.cos(freqs)[None, :, None, :], jnp.sin(freqs)[None, :, None, :]], + axis=3, + ) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack( + [ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1], + ], + axis=-1, + ) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: ModelConfig + + def setup(self): + cfg = self.cfg + assert cfg.model_dim % cfg.num_heads == 0, ( + f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + ) + self.Dh = cfg.model_dim // cfg.num_heads + self.eps = cfg.qknorm_epsilon + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.num_heads, self.Dh), + kernel_init=cfg.attention_init, + use_bias=False, + dtype=cfg.dtype, + ) + self.multilinear_query = self.multilinear(name='query') + self.multilinear_key = self.multilinear(name='key') + self.multilinear_value = self.multilinear(name='value') + # See Henry et al. (2020) "Query Key Normalization for Transformers" + seq_len = cfg.seq_len + attn_scale0 = jnp.log2(seq_len**2 - seq_len) + self.attn_scale = self.param( + 'attn_scale', nn.initializers.constant(attn_scale0), () + ) + self.output_projection = nn.DenseGeneral( + features=cfg.model_dim, + name='attn_out_proj', + # axis=(-2, -1), # + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Apply QK normalization + q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps + + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: ModelConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: ModelConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.model_dim, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.vocab_size, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name='output_proj', + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.seq_len: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + last_token_id = y_BxL[:, -1] + # Prevent predicting the same token consecutively + next_token_logits = next_token_logits.at[ + jnp.arange(len(last_token_id)), last_token_id + ].set(float('-inf')) + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = ModelConfig( + model_dim=128, + num_heads=4, + seq_len=L, + num_layers=2, + vocab_size=256, + expanded_model_dim=4 * 128, + ) + model = TransformerDo(cfg) + + # Print model info + print('\nModel Configuration:') + print(f' - Model dimension (D): {cfg.model_dim}') + print(f' - Number of heads (H): {cfg.num_heads}') + print(f' - Max sequence length (L): {cfg.seq_len}') + print(f' - Number of layers (N): {cfg.num_layers}') + print(f' - Vocabulary size (V): {cfg.vocab_size}') + print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}') + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32 + ) + + # Initialize model parameters + print('\nInitializing model parameters...') + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f'Total parameters: {param_count:,}') + + # Make a prediction (forward pass) + print('\nRunning forward pass...') + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print( + f'\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)' + ) + print(f'Output data type: {logits.dtype}') + + # Print sample logits (first 5 positions of the first sequence) + print('\nSample logits (first sequence, first 5 positions, first 5 values):') + for position in range(min(5, L)): + print(f' Position {position}: {logits[0, position, :5]}') + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + # Test the predict function + print('\nTesting predict function...') + # Use a shorter + short_seq = x_BxL[:, :10] + print(f'Input sequence shape: {short_seq.shape}') + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + print('\nDone!') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py new file mode 100644 index 000000000..ee4cffbbc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -0,0 +1,169 @@ +"""LM workload implemented in Jax.""" + +from typing import Any, Dict, Optional, Tuple + +import jax +import jax.numpy as jnp + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig, + TransformerDo, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + """LM JAX workload.""" + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): + """Build an input queue using pre-cached FineWeb dataset.""" + del cache, repeat_final_dataset + ds = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=global_batch_size, + num_batches=num_batches, + ) + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + # Initialize NanoDO transformer model + cfg = ModelConfig( + model_dim=self._emb_dim, # embedding dim + num_heads=self._n_heads, # num heads + seq_len=self._seq_len, + num_layers=self._n_layers, # num layers + vocab_size=self._vocab_size, + expanded_model_dim=self._mlp_dim, # feedforward dim + dtype=jnp.float32, + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(self._model.init)( + {'params': params_rng}, jnp.ones(input_shape, jnp.int32) + ) + params = variables['params'] + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + params = jax_sharding_utils.replicate(params) + model_state = None + return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate + inputs = batch['inputs'] + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy. + + Args: + label_batch: categorical targets [batch, length] int array. + logits_batch: [batch, length, num_classes] float array. + mask_batch: weights array of shape [batch, length]. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 2d array of per-example losses} + """ + if logits_batch.ndim != label_batch.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits_batch.shape} logits and ' + f'{label_batch.shape} targets.' + ) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits_batch, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, label_batch[..., None], axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - Ξ±) * log_p[target] - Ξ± * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * ( + confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) + ) + if mask_batch is not None: + per_example_losses = mask_batch * per_example_losses + n_valid_examples = mask_batch.sum() + else: + n_valid_examples = label_batch.shape[0] * label_batch.shape[1] + summed_loss = per_example_losses.sum() + return { + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py new file mode 100644 index 000000000..edee8318c --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -0,0 +1,344 @@ +""" +Originally based on the plainLM codebase: +https://github.com/Niccolo-Ajroldi/plainLM +under the MIT license https://github.com/Niccolo-Ajroldi/plainLM/blob/main/LICENSE. +""" + +import math +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +@dataclass +class ModelConfig: + model_dim: int + num_heads: int + seq_len: int + num_layers: int + vocab_size: int + expanded_model_dim: int + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + qknorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True + + +class MLP(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + nn.init.normal_(self.fc1.weight, std=0.02) + nn.init.normal_(self.fc2.weight, std=0.02) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1 +): + inv_freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.device('cpu')) + / dim + ) + ) + t = ( + torch.arange(end, dtype=torch.float32, device=inv_freqs.device) + / condense_ratio + ) + freqs = torch.outer(t, inv_freqs).float() + return torch.stack( + [torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], + dim=4, + ) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.model_dim % cfg.num_heads == 0 + self.dim = cfg.model_dim + self.n_heads = cfg.num_heads + self.head_dim = cfg.model_dim // cfg.num_heads + + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + # Split into Q, K, V sections + wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) + for w in [wq, wk, wv]: + nn.init.normal_(w, std=0.02) + nn.init.normal_(self.w_out.weight, std=0.02) + + self.eps = cfg.qknorm_epsilon # e.g., 1e-6 + seq_len = cfg.seq_len + attn_scale0 = math.log2(seq_len**2 - seq_len) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + k = k.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + v = v.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis + ) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + # Apply QK normalization + q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps + k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps + q *= self.attn_scale + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True, scale=1.0 + ) # (bsz, nh, seqlen, h_dim) + out = ( + out.transpose(1, 2).contiguous().view(bsz, seqlen, d) + ) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.mlp = MLP( + dim=cfg.model_dim, + hidden_dim=cfg.expanded_model_dim, + multiple_of=cfg.multiple_of, + ) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.n_layers = cfg.num_layers + self.cfg = cfg + head_dim = cfg.model_dim // cfg.num_heads + assert cfg.model_dim % cfg.num_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.num_layers)] + ) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer( + 'freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0 : cfg.seq_len], + persistent=False, + ) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x, targets=None): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.model_dim // self.cfg.num_heads + new_freqs = precompute_freqs_cis( + head_dim, max(L, self.cfg.seq_len), 500000 + ) + self.register_buffer( + 'freqs_cis', new_freqs[0 : max(L, self.cfg.seq_len)], persistent=False + ) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 + ) + return out, loss + return out + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith('fc2.weight') or n.endswith( + 'w_out.weight' + ): # mlp/glu output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if ( + self.lm_head.weight is not self.embed_tokens.weight + ): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print('Initializing transformer model and running forward pass...') + + seq_length = 1024 + + # Define model configuration + config = ModelConfig( + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + model_dim=1024, # Embedding dimension + expanded_model_dim=4.0, # MLP expansion factor + num_layers=12, # Number of transformer layers + num_heads=8, # Number of attention heads + rmsnorm_epsilon=1e-6, # RMSNorm epsilon + tie_embeddings=True, # Tie embedding and output weights + ) + + # Instantiate the model + model = Transformer(config) + print(f'Model has {model.count_params():,} parameters.') + + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f'Running forward pass with input shape: {input_ids.shape}') + logits = model(input_ids) + print(f'Output logits shape: {logits.shape}') + + # Run prediction + print('Running prediction...') + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f'Original input shape for prediction: {original_input.shape}') + print(f'Predicted IDs shape: {predicted_ids.shape}') + print(f'Predicted IDs: {predicted_ids}') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py new file mode 100644 index 000000000..a25ca334a --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -0,0 +1,221 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple + +import jax +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig, + Transformer, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + if hasattr(self, '_model'): + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None + + torch.manual_seed(rng[0]) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + model_dim=self._emb_dim, # Model dimension + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + self._model = Transformer(cfg) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del model_state, rng, update_batch_norm, dropout_rate + model = params + + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + + logits = model(inputs) + + return logits, None + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: + """Build an input queue for the given split.""" + del cache, repeat_final_dataset + local_batch_size = global_batch_size // N_GPUS + loader = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=local_batch_size, + num_batches=num_batches, + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) + dtype = torch.int32 + for batch in loader: + batch = { + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor( + batch['targets'], device=DEVICE, dtype=torch.int64 + ), + 'weights': torch.tensor( + batch['weights'], device=DEVICE, dtype=torch.float32 + ) + if batch['weights'] is not None + else None, + } + yield batch + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: spec.Tensor, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute weighted cross-entropy loss. + + Args: + label_batch: Target labels of shape [batch, length] (int). + logits_batch: Predicted logits of shape [batch, length, vocab_size] (float). + mask_batch: Optional weights of shape [batch, length] (float). Used to mask + out padding tokens or weight examples differently. If None, all examples + are weighted equally. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + Dictionary containing: + - 'summed': Scalar tensor with the sum of all weighted losses. + - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. + - 'per_example': Tensor of shape [batch, length] with individual losses per example. + """ + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), dtype=torch.float32, device=label_batch.device + ) + ) + + return { + 'summed': per_example_losses.sum(), + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/input_pipeline.py b/algoperf/workloads/finewebedu_lm/input_pipeline.py new file mode 100644 index 000000000..3007371fc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/input_pipeline.py @@ -0,0 +1,153 @@ +"""Input pipeline for a LM dataset.""" + +import functools +import os +from typing import Optional + +import jax +import tensorflow as tf + +from algoperf import data_utils + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = tf.constant(-1, dtype=tf.int64) + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 1024 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: + """ + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size + ), + ds, + ) + + return iter(it) + + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + """Load preprocessed TF dataset.""" + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + if split == 'train': + ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=False) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + return ds diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py new file mode 100644 index 000000000..59f70380f --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -0,0 +1,193 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Iterator, Optional + +import jax +import numpy as np +from absl import flags + +from algoperf import spec + +FLAGS = flags.FLAGS + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 1024 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 + warmup_factor: float = 0.1 + + def __init__(self) -> None: + super().__init__() + self._param_shapes = None + self._param_types = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] <= self.validation_target_value + + @property + def validation_target_value(self) -> float: + return 22.2995 # Target perplexity + + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return True # No test targets + + @property + def test_target_value(self) -> float: + return None # No test targets + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + return 8_749_870 # sequences of 1024 tokens each + + @property + def num_eval_train_examples(self) -> int: + return 10_000 # Subset for evaluation. + + @property + def num_validation_examples(self) -> int: + return 100_000 # sequences + + @property + def num_test_examples(self) -> int: + return 0 + + @property + def eval_batch_size(self) -> int: + return 256 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + return 31_967 # 8.9 hours + + @property + def eval_period_time_sec(self) -> int: + return 2_571 # approximately 25 evals + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 72_000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, Any]]: + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + eval_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, float]: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, split, data_dir, global_batch_size, num_batches=num_batches + ) + + eval_metrics = {} + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + metrics = self._eval_batch(params, eval_batch, model_state, rng) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']).item() + return eval_results + + @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + + @abc.abstractmethod + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..cd476e37f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -3,10 +3,13 @@ import contextlib import functools import itertools +import json import math import os import random -from typing import Dict, Iterator, Optional, Tuple +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union import numpy as np import torch @@ -14,7 +17,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms -from torchvision.datasets.folder import ImageFolder +from torchvision.datasets.folder import ( + IMG_EXTENSIONS, + ImageFolder, + default_loader, +) import algoperf.random_utils as prng from algoperf import data_utils, param_utils, pytorch_utils, spec @@ -28,6 +35,100 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +class CachedImageFolder(ImageFolder): + """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" + + def __init__( + self, + root: Union[str, Path], + cache_file: Optional[Union[str, Path]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, + rebuild_cache: bool = False, + cache_build_timeout_minutes: int = 30, + ): + self.root = os.path.abspath(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None + + # Default cache location: .cache_index.json in the root directory + if cache_file is None: + cache_file = os.path.join(self.root, '.cache_index.json') + self.cache_file = cache_file + + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + cache_exists = os.path.exists(self.cache_file) + needs_rebuild = rebuild_cache or not cache_exists + + if needs_rebuild: + # We only want one process to build the cache + # and others to wait for it to finish. + if rank == 0: + self._build_and_save_cache(is_valid_file, allow_empty) + if is_distributed: + self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) + dist.barrier() + + self._load_from_cache() + + self.targets = [s[1] for s in self.samples] + self.imgs = self.samples + + def _wait_for_cache(self, timeout_minutes: int): + """Poll for cache file to exist.""" + timeout_seconds = timeout_minutes * 60 + poll_interval = 5 + elapsed = 0 + + while not os.path.exists(self.cache_file): + if elapsed >= timeout_seconds: + raise TimeoutError( + f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' + ) + time.sleep(poll_interval) + elapsed += poll_interval + + def _load_from_cache(self): + """Load classes and samples from cache file.""" + with open(os.path.abspath(self.cache_file), 'r') as f: + cache = json.load(f) + self.classes = cache['classes'] + self.class_to_idx = cache['class_to_idx'] + # Convert relative paths back to absolute + self.samples = [ + (os.path.join(self.root, rel_path), idx) + for rel_path, idx in cache['samples'] + ] + + def _build_and_save_cache(self, is_valid_file, allow_empty): + """Scan filesystem, build index, and save to cache.""" + self.classes, self.class_to_idx = self.find_classes(self.root) + self.samples = self.make_dataset( + self.root, + class_to_idx=self.class_to_idx, + extensions=self.extensions, + is_valid_file=is_valid_file, + allow_empty=allow_empty, + ) + + cache = { + 'classes': self.classes, + 'class_to_idx': self.class_to_idx, + 'samples': [ + (os.path.relpath(path, self.root), idx) for path, idx in self.samples + ], + } + with open(os.path.abspath(self.cache_file), 'w') as f: + json.dump(cache, f) + + def imagenet_v2_to_torch( batch: Dict[str, spec.Tensor], ) -> Dict[str, spec.Tensor]: @@ -119,8 +220,10 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config + dataset = CachedImageFolder( + os.path.join(data_dir, folder), + transform=transform_config, + cache_file='.imagenet_{}_cache_index.json'.format(split), ) if split == 'eval_train': @@ -145,16 +248,16 @@ def _build_dataset( sampler = data_utils.DistributedEvalSampler( dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False ) - dataloader = torch.utils.data.DataLoader( dataset, batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS, pin_memory=True, drop_last=is_train, persistent_workers=is_train, + prefetch_factor=N_GPUS, ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( @@ -163,7 +266,6 @@ def _build_dataset( use_mixup=use_mixup, mixup_alpha=0.2, ) - return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..06df7ea75 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -5,7 +5,6 @@ and https://github.com/lucidrains/vit-pytorch. """ -import math from typing import Any, Optional, Tuple, Union import torch @@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: value_layer = self.transpose_for_scores(self.value(x)) query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) + # Use built-in scaled_dot_product_attention (Flash Attention when available) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=dropout_rate if self.training else 0.0, + ) - context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) context_layer = context_layer.view(new_context_layer_shape) diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4dd4717e9..e90300a36 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -113,6 +113,10 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'finewebedu_lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', @@ -152,6 +156,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'finewebedu_lm', 'ogbg', 'wmt', ] diff --git a/algorithms/archived_paper_baselines/adamw/jax/submission.py b/algorithms/archived_paper_baselines/adamw/jax/submission.py index b8ea5d30a..c0ffe7601 100644 --- a/algorithms/archived_paper_baselines/adamw/jax/submission.py +++ b/algorithms/archived_paper_baselines/adamw/jax/submission.py @@ -254,6 +254,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 32 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 761ce5cb1..7c50ff4ff 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,6 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index e199fb2b9..061acc3de 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,6 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..6d2808593 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -340,12 +340,6 @@ def update_params( dropout_rate, ) ) - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step - ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -394,6 +388,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'finewebedu_lm': + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..92027887f 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -300,28 +299,6 @@ def update_params( optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 - ) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step, - ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) @@ -372,6 +349,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..b7adf6cd6 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + step_hint = 0.75 * step_hint + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'finewebedu_lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..b881747d8 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -0,0 +1,403 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step'] + ) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + step_hint = step_hint * 0.75 + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'finewebedu_lm': + return 64 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..ce0f75623 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 + } +] \ No newline at end of file diff --git a/datasets/README.md b/dataset/README.md similarity index 96% rename from datasets/README.md rename to dataset/README.md index 1aeb83239..d08f4cf67 100644 --- a/datasets/README.md +++ b/dataset/README.md @@ -16,6 +16,7 @@ - [LibriSpeech](#librispeech) - [Training SPM Tokenizer](#training-spm-tokenizer) - [Preprocessing Script](#preprocessing-script) + - [Fineweb-edu 10B](#fineweb-edu-10b) ## General Setup @@ -24,7 +25,7 @@ This document provides instructions on downloading and preparing all datasets ut *TL;DR to download and prepare a dataset, run `dataset_setup.py`:* ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ -- -- @@ -88,7 +89,7 @@ By default, a user will be prompted before any files are deleted. If you do not From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --ogbg ``` @@ -124,7 +125,7 @@ In total, it should contain 13 files (via `find -type f | wc -l`) for a total of From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` @@ -194,7 +195,7 @@ you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --fastmri \ --fastmri_knee_singlecoil_train_url '' \ @@ -229,13 +230,13 @@ In total, it should contain 1280 files (via `find -type f | wc -l`) for a total Register on and follow directions to obtain the URLS for the ILSVRC2012 train and validation images. -The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/dataset/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. The ImageNet data pipeline differs between the PyTorch and JAX workloads. Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --imagenet \ --temp_dir $DATA_DIR/tmp \ @@ -349,7 +350,7 @@ In total, it should contain 20 files (via `find -type f | wc -l`) for a total of ### Criteo1TB ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --criteo1tb @@ -378,7 +379,7 @@ In total, it should contain 885 files (via `find -type f | wc -l`) for a total o To download, train a tokenizer and preprocess the librispeech dataset: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --librispeech @@ -453,3 +454,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 dataset/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file diff --git a/datasets/dataset_setup.py b/dataset/dataset_setup.py similarity index 89% rename from datasets/dataset_setup.py rename to dataset/dataset_setup.py index e110930cd..de5e9d271 100644 --- a/datasets/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -73,8 +73,11 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer + +import datasets as hf_datasets +from transformers import AutoTokenizer import functools import os @@ -82,6 +85,7 @@ import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -121,6 +125,9 @@ flags.DEFINE_boolean( 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' ) +flags.DEFINE_boolean( + 'finewebedu', False, 'If --all=false, whether or not to download FineWebEdu.' +) flags.DEFINE_boolean( 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' ) @@ -194,6 +201,9 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean( + 'skip_tokenization', False, 'Skip Fineweb-edu tokenization.' +) FLAGS = flags.FLAGS @@ -767,6 +777,102 @@ def download_wmt(data_dir): ) +def download_finewebedu( + data_dir, tmp_dir=None, skip_download=False, skip_tokenization=False +): + """Download FineWebEdu-10B.""" + + if not skip_download: + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = ( + os.path.join(tmp_dir, 'lm') + if tmp_dir is not None + else os.path.expanduser('~/.cache/huggingface/datasets') + ) + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ['TMPDIR'] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir, + ) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f'Vocab size of lm_tokenizer = {len(lm_tokenizer)}') + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples['text']), + return_special_tokens_mask=False, + return_attention_mask=False, + ) + + lm_tokenizer.model_max_length = ( + 1e30 # prevent truncation during tokenization + ) + logging.info('Tokenizing...') + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score', + ], + batched=True, + batch_size=1024, + num_proc=8, + ) + + tokenized_dataset.save_to_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + else: + tokenized_dataset = hf_datasets.load_from_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + + # Convert to tensorflow_datasets.Dataset objects + tokenized_dataset = tokenized_dataset.to_tf_dataset() + + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, 'train')) + val_dataset.save(os.path.join(data_dir, 'val')) + + return + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -854,6 +960,12 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu( + data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization + ) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 99% rename from datasets/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index 1c216db46..878f10f2a 100644 --- a/datasets/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -14,7 +14,7 @@ from absl import logging from pydub import AudioSegment -from datasets import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..1cd676d2a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "finewebedu_lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) diff --git a/docs/DOCUMENTATION.md b/docs/DOCUMENTATION.md index f7ac5e659..49e738408 100644 --- a/docs/DOCUMENTATION.md +++ b/docs/DOCUMENTATION.md @@ -55,7 +55,7 @@ The **AlgoPerf: Training Algorithms benchmark** challenges participants to submi The benchmarking process follows these **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `8x NVIDIA V100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `4x NVIDIA A100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance across all workloads, using [**performance profiles**](#algoperf-benchmark-score-via-integrated-performance-profiles), to ensure general-purpose algorithms. - πŸ“¦ **Fully-Specified Algorithms:** Submissions must be [**complete procedures**](#submission-api) and thus hyperparameter tuning is treated as part of the algorithm. Depending on the [**ruleset**](#tuning-rulesets), submissions may use parallel tuning resources. This ensures that the benchmark measures the _total_ practical cost of a training algorithm and provides practitioners with a complete method, eliminating the guesswork of how to apply it. @@ -542,7 +542,7 @@ All officially scored runs will be performed on the same benchmarking hardware t This benchmarking hardware is chosen to be easily accessible via common cloud computing providers and will likely change with each iteration of the benchmark. The specs of the benchmarking hardware for this iteration of the benchmark are: -- 8Γ— NVIDIA V100 (16 GB) GPUs +- 4Γ— NVIDIA A100 (40 GB) GPUs - 240 GB in RAM - 2 TB in storage (for datasets). @@ -595,7 +595,7 @@ Furthermore, all submitters must sign the following agreements:
My machine only has one GPU. How can I use this repo? -> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 8Γ— NVIDIA V100 (16 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 8Γ— NVIDIA V100 GPUs with 16 GB of VRAM per card. +> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 4Γ— NVIDIA A100 (40 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 4Γ— NVIDIA A100 GPUs with 40 GB of VRAM per card.
diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..ae2f2c8fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,9 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] +full = [ + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", +] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package @@ -88,6 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -105,16 +108,15 @@ jax_cpu = [ jax_gpu = [ "jax[cuda12]==0.7.0", "algoperf[jax_core_deps]", - "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663 ] pytorch_cpu = [ - "torch==2.5.1", - "torchvision==0.20.1" + "torch==2.9.0", + "torchvision==0.24.0" ] pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", + "torch==2.9.0", + "torchvision==0.24.0", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. ############################################################################### diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..efe276a33 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -67,10 +67,16 @@ '', 'Optional comma seperated list of names of submissions to exclude from scoring.', ) +flags.DEFINE_string( + 'include_submissions', + '', + 'Optional comma seperated list of names of submissions to include from scoring.', +) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): + print(f' WORKLOAD: {workload}') validation_metric, validation_target = ( scoring_utils.get_workload_metrics_and_targets(workload, split='validation') ) @@ -119,9 +125,22 @@ def get_summary_df(workload, workload_df, include_test_split=False): axis=1, ) - summary_df['step_time (s)'] = ( - workload_df['accumulated_submission_time'] / workload_df['global_step'] - ).iloc[-1][-1] + # compute the step times + def delta(series): + return series.shift(1, fill_value=0) - series + + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) + step_intervals = delta(workload_df['global_step']) + if len(accumulated_time_intervals) < 2: + print( + f'WARNING: The number of evals may be too low to calculate reliable step time for {workload}' + ) + + summary_df['step_time (s)'] = np.median( + (accumulated_time_intervals / step_intervals).iloc[0] + ) + + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) # test metrics if include_test_split: @@ -157,7 +176,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ @@ -203,18 +222,25 @@ def main(_): ) as f: results = pickle.load(f) else: - for submission in os.listdir(FLAGS.submission_directory): + all_submission_dirs = list(os.listdir(FLAGS.submission_directory)) + if not FLAGS.include_submissions: + include_submissions = all_submission_dirs + else: + include_submissions = FLAGS.include_submissions.split(',') + + for submission in all_submission_dirs: print(submission) - if submission in FLAGS.exclude_submissions.split(','): - continue - experiment_path = os.path.join(FLAGS.submission_directory, submission) - df = scoring_utils.get_experiment_df(experiment_path) - results[submission] = df - summary_df = get_submission_summary(df) - with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' - ) as fout: - summary_df.to_csv(fout) + if submission not in FLAGS.exclude_submissions.split(',') and ( + submission in include_submissions + ): + experiment_path = os.path.join(FLAGS.submission_directory, submission) + df = scoring_utils.get_experiment_df(experiment_path) + results[submission] = df + summary_df = get_submission_summary(df) + with open( + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: + summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..0ba0d99ee 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,11 +24,15 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { "max_steps": 80000, "dataset": "librispeech" + }, + "finewebedu_lm" : { + "max_steps": 55000, + "dataset":"fineweb_edu_10B" } } diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..b557c4f40 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -256,7 +256,6 @@ def train_once( 'librispeech_conformer', 'ogbg', 'criteo1tb', - 'imagenet_vit', 'librispeech_deepspeech', ] eager_backend_workloads = [] @@ -266,6 +265,8 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'finewebedu_lm', + 'imagenet_vit', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -352,7 +353,6 @@ def train_once( log_dir, flags.FLAGS, hyperparameters ) workload.attach_metrics_logger(metrics_logger) - global_start_time = get_time() train_state['last_step_end_time'] = global_start_time @@ -783,7 +783,10 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = base_workload != 'finewebedu_lm' + pytorch_init( + USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads + ) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: @@ -799,6 +802,7 @@ def main(_): 'librispeech_deepspeech', 'imagenet_vit', 'criteo1tb', + 'finewebedu_lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py new file mode 100644 index 000000000..709e3125f --- /dev/null +++ b/tests/modeldiffs/lm/compare.py @@ -0,0 +1,892 @@ +""" +Test file to verify that JAX and PyTorch implementations produce identical outputs +when given the same weights and inputs. + +Tests are performed module-by-module: +1. RMSNorm +2. RoPE (Rotary Position Embeddings) +3. MLP +4. Attention +5. Transformer Block +6. Full Model +""" + +import os +import sys + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags, logging +from absl.testing import absltest, parameterized + +# Import JAX implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + CausalAttn, + Mlp, + TBlock, + TransformerDo, + apply_rope, + init_rope, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig as JaxModelConfig, +) + +# Import PyTorch implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + MLP, + Attention, + Block, + Transformer, + apply_rotary_emb_complex_like, + precompute_freqs_cis, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig as PyTorchModelConfig, +) + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +FLAGS(sys.argv) + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def assert_close(jax_output, torch_output, rtol=1e-5, atol=1e-6, name=''): + """Assert that JAX and PyTorch outputs are close.""" + jax_np = np.array(jax_output) + torch_np = torch_output.detach().cpu().numpy() + + mse = np.mean((jax_np - torch_np) ** 2) + max_diff = np.max(np.abs(jax_np - torch_np)) + + logging.info(f'\n{name} Comparison:') + logging.info(f' MSE: {mse:.8e}') + logging.info(f' Max Difference: {max_diff:.8e}') + + np.testing.assert_allclose( + jax_np, + torch_np, + rtol=rtol, + atol=atol, + err_msg=f'{name} outputs do not match', + ) + + +# ============================================================================ +# Test Functions (unchanged) +# ============================================================================ + + +def test_rmsnorm(): + """Test that RMSNorm produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RMSNorm') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + eps = 1e-6 + + # Create random input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + + # Initialize PyTorch RMSNorm + torch_norm = torch.nn.RMSNorm(dim, eps=eps) + torch_input = torch.tensor(np_input) + + # Initialize JAX RMSNorm (using Flax's RMSNorm from nanodo) + from flax import linen as nn + + flax_norm = nn.RMSNorm(epsilon=eps) + jax_input = jnp.array(np_input) + flax_params = flax_norm.init(jax.random.PRNGKey(0), jax_input) + + # Copy weights from PyTorch to JAX + with torch.no_grad(): + flax_params['params']['scale'] = jnp.array(torch_norm.weight.numpy()) + + # Forward pass + with torch.no_grad(): + torch_output = torch_norm(torch_input) + + jax_output = flax_norm.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='RMSNorm') + logging.info('βœ“ RMSNorm test passed') + + +def test_rope(): + """Test that RoPE produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RoPE (Rotary Position Embeddings)') + logging.info('=' * 70) + + batch_size, seq_len, n_heads, dim = 2, 16, 4, 128 + head_dim = dim // n_heads + + # Initialize RoPE + torch_freqs = precompute_freqs_cis(head_dim, seq_len, theta=500000) + jax_freqs = init_rope(dim, seq_len, n_heads) + + # Create random Q and K + np_q = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + np_k = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + + # PyTorch forward + torch_q = torch.tensor(np_q) + torch_k = torch.tensor(np_k) + with torch.no_grad(): + torch_q_rot, torch_k_rot = apply_rotary_emb_complex_like( + torch_q, torch_k, freqs_cis=torch_freqs + ) + + # JAX forward + jax_q = jnp.array(np_q) + jax_k = jnp.array(np_k) + jax_q_rot, jax_k_rot = apply_rope(jax_q, jax_k, jax_freqs) + + # Compare + assert_close(jax_q_rot, torch_q_rot, name='RoPE Q') + assert_close(jax_k_rot, torch_k_rot, name='RoPE K') + logging.info('βœ“ RoPE test passed') + + +def copy_mlp_params(pytorch_mlp, flax_params): + """Copy MLP parameters from PyTorch to JAX.""" + new_params = flax_params.copy() + + # Handle compiled models + if hasattr(pytorch_mlp, '_orig_mod'): + pytorch_mlp = pytorch_mlp._orig_mod + + # Copy fc1 and fc2 weights (transposed for JAX) + new_params['params']['Dense_0']['kernel'] = ( + pytorch_mlp.fc1.weight.detach().numpy().T + ) + new_params['params']['Dense_1']['kernel'] = ( + pytorch_mlp.fc2.weight.detach().numpy().T + ) + + return new_params + + +def test_mlp(): + """Test that MLP produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing MLP') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + hidden_dim = 1024 + + # Initialize PyTorch MLP + pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) + + # Initialize JAX MLP + cfg = JaxModelConfig( + model_dim=dim, + num_heads=4, + seq_len=128, + num_layers=2, + vocab_size=1000, + expanded_model_dim=hidden_dim, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_mlp = Mlp(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_mlp.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_mlp_params(pytorch_mlp, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_mlp(torch_input) + + jax_output = flax_mlp.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='MLP') + logging.info('βœ“ MLP test passed') + + +def copy_attention_params(pytorch_attn, flax_params): + """Copy attention parameters from PyTorch to JAX.""" + # Handle compiled models + if hasattr(pytorch_attn, '_orig_mod'): + pytorch_attn = pytorch_attn._orig_mod + + n_heads = pytorch_attn.n_heads + head_dim = pytorch_attn.head_dim + dim = pytorch_attn.dim + + # Split PyTorch's combined qkv weights + w_qkv = pytorch_attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + # Reshape for Flax's DenseGeneral format [D, H, Dh] + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + new_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': {'kernel': pytorch_attn.w_out.weight.detach().numpy().T}, + } + + return {'params': new_params} + + +def test_attention(): + """Test that Attention produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Attention') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + + # Initialize PyTorch Attention + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=1024, + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_attn = Attention(config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Attention + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=1024, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_attn = CausalAttn(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_attn.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_attention_params(pytorch_attn, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_attn(torch_input, freqs_cis) + + jax_output = flax_attn.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Attention') + logging.info('βœ“ Attention test passed') + + +def copy_block_params(pytorch_block, flax_params): + """Copy block parameters from PyTorch to JAX.""" + # Copy attention parameters + attn_params = copy_attention_params(pytorch_block.attn, {'params': {}})[ + 'params' + ] + + # Copy MLP parameters + pytorch_mlp = pytorch_block.mlp + mlp_params = { + 'Dense_0': {'kernel': pytorch_mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_mlp.fc2.weight.detach().numpy().T}, + } + + # Copy RMSNorm parameters + norm_params = { + 'attn_norm': {'scale': pytorch_block.attn_norm.weight.detach().numpy()}, + 'mlp_norm': {'scale': pytorch_block.mlp_norm.weight.detach().numpy()}, + } + + return { + 'params': { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': norm_params['attn_norm'], + 'RMSNorm_1': norm_params['mlp_norm'], + } + } + + +def test_block(): + """Test that Transformer Block produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Transformer Block') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + expand = 4.0 + + # Initialize PyTorch Block + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_block = Block(layer_id=0, cfg=config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Block + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_block = TBlock(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_block.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_block_params(pytorch_block, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_block(torch_input, freqs_cis) + + jax_output = flax_block.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Block') + logging.info('βœ“ Block test passed') + + +def copy_full_model_params(pytorch_model, flax_params, config): + """Copy all parameters from PyTorch model to JAX model.""" + # Handle tied embeddings case + if hasattr(pytorch_model, '_orig_mod'): + pytorch_model = pytorch_model._orig_mod + + n_layers = config.num_layers + n_heads = config.num_heads + dim = config.model_dim + head_dim = dim // n_heads + + new_params = {'params': {}} + + # Copy embedding weights + new_params['params']['embed'] = { + 'embedding': pytorch_model.embed_tokens.weight.detach().numpy() + } + + # Copy each transformer block + for i in range(n_layers): + pytorch_block = pytorch_model.layers[i] + + # Attention params + w_qkv = pytorch_block.attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + attn_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': { + 'kernel': pytorch_block.attn.w_out.weight.detach().numpy().T + }, + } + + # MLP params + mlp_params = { + 'Dense_0': {'kernel': pytorch_block.mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_block.mlp.fc2.weight.detach().numpy().T}, + } + + # Norm params + attn_norm = {'scale': pytorch_block.attn_norm.weight.detach().numpy()} + mlp_norm = {'scale': pytorch_block.mlp_norm.weight.detach().numpy()} + + # Assemble block params + block_key = f'blocks_{i}' + new_params['params'][block_key] = { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': attn_norm, + 'RMSNorm_1': mlp_norm, + } + + # Copy output norm + new_params['params']['out_ln'] = { + 'scale': pytorch_model.out_norm.weight.detach().numpy() + } + + # Handle output projection (tied or untied) + if not config.tie_embeddings: + new_params['params']['output_proj'] = { + 'kernel': pytorch_model.lm_head.weight.detach().numpy().T + } + + return new_params + + +def test_full_model(): + """Test that full Transformer model produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Full Transformer Model') + logging.info('=' * 70) + + batch_size, seq_len = 2, 32 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Forward pass + with torch.no_grad(): + torch_logits = pytorch_model(torch_tokens) + + jax_logits = jax_model.apply(jax_params, jax_tokens) + + # Compare + assert_close( + jax_logits, torch_logits, rtol=1e-4, atol=1e-5, name='Full Model' + ) + logging.info('βœ“ Full Model test passed') + + +def test_prediction(): + """Test that autoregressive generation produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Autoregressive Prediction') + logging.info('=' * 70) + + batch_size, seq_len = 1, 10 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + k = 5 # Number of tokens to predict + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len + k, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len + k, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Predict k tokens + with torch.no_grad(): + _, torch_predictions = pytorch_model.predict(torch_tokens, k=k) + + _, jax_predictions = jax_model.apply( + jax_params, jax_tokens, k, method=jax_model.predict + ) + + # Compare predictions + torch_pred_np = torch_predictions.cpu().numpy() + jax_pred_np = np.array(jax_predictions) + + logging.info(f'\nPyTorch predictions: {torch_pred_np[0]}') + logging.info(f'JAX predictions: {jax_pred_np[0]}') + + # Check if predictions match exactly + if np.array_equal(torch_pred_np, jax_pred_np): + logging.info('βœ“ Predictions match exactly!') + else: + matching = np.sum(torch_pred_np == jax_pred_np) + total = torch_pred_np.size + logging.info( + f'⚠ Predictions differ: {matching}/{total} tokens match ({matching / total * 100:.1f}%)' + ) + logging.info( + ' (Note: Small numerical differences can lead to different argmax results)' + ) + + +def test_initialization_statistics(): + """Verify initialization follows expected distributions.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Statistics') + logging.info('=' * 70) + + # Initialize models + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=1024, + num_layers=12, + vocab_size=50000, + expanded_model_dim=2048, + dtype=jnp.float32, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) + ) + + pytorch_cfg = PyTorchModelConfig( + vocab_size=50000, + seq_len=1024, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + logging.info('Initialization Statistics Check:') + + # Check embedding + jax_embed = jax_params['params']['embed']['embedding'] + torch_embed = pytorch_model.embed_tokens.weight.detach().numpy() + + logging.info('\nToken Embedding (should be ~0.02 std):') + logging.info( + f' JAX: mean={jax_embed.mean():.6f}, std={jax_embed.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_embed.mean():.6f}, std={torch_embed.std():.6f}' + ) + + # Assert embedding std is close to 0.02 + assert abs(jax_embed.std() - 0.02) < 0.005, ( + f'JAX embedding std {jax_embed.std():.6f} not close to 0.02' + ) + assert abs(torch_embed.std() - 0.02) < 0.005, ( + f'PyTorch embedding std {torch_embed.std():.6f} not close to 0.02' + ) + assert abs(jax_embed.mean()) < 0.01, ( + f'JAX embedding mean {jax_embed.mean():.6f} not close to 0' + ) + assert abs(torch_embed.mean()) < 0.01, ( + f'PyTorch embedding mean {torch_embed.mean():.6f} not close to 0' + ) + + # Check first layer attention Q + jax_q = jax_params['params']['blocks_0']['CausalAttn_0']['query']['kernel'] + torch_q_weight = ( + pytorch_model.layers[0].attn.w_qkv.weight[:512].detach().numpy() + ) + + logging.info('\nAttention Q:') + logging.info(f' JAX: mean={jax_q.mean():.6f}, std={jax_q.std():.6f}') + logging.info( + f' PyTorch: mean={torch_q_weight.mean():.6f}, std={torch_q_weight.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_q.mean()) < 0.01, ( + f'JAX Q mean {jax_q.mean():.6f} not close to 0' + ) + assert abs(torch_q_weight.mean()) < 0.01, ( + f'PyTorch Q mean {torch_q_weight.mean():.6f} not close to 0' + ) + + # Check stds are similar + # Allow 20% difference due to random initialization + assert abs(jax_q.std() - torch_q_weight.std()) / torch_q_weight.std() < 0.2, ( + f'Q std differs too much: JAX {jax_q.std():.6f} vs PyTorch {torch_q_weight.std():.6f}' + ) + + # Check first layer attention output (should be scaled) + jax_attn_out = jax_params['params']['blocks_0']['CausalAttn_0'][ + 'attn_out_proj' + ]['kernel'] + torch_attn_out = pytorch_model.layers[0].attn.w_out.weight.detach().numpy() + + logging.info('\nAttention Output:') + logging.info( + f' JAX: mean={jax_attn_out.mean():.6f}, std={jax_attn_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_attn_out.mean():.6f}, std={torch_attn_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_attn_out.mean()) < 0.01, ( + f'JAX attn out mean {jax_attn_out.mean():.6f} not close to 0' + ) + assert abs(torch_attn_out.mean()) < 0.01, ( + f'PyTorch attn out mean {torch_attn_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_attn_out.std() - torch_attn_out.std()) / torch_attn_out.std() < 0.2 + ), ( + f'Attention output std differs too much: JAX {jax_attn_out.std():.6f} vs PyTorch {torch_attn_out.std():.6f}' + ) + + # Check MLP fc2 (should be scaled) + jax_mlp_out = jax_params['params']['blocks_0']['Mlp_0']['Dense_1']['kernel'] + torch_mlp_out = pytorch_model.layers[0].mlp.fc2.weight.detach().numpy() + + logging.info('\nMLP Output:') + logging.info( + f' JAX: mean={jax_mlp_out.mean():.6f}, std={jax_mlp_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_mlp_out.mean():.6f}, std={torch_mlp_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_mlp_out.mean()) < 0.01, ( + f'JAX MLP out mean {jax_mlp_out.mean():.6f} not close to 0' + ) + assert abs(torch_mlp_out.mean()) < 0.01, ( + f'PyTorch MLP out mean {torch_mlp_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_mlp_out.std() - torch_mlp_out.std()) / torch_mlp_out.std() < 0.2 + ), ( + f'MLP output std differs too much: JAX {jax_mlp_out.std():.6f} vs PyTorch {torch_mlp_out.std():.6f}' + ) + + logging.info('\nβœ“ Initialization statistics test passed') + + +def test_initialization_impact(): + """Test that initialization produces similar initial losses.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Impact') + logging.info('=' * 70) + + # Create identical inputs + batch_size, seq_len = 4, 128 + vocab_size = 50000 + + np.random.seed(42) + tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) + + # Initialize both models with same seed + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=seq_len, + num_layers=12, + vocab_size=vocab_size, + expanded_model_dim=2048, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) + ) + + torch.manual_seed(42) + pytorch_cfg = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + # Forward pass + jax_logits = jax_model.apply(jax_params, jnp.array(tokens, dtype=jnp.int32)) + + with torch.no_grad(): + torch_logits = pytorch_model(torch.tensor(tokens, dtype=torch.long)) + + # Compute losses + targets = tokens[:, 1:] + jax_loss = -jax.nn.log_softmax(jax_logits[:, :-1]).mean() + torch_loss = F.cross_entropy( + torch_logits[:, :-1].reshape(-1, vocab_size), + torch.tensor(targets.reshape(-1), dtype=torch.long), + ) + + logging.info('\nInitial Loss Comparison:') + logging.info(f' JAX: {jax_loss:.4f}') + logging.info(f' PyTorch: {torch_loss.item():.4f}') + logging.info(f' Difference: {abs(jax_loss - torch_loss.item()):.6f}') + + # Check that losses are in reasonable range for random init + # With vocab_size=50000, random init should give loss around log(50000) β‰ˆ 10.82 + expected_loss = np.log(vocab_size) + + assert 8.0 < jax_loss < 13.0, ( + f'JAX loss {jax_loss:.4f} outside expected range [8.0, 13.0]' + ) + assert 8.0 < torch_loss.item() < 13.0, ( + f'PyTorch loss {torch_loss.item():.4f} outside expected range [8.0, 13.0]' + ) + + # Both losses should be within 10% of log(vocab_size) + assert abs(jax_loss - expected_loss) / expected_loss < 0.1, ( + f'JAX loss {jax_loss:.4f} too far from expected {expected_loss:.4f}' + ) + assert abs(torch_loss.item() - expected_loss) / expected_loss < 0.1, ( + f'PyTorch loss {torch_loss.item():.4f} too far from expected {expected_loss:.4f}' + ) + + logging.info( + '\nNote: Losses are in expected range for random initialization.' + ) + logging.info(f' Expected ~log(vocab_size) = {expected_loss:.4f}') + logging.info('\nβœ“ Initialization impact test passed') + + +# ============================================================================ +# Test Class +# ============================================================================ + +named_parameters = [ + dict(testcase_name='rmsnorm', test_fn=test_rmsnorm), + dict(testcase_name='rope', test_fn=test_rope), + dict(testcase_name='mlp', test_fn=test_mlp), + dict(testcase_name='attention', test_fn=test_attention), + dict(testcase_name='block', test_fn=test_block), + dict(testcase_name='full_model', test_fn=test_full_model), + dict(testcase_name='prediction', test_fn=test_prediction), + dict( + testcase_name='initialization_statistics', + test_fn=test_initialization_statistics, + ), + dict( + testcase_name='initialization_impact', test_fn=test_initialization_impact + ), +] + + +class ModelMatchingTest(parameterized.TestCase): + """Tests for JAX vs PyTorch model matching.""" + + @parameterized.named_parameters(*named_parameters) + def test_model_matching(self, test_fn): + """Run individual model matching test.""" + test_fn() + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/test_step_times.py b/tests/test_step_times.py new file mode 100644 index 000000000..22868d67d --- /dev/null +++ b/tests/test_step_times.py @@ -0,0 +1,199 @@ +"""Tests that JAX and PyTorch step times are within 20% of each other. + +This test runs each workload for a number of steps with both JAX and PyTorch, +captures the step_time_ms metric, and asserts they are within 20%. +""" + +import re +import subprocess +import sys +import tempfile +from pathlib import Path + +from absl import flags, logging +from absl.testing import absltest, parameterized + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +MAX_STEPS = 101 +TOLERANCE = 0.25 + +WORKLOADS = [ + 'imagenet_vit', +] + +DATA_DIRS = { + 'imagenet_resnet': '/opt/data/imagenet/', + 'imagenet_vit': '/opt/data/imagenet/', + 'librispeech_conformer': '/opt/data/librispeech', + 'librispeech_deepspeech': '/opt/data/librispeech', + 'criteo1tb': '/opt/data/criteo1tb', + 'fastmri': '/opt/data/fastmri', + 'ogbg': '/opt/data/ogbg', + 'wmt': '/opt/data/wmt', +} + +CONDA_ENVS = { + 'jax': 'ap11_jax', + 'pytorch': 'ap11_torch_latest', +} + + +def get_data_dir(workload: str, framework: str) -> str: + """Map workload to its data directory.""" + base_dir = DATA_DIRS.get(workload, '/opt/data') + if workload in ['imagenet_resnet', 'imagenet_vit']: + return base_dir + framework + return base_dir + + +def run_workload(workload: str, framework: str, output_file: Path) -> bool: + """Run a workload and capture output to file.""" + data_dir = get_data_dir(workload, framework) + experiment_dir = tempfile.mkdtemp(prefix=f'{workload}_{framework}_') + + submission_path = ( + f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' + ) + tuning_search_space = ( + 'algorithms/baselines/external_tuning/tuning_search_space.json' + ) + + if framework == 'jax': + cmd = [ + 'python', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + else: + cmd = [ + 'torchrun', + '--nproc_per_node=4', + '--standalone', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + + conda_env = CONDA_ENVS[framework] + activate_cmd = ( + f'source $(conda info --base)/etc/profile.d/conda.sh && ' + f'conda activate {conda_env} && ' + ) + full_cmd = activate_cmd + ' '.join(cmd) + + logging.info(f'Running: {workload} with {framework}') + logging.info(f'Output will be saved to: {output_file}') + + with open(output_file, 'w') as f: + result = subprocess.run( + full_cmd, + shell=True, + executable='/bin/bash', + stdout=f, + stderr=subprocess.STDOUT, + cwd=str(Path(__file__).parent.parent), + ) + + return result.returncode == 0 + + +def parse_step_time(output_file: Path) -> float | None: + """Parse the last step_time_ms from output file.""" + if not output_file.exists(): + return None + + with open(output_file, 'r') as f: + content = f.read() + + # Find all step_time_ms values + # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 + pattern = r'step_time_ms[=:]\s*([\d.]+)' + matches = re.findall(pattern, content) + + if matches: + # Return the last value (most recent EMA) + return float(matches[-1]) + return None + + +named_parameters = [ + dict(testcase_name=workload, workload=workload) for workload in WORKLOADS +] + + +class StepTimeTest(parameterized.TestCase): + """Tests that JAX and PyTorch step times are within tolerance.""" + + @parameterized.named_parameters(*named_parameters) + def test_step_times_within_tolerance(self, workload): + """Test that JAX and PyTorch step times are within 20% of each other.""" + results = {} + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + for framework in ['jax', 'pytorch']: + output_file = tmpdir / f'{workload}_{framework}.out' + + success = run_workload(workload, framework, output_file) + self.assertTrue(success, f'Failed to run {workload} with {framework}') + + step_time = parse_step_time(output_file) + self.assertIsNotNone( + step_time, + f'Could not parse step_time_ms for {workload} with {framework}', + ) + + results[framework] = step_time + logging.info(f'{workload} {framework}: {step_time:.2f} ms') + + jax_time = results['jax'] + pytorch_time = results['pytorch'] + ratio = pytorch_time / jax_time + + logging.info( + f'{workload}: JAX={jax_time:.2f}ms, PyTorch={pytorch_time:.2f}ms, ' + f'ratio={ratio:.2f}' + ) + + # Check that ratio is within tolerance (0.8 to 1.2 for 20% tolerance) + lower_bound = 1.0 - TOLERANCE + upper_bound = 1.0 + TOLERANCE + + self.assertGreaterEqual( + ratio, + lower_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% faster than JAX ' + f'(ratio={ratio:.2f}, expected >= {lower_bound:.2f})', + ) + self.assertLessEqual( + ratio, + upper_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% slower than JAX ' + f'(ratio={ratio:.2f}, expected <= {upper_bound:.2f})', + ) + + +if __name__ == '__main__': + absltest.main()