diff --git a/configs/debug.yaml b/configs/debug.yaml new file mode 100644 index 0000000..443a9a2 --- /dev/null +++ b/configs/debug.yaml @@ -0,0 +1,42 @@ + +# — Model — +model_dim: 768 +n_layer: 12 +n_head: 6 +sequence_length: 128 + +# — Batching & Training — +num_iterations: 3000 +batch_size: 1024 +device_batch_size: 256 + +# — Learning‐rate schedule — +warmup_ratio: 0.0 +warmdown_ratio: 0.2 + +# — Validation & Checkpointing — +val_loss_every: 125 +val_tokens: 10485760 + +# — Weights & Biases logging — +wandb_project_name: debug +wandb_job_name: null +no_wandb: false + +# — Miscellaneous flags — +debug: false +no_compile: false + +# — Distributed training — +fs_size: 4 # FSDP size + +# — Optimizer & Hyperparameters — +optimizer: dion2 +scalar_opt: lion +mu: 0.95 +weight_decay: 0.01 +ortho_fraction: 0.25 +lr: 0.02 + +shard_independent: true +no_validation: false \ No newline at end of file diff --git a/dion/dion2.py b/dion/dion2.py index 68cea3b..9ff145f 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -1,3 +1,4 @@ +# dion2.py import math import torch import torch.distributed as dist @@ -51,9 +52,10 @@ class Dion2(Optimizer): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. - verbose: Whether to print debug information during updates. If True, it prints whether rows or columns are selected for the submatrix selection process. + shard_independent: If True, use shard-independent row selection for Dion2. + This incurs higher communication cost due to all-to-all of full matrices, - Dion2 optimizer by Ahn et al.: TBD + Dion2 optimizer by Ahn et al.: https://arxiv.org/abs/2512.16928 """ def __init__( @@ -70,7 +72,7 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, - verbose: bool = False, + shard_independent: bool = False, ): # Validate hyperparameters if lr < 0.0: @@ -135,7 +137,7 @@ def __init__( self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 - self.verbose = verbose + self.shard_independent = shard_independent @torch.no_grad() def step(self, closure=None): @@ -167,7 +169,7 @@ def step(self, closure=None): raise ValueError(f"Unknown algorithm: {algo}") # Create async tasks for each algorithm - dion2_tasks = self._create_dion2_tasks(dion2_groups, verbose=self.verbose) + dion2_tasks = self._create_dion2_tasks(dion2_groups) lion_tasks = self._create_lion_tasks(lion_groups) adamw_tasks = self._create_adamw_tasks(adamw_groups) @@ -192,7 +194,6 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: def _create_dion2_tasks( self, param_groups: List[dict], - verbose: bool = False, ) -> Generator["AsyncTask", None, None]: """ Helper function to create batches of Dion2 matrices and generate @@ -223,6 +224,7 @@ def _create_dion2_tasks( world_size=self._world_size, process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, + shard_independent=self.shard_independent, ) # Create batches of parameters of size self._world_size @@ -264,6 +266,15 @@ def _create_dion2_tasks( (i, p) for i, p in shard_placements if p.dim in matrix_dims ] + # We currently do not support tensors sharded along the last dimension because Dion2 + # normalization later assumes a full trailing axis when computing means. + if any(p.dim == params[0].ndim - 1 for _, p in shard_placements): + raise NotImplementedError( + "Dion2 currently does not support parameters sharded along the last dimension. " + "Please avoid shards at dim -1." + "(Note: Default behavior of FSDP2 is to shard along dim-0.)" + ) + # Check that we have no more than 1 sharded matrix dimension # Note that non-flattened 3D tensors can have additional sharded batch dimensions # Flattened 3D tensors are limited to one sharded dimension out of all dimensions @@ -299,7 +310,6 @@ def _create_dion2_tasks( M=[m], shard_dim=None, # No sharded matrix dim **dion2_args, - verbose=verbose, ) ) # Otherwise, we parallelize the Muon update across devices @@ -311,7 +321,6 @@ def _create_dion2_tasks( M=pad_batch(momentums, self._world_size), shard_dim=sharded_tensor_dim, **dion2_args, - verbose=verbose, ) ) @@ -412,7 +421,7 @@ def dion2_update_batch_async( shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, - verbose: bool = False, + shard_independent: bool = False, ) -> Generator[None, None, None]: """ Batched version of Dion2 update. Batch size should be equal to number of GPUs. @@ -422,137 +431,210 @@ def dion2_update_batch_async( assert len(X) == len(G) assert len(X) == len(M) - # Determine selection dimension based on sharding and tensor shape: - # For sharded matrices, we align select_dim with shard_dim - # For unsharded matrices (DDP or single-GPU), we select the shorter dimension - ndim = X[0].ndim - select_dim = None - - if shard_dim is not None: - # Normalize shard_dim to negative indexing for unified treatment - shard_dim = shard_dim if shard_dim < 0 else shard_dim - ndim - if shard_dim == -2: - select_dim = -2 # Row-sharded - elif shard_dim == -1: - select_dim = -1 # Column-sharded - - # Fall-back to shorter dimension when DDP, Single-GPU, or batch-sharded - if select_dim is None: - num_rows, num_cols = X[0].shape[-2:] - select_dim = -2 if num_rows <= num_cols else -1 - - # Print how the selection choice based on shard_dim and tensor shape - if verbose: - _print_selection_choice(X[0].shape, shard_dim, select_dim, ndim) - - # Update momentum and select top-α fraction along select_dim - U_selected, indices_list = dion2_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - fraction=fraction, - ef_decay=ef_decay, - select_dim=select_dim, - ) - - # Get one whole matrix for each device to orthogonalize - if shard_dim is not None: - # Use all-to-all to transform from a batch of shards to a single whole matrix - # https://www.essential.ai/blog/infra - assert len(X) == world_size, "Batch size must equal world size" - assert ( - process_group is not None - ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert ( - X[0].size(shard_dim) % world_size == 0 - ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." - - # Allocate buffers to receive shards of one whole submatrix from other devices - recv_shards = [torch.empty_like(u) for u in U_selected] - work = dist.all_to_all( - recv_shards, U_selected, group=process_group, async_op=True - ) - yield - work.wait() + # ALways select submtrix based on row-wise selection + # This way it chooses the submatrix neuron-wise + select_dim = -2 - # Concatentate shards to form a whole matrix to orthogonalize - # Only submatrix is orthogonalized! - full_submatrix = torch.cat(recv_shards, dim=select_dim) - full_submatrix = muon_update_newton_schulz( - full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon - ) + # Shard-independent path: + # The row-selection process is independent of the sharding configuration. + # However, this comes with the cost of all-to-all communication of full matrices. + if shard_dim is not None and shard_independent: + assert len(X) == world_size + assert process_group is not None - # Split result back into shards - # Contiguous is needed for all-to-all to work correctly - send_shards = [ - t.contiguous() - for t in torch.tensor_split(full_submatrix, world_size, dim=select_dim) - ] + M_local = to_local(M) + G_local = to_local(G) + dtype = M_local[0].dtype - # Redistribute the orthogonalized tensor back to original layout - U_ortho = [torch.empty_like(u) for u in U_selected] - work = dist.all_to_all(U_ortho, send_shards, group=process_group, async_op=True) + # Update momentum locally: M = M + G + G_casted = [g.to(dtype=dtype) for g in G_local] + torch._foreach_add_(M_local, G_casted) + + # All-to-all to get one full matrix per device + M_recv = [torch.empty_like(M_local[0]) for _ in range(world_size)] + work = dist.all_to_all(M_recv, M_local, group=process_group, async_op=True) yield work.wait() - # Matrices are not sharded, so we can distribute the batch across different devices - # Get a single matrix of the batch corresponding to this device - elif len(U_selected) > 1: - assert len(U_selected) == world_size, "Batch size must equal world size" - assert process_group is not None + # Concatenate to form full matrix + M_full = torch.cat(M_recv, dim=shard_dim) + full_rows, full_cols = M_full.shape[-2:] - single_matrix = U_selected[device_rank] - assert not isinstance(single_matrix, DTensor) + # Select top-k on full matrix + num_select = M_full.size(select_dim) + k = max(1, int(math.ceil(fraction * num_select))) - single_ortho = muon_update_newton_schulz( - single_matrix, - newton_schulz_func, - flatten=flatten, - epsilon=epsilon, + slice_norms = M_full.norm(p=1, dim=-1) + _, indices = torch.topk(slice_norms, k, dim=-1, sorted=False) + + # Extract and orthonormalize + U_selected = M_full.index_select(dim=select_dim, index=indices).to( + dtype=torch.bfloat16 + ) + U_ortho = muon_update_newton_schulz( + U_selected, newton_schulz_func, flatten=flatten, epsilon=epsilon ) - # Allocate empty tensors to receive updates from other devices - U_ortho = [torch.empty_like(u) for u in U_selected] - # All gather orthogonalized results from other devices into buffer - work = dist.all_gather( - U_ortho, single_ortho.contiguous(), group=process_group, async_op=True + # Construct dense update matrix (zeros on the unchosen rows/cols) + U_dense_full = torch.zeros( + M_full.shape, dtype=U_ortho.dtype, device=U_ortho.device ) + idx_exp = indices.unsqueeze(-1).expand(-1, full_cols) + U_dense_full.scatter_(-2, idx_exp, U_ortho) + + # All-to-all scatter dense U back to shards + U_send = [ + t.contiguous() + for t in torch.tensor_split(U_dense_full, world_size, dim=shard_dim) + ] + U_recv = [torch.empty_like(U_send[0]) for _ in range(world_size)] + + + work = dist.all_to_all(U_recv, U_send, group=process_group, async_op=True) yield work.wait() - # Single tensor with no sharded dimension. This happens in 2 cases: - # - Running on a single GPU - # - 3D+ tensors sharded along a batch dimension (different whole matrices per device) + # Infer selected indices from non-zero rows/cols and apply error-feedback decay locally + dion2_apply_ef_decay_from_dense( + M=M_local, + U_dense=U_recv, + ef_decay=ef_decay, + select_dim=select_dim, + ) + + # Compute scaled learning rate + # Do this before to_local(X) because we use the full tensor shape, not the shard shape + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + + # Apply weight update using dense U + dion2_post_orthogonalize_dense( + X=to_local(X), + U_dense=U_recv, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + # This path corresponds to the case when `shard-independent` is False + # When shard_dim is not None, the row selection process depends on the sharding configuration. + # This is because each shard chooses its own top-k rows based on local momentum. + # This path only communicates the selected submatrices for orthogonalization, which leads to less communication. else: - assert len(U_selected) == 1 - U_ortho = [ - muon_update_newton_schulz( - U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon + # Update momentum and select top-fraction along select_dim + U_selected, indices_list = dion2_pre_orthogonalize( + G=to_local(G), + M=to_local(M), + fraction=fraction, + ef_decay=ef_decay, + select_dim=select_dim, + ) + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + # https://www.essential.ai/blog/infra + assert len(X) == world_size, "Batch size must equal world size" + assert ( + process_group is not None + ), "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." + + # Allocate buffers to receive shards of one whole submatrix from other devices + recv_shards = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all( + recv_shards, U_selected, group=process_group, async_op=True + ) + yield + work.wait() + + # Concatentate shards to form a whole matrix to orthogonalize + # Only submatrix is orthogonalized! + full_submatrix = torch.cat(recv_shards, dim=select_dim) + full_submatrix = muon_update_newton_schulz( + full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon ) - ] - # Compute scaled learning rate - # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) - else: - raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + # Split result back into shards + # Contiguous is needed for all-to-all to work correctly + send_shards = [ + t.contiguous() + for t in torch.tensor_split(full_submatrix, world_size, dim=select_dim) + ] + + # Redistribute the orthogonalized tensor back to original layout + U_ortho = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all(U_ortho, send_shards, group=process_group, async_op=True) + yield + work.wait() + + # Matrices are not sharded, so we can distribute the batch across different devices + # Get a single matrix of the batch corresponding to this device + elif len(U_selected) > 1: + assert len(U_selected) == world_size, "Batch size must equal world size" + assert process_group is not None + + single_matrix = U_selected[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_ortho = muon_update_newton_schulz( + single_matrix, + newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + # Allocate empty tensors to receive updates from other devices + U_ortho = [torch.empty_like(u) for u in U_selected] + # All gather orthogonalized results from other devices into buffer + work = dist.all_gather( + U_ortho, single_ortho.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() - # Update model parameters with orthogonalized output - # Weight update is applied to selected slices only - dion2_post_orthogonalize( - X=to_local(X), - U=U_ortho, - indices=indices_list, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - select_dim=select_dim, - ) + # Single tensor with no sharded dimension. This happens in 2 cases: + # - Running on a single GPU + # - 3D+ tensors sharded along a batch dimension (different whole matrices per device) + else: + assert len(U_selected) == 1 + U_ortho = [ + muon_update_newton_schulz( + U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon + ) + ] + + # Compute scaled learning rate + # Do this before to_local(X) because we use the full tensor shape, not the shard shape + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + + # Update model parameters with orthogonalized output + # Weight update is applied to selected slices only + dion2_post_orthogonalize( + X=to_local(X), + U=U_ortho, + indices=indices_list, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + select_dim=select_dim, + ) @torch.compile(fullgraph=True) @@ -633,6 +715,7 @@ def dion2_post_orthogonalize( Inputs and outputs should be lists of regular Tensor, not DTensor. This is a separate function for compatibility with torch.compile(). """ + # Apply weight decay torch._foreach_mul_(X, 1 - base_lr * weight_decay) # Convert U to match parameter dtype @@ -645,44 +728,48 @@ def dion2_post_orthogonalize( x.index_add_(dim=select_dim, index=idx, source=u_scaled) -# A helper function to print selection chocie for each matrix -# It only prints once `verbose` is set True -_printed_configs: set = set() +def dion2_apply_ef_decay_from_dense( + M: List[Tensor], + U_dense: List[Tensor], + ef_decay: Tensor, + select_dim: int, +): + """ + Infer selected indices from non-zero rows/cols in U_dense, + then apply error-feedback decay to those rows/cols in M. + This function is used for `shard-independent' path of the Dion2 update. + """ + norm_dim = -1 if select_dim == -2 else -2 + + for m, u in zip(M, U_dense): + # Check if ANY element in the row/col is non-zero + any_nonzero = (u != 0).any(dim=norm_dim) + indices = torch.where(any_nonzero)[0] + + # Apply error-feedback decay on the non-zero rows/cols + if indices.numel() > 0: + selected_slice = m.index_select(dim=select_dim, index=indices) + m.index_copy_( + dim=select_dim, index=indices, source=selected_slice * ef_decay + ) -def _print_selection_choice( - shape: torch.Size, - shard_dim: Optional[int], - select_dim: int, - ndim: int, +def dion2_post_orthogonalize_dense( + X: List[Tensor], + U_dense: List[Tensor], + base_lr: Tensor, + adjusted_lr: Tensor, + weight_decay: Tensor, ): - config_key = (tuple(shape), shard_dim, select_dim) - if config_key not in _printed_configs: - _printed_configs.add(config_key) - - num_rows, num_cols = shape[-2:] - select_info = "rows" if select_dim == -2 else "columns" - norm_info = "row norms" if select_dim == -2 else "col norms" - - if shard_dim is None: - mode = "DDP/Single-GPU" - shorter = "rows" if num_rows <= num_cols else "cols" - reason = f"shorter dim = {shorter} ({min(num_rows, num_cols)})" - else: - # Normalize shard_dim for display - normalized = shard_dim if shard_dim < 0 else shard_dim - ndim - if normalized == -2: - mode = "FSDP" - reason = f"row-sharded (shard_dim={shard_dim}→-2)" - elif normalized == -1: - mode = "FSDP" - reason = f"col-sharded (shard_dim={shard_dim}→-1)" - else: - mode = "FSDP batch-sharded" - shorter = "rows" if num_rows <= num_cols else "cols" - reason = f"shard_dim={shard_dim} (batch), shorter = {shorter}" + """ + Apply weight decay and weight update. + This function assumes that the update is dense (zeros on unselected rows/cols). + This is used for `shard-independent' path of the Dion2 update. + """ + # Apply weight decay + torch._foreach_mul_(X, 1 - base_lr * weight_decay) - print( - f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " - f"select top-α {select_info} by {norm_info}" - ) + # Dense weight update + dtype = X[0].dtype + U_casted = [u.to(dtype=dtype) for u in U_dense] + torch._foreach_add_(X, U_casted, alpha=-adjusted_lr) diff --git a/dion/normuon.py b/dion/normuon.py index e2a2479..5f9a35f 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -587,14 +587,16 @@ def normuon_normalization( """ V_dtype = V[0].dtype U = [u.to(dtype=V_dtype) for u in U] - + norm_U = [ u.norm(p=2, dim=(-2, -1), keepdim=True) for u in U ] # list of ||u||_F, shape [*, 1, 1] U_sq = torch._foreach_mul(U, U) # list of u*u, same shapes as U - neuron_norms = [u_sq.mean(dim=-1, keepdim=True) for u_sq in U_sq] # Shape: [*, rows, 1] - + neuron_norms = [ + u_sq.mean(dim=-1, keepdim=True) for u_sq in U_sq + ] # Shape: [*, rows, 1] + torch._foreach_lerp_( V, neuron_norms, 1 - muon_beta2 ) # Update variance neuron buffer @@ -606,17 +608,17 @@ def normuon_normalization( norm_U_new = [ nu.norm(p=2, dim=(-2, -1), keepdim=True) for nu in normalized_U ] # list of ||normalized_u||_F, shape [*, 1, 1] - + # Protect against division by zero when norm_U_new is zero. # This can happen when U is all zeros (e.g., zero gradients from zero-initialized weights). # In this case, norm_U is also zero, so after clamping norm_U_new to ε the ratio becomes 0/ε ≈ 0, # and normalized_U * ratio correctly remains zero, preserving the zero state. norm_U_new_safe = [nu.clamp(min=1e-8) for nu in norm_U_new] - + ratio = torch._foreach_div( norm_U, norm_U_new_safe ) # list of ||u||_F / ||normalized_u||_F, shape [*, 1, 1] - + torch._foreach_mul_(normalized_U, ratio) # normalized_u[i] *= ratio return normalized_U diff --git a/train.py b/train.py index 81654d1..622d731 100644 --- a/train.py +++ b/train.py @@ -52,6 +52,7 @@ class Hyperparameters: # Evaluation and logging val_loss_every: int = 125 val_tokens: int = 10485760 + no_validation: bool = False checkpoint_freq: int = 0 checkpoint_dir: str = None wandb_project_name: str = "dion-test" @@ -75,7 +76,7 @@ class Hyperparameters: adjust_lr: str = "spectral_norm" # for Muon only # For printing out selection choice in Dion2 - verbose: bool = True + shard_independent: bool = False # Helper function to only print on global rank 0 @@ -418,7 +419,7 @@ def init_optimizer( lr=hp.lr, mu=hp.mu, weight_decay=hp.weight_decay, - nesterov=True, + nesterov=False, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), ) @@ -431,13 +432,30 @@ def init_optimizer( outer_shard_mesh if outer_shard_mesh.size() > 1 else replicate_mesh ) comm_method = "all-to-all" if outer_shard_mesh.size() > 1 else "all-gather" + is_sharded = outer_shard_mesh.size() > 1 else: assert ddp_model is not None distributed_mesh = ddp_model.process_group # using ProcessGroup for DDP comm_method = "all-gather" + is_sharded = False print0(f"LR adjust method: {hp.adjust_lr}") print0(f"Triton Newton-Schulz kernels: {not cli_args.no_triton}") print0(f"Distributed Dion2 using: {comm_method}") + if is_sharded: + if hp.shard_independent: + print0( + "Using shard-independent Dion2 path.\n" + " - Selects top-k rows/cols on the FULL matrix\n" + " - Algorithm behavior is identical regardless of sharding configuration" + ) + else: + print0( + "Using shard-dependent Dion2 path.\n" + " - Selects top-k rows/cols on LOCAL shards\n" + " - Algorithm behavior may vary with different sharding configurations\n" + " - Benefit: Only communicates selected submatrices (less communication)" + ) + opt = Dion2( param_groups, distributed_mesh=distributed_mesh, @@ -447,7 +465,7 @@ def init_optimizer( weight_decay=hp.weight_decay, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), - verbose=hp.verbose, + shard_independent=hp.shard_independent, ) elif hp.optimizer == "normuon": if device_mesh is not None: @@ -472,7 +490,7 @@ def init_optimizer( mu=hp.mu, muon_beta2=0.95, weight_decay=hp.weight_decay, - nesterov=True, + nesterov=False, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), ) @@ -496,7 +514,7 @@ def init_optimizer( lr=hp.lr, mu=hp.mu, weight_decay=hp.weight_decay, - nesterov=True, + nesterov=False, adjust_lr=hp.adjust_lr, ) @@ -813,10 +831,13 @@ def get_lr(it): optimizer_name = f"{hp.ortho_fraction}-{hp.optimizer}" run_name = f"({optimizer_name}+{hp.scalar_opt})" + if device_mesh is not None: dp, fs, tp = device_mesh.size(0), device_mesh.size(1), device_mesh.size(2) run_name += f"_(dp={dp}, fs={fs}, tp={tp})" + if "dion2" in hp.optimizer and hp.shard_independent: + run_name += "_shard-indep" if cli_args.wandb_job_name: run_name += f"_{cli_args.wandb_job_name}" @@ -901,42 +922,43 @@ def get_lr(it): # --- Validation --- last_step = step == hp.num_iterations if last_step or (hp.val_loss_every > 0 and step % hp.val_loss_every == 0): - # Measure elapsed time for training - torch.cuda.synchronize() - training_time_ms += 1000 * (time.time() - t0) - - # Run validation - model.eval() - val_loader.reset() - val_loss = torch.tensor(0.0, device=x.device) - for _ in range(val_steps): - with torch.no_grad(): - x_val, y_val = val_loader.next_batch() - with autocast_ctx: - loss = model(x_val, y_val) - val_loss += loss - - # Average validation loss across devices - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - val_loss = val_loss.item() / val_steps - log_message = ( - f"step:{step}/{hp.num_iterations} val_loss:{val_loss:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps):.2f}ms" - ) - print0(log_message) - if MASTER_PROCESS and not cli_args.no_wandb and not cli_args.debug: - wandb.log( - { - "val/loss": val_loss, - "step": step, - "time/training_time_ms": training_time_ms, # Log total elapsed training time in ms - } + if not hp.no_validation: + # Measure elapsed time for training + torch.cuda.synchronize() + training_time_ms += 1000 * (time.time() - t0) + + # Run validation + model.eval() + val_loader.reset() + val_loss = torch.tensor(0.0, device=x.device) + for _ in range(val_steps): + with torch.no_grad(): + x_val, y_val = val_loader.next_batch() + with autocast_ctx: + loss = model(x_val, y_val) + val_loss += loss + + # Average validation loss across devices + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + val_loss = val_loss.item() / val_steps + log_message = ( + f"step:{step}/{hp.num_iterations} val_loss:{val_loss:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps):.2f}ms" ) - pbar.set_postfix(val_loss=f"{val_loss:.4f}") - - # Restart training time for the next iteration - torch.cuda.synchronize() - t0 = time.time() + print0(log_message) + if MASTER_PROCESS and not cli_args.no_wandb and not cli_args.debug: + wandb.log( + { + "val/loss": val_loss, + "step": step, + "time/training_time_ms": training_time_ms, # Log total elapsed training time in ms + } + ) + pbar.set_postfix(val_loss=f"{val_loss:.4f}") + + # Restart training time for the next iteration + torch.cuda.synchronize() + t0 = time.time() if last_step: break @@ -985,6 +1007,7 @@ def get_lr(it): { "train/loss": train_loss.item(), "train/grad_norm": grad_norm.item(), + "train/base_lr": lr_scheduler.get_last_lr()[0], "step": step, "time/training_time_ms": approx_time, # Log approximate elapsed training time in ms }