Skip to content

The substantial overhead of NKI kernel construction in the Python call stack. #72

@dinghongsong

Description

@dinghongsong

Describe the bug

As previously discussed, the current optimized NKI matrix multiplication kernel is still slower than the matrix multiplication using torch.matmul or torch.einsum measured in the benchmark_sampling of the NxDI for two reasons:

  1. In torch.matmul, the tensor data on the state buffer is directly loaded into the TensorMatrix engine for matrix multiplication. However, in the NKI kernel, NKI enforces that all inputs and outputs to a kernel must be within HBM. This means that the tensor data must be taken from the state buffer and written to HBM outside the kernel scope. Then, within the kernel scope, the tensor data is loaded back from HBM into the state buffer.

  2. The NKI kernel construction in the Python call stack causes significant overhead.

Expected Behavior

Under Python's time.time() measurement, the NKI matrix multiplication kernel should be faster than torch.matmul. The NKI kernel construction overhead in the Python call stack should be minimal.

Current Behavior

Under Python's time.time() measurement, the NKI matrix multiplication kernel is much slower than torch.matmul.

Reproduction Steps

torch_matmul.py:

from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch 
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "


if __name__ == "__main__":
  K, M, N = 128, 128, 512
  device = xm.xla_device()
  cpu = torch.device('cpu')

  A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
  B = torch.rand((K, N), dtype=torch.bfloat16, device=device)

  for _ in range(100):
    start_time = time.time()
    output_torch = torch.matmul(A, B)
    xm.mark_step()
    xm.wait_device_ops()
    end_time = time.time()
    print(f"output_torch={output_torch}")
    print("torch matmul time (ms): ", (end_time - start_time ) * 1000)

nki_matmul.py:

from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch 
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "

@nki.jit
def nki_matmul_basic_(lhsT, rhs):
  """NKI kernel to compute a 128x128x512 matrix multiplication operation

  Args:
      lhsT: an input tensor of shape [128,128], a left hand side argument of the
        matrix multiplication, delivered transposed for optimal performance
      rhs: an input tensor of shape [128,512], a right hand side argument of the
        matrix multiplication
  Returns:
      result: the resulting output tensor of shape [128,512]
  """
  result = nl.ndarray((128, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)

  i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:128]
  i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]
  i_out_p, i_out_f = nl.mgrid[0:128, 0:512]

 
  lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
  rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])

  
  result_psum = nisa.nc_matmul(lhs_tile, rhs_tile)

 
  result_sbuf = nl.copy(result_psum, dtype=result.dtype)
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)

  return result


if __name__ == "__main__":
  K, M, N = 128, 128, 512
  device = xm.xla_device()
  cpu = torch.device('cpu')

  A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
  B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
  
  for _ in range(100):
    start_time = time.time()
    output_nki = nki_matmul_basic_(A.T, B)
    xm.mark_step()
    xm.wait_device_ops()
    end_time = time.time()
    print(f"output_nki={output_nki}")
    print("nki matmul time (ms): ", (end_time - start_time ) * 1000)

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Additional Information/Context

No response

neuronx-cc version used

NeuronX Compiler version 2.16.372.0+4a9b2326

Framework(s) and their versions used (JAX, PyTorch, etc..)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions