From 5d50b0098b26fd66bce2fce72768f8fa35911f34 Mon Sep 17 00:00:00 2001 From: Shubham Chandak Date: Wed, 6 Aug 2025 17:02:31 +0530 Subject: [PATCH] Fix attention backward dropout --- src/nki_samples/reference/attention.py | 32 ++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/nki_samples/reference/attention.py b/src/nki_samples/reference/attention.py index c71de00..6c908af 100644 --- a/src/nki_samples/reference/attention.py +++ b/src/nki_samples/reference/attention.py @@ -895,13 +895,16 @@ def _flash_attn_bwd_core( # Dropout ##################################################################### if dropout_p > 0.0: + softmax_y_dropped = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \ + head_id * k_seq_n_tiles * q_seq_n_tiles \ + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) nl.random_seed(seed=offset_seed, mask=mask) - softmax_y[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask) - softmax_y[:, :] = nl.multiply(softmax_y[:, :], 1 / (1 - dropout_p), mask=mask) + softmax_y_dropped[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask) + softmax_y_dropped[:, :] = nl.multiply(softmax_y_dropped[:, :], 1 / (1 - dropout_p), mask=mask) + else: + softmax_y_dropped = softmax_y ##################################################################### # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V @@ -911,24 +914,39 @@ def _flash_attn_bwd_core( trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, :, :], mask=mask) dv_psum[i_d_head_tile, :, :] += \ - nisa.nc_matmul(trans_dy, softmax_y[:, :], mask=mask) + nisa.nc_matmul(trans_dy, softmax_y_dropped[:, :], mask=mask) ##################################################################### # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V # in value projection with matmul(stationary=dy, moving=v) ##################################################################### - softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), + softmax_dy_dropped_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, buffer=nl.psum) for i_d_head_tile in nl.affine_range(d_head_n_tiles): - softmax_dy_psum[:, :] += \ + softmax_dy_dropped_psum[:, :] += \ nisa.nc_matmul(dy_local[i_d_head_tile, :, :], v_local[i_d_head_tile, :, :], mask=mask) - softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_dy[:, :] = nl.copy(softmax_dy_psum[:, :], dtype=kernel_dtype, + softmax_dy_dropped = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) + softmax_dy_dropped[:, :] = nl.copy(softmax_dy_dropped_psum[:, :], dtype=kernel_dtype, mask=mask) + ##################################################################### + # Step 3.3 Apply Dropout to softmax_dy_dropped + ##################################################################### + if dropout_p > 0.0: + softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) + offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \ + + head_id * k_seq_n_tiles * q_seq_n_tiles \ + + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles + offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) + nl.random_seed(seed=offset_seed, mask=mask) + softmax_dy[:, :] = nl.dropout(softmax_dy_dropped[:, :], rate=dropout_p_local[:, 0], mask=mask) + softmax_dy[:, :] = nl.multiply(softmax_dy[:, :], 1 / (1 - dropout_p), mask=mask) + else: + softmax_dy = softmax_dy_dropped + ##################################################################### # Step 4 Calculate the softmax backward gradients dL/dx, where y=softmax(x) # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x)