-
Notifications
You must be signed in to change notification settings - Fork 32
Description
Describe the bug
As previously discussed, the current optimized NKI matrix multiplication kernel is still slower than the matrix multiplication using torch.matmul or torch.einsum measured in the benchmark_sampling of the NxDI for two reasons:
-
In
torch.matmul, the tensor data on the state buffer is directly loaded into the TensorMatrix engine for matrix multiplication. However, in the NKI kernel, NKI enforces that all inputs and outputs to a kernel must be within HBM. This means that the tensor data must be taken from the state buffer and written to HBM outside the kernel scope. Then, within the kernel scope, the tensor data is loaded back from HBM into the state buffer. -
The NKI kernel construction in the Python call stack causes significant overhead.
Expected Behavior
Under Python's time.time() measurement, the NKI matrix multiplication kernel should be faster than torch.matmul. The NKI kernel construction overhead in the Python call stack should be minimal.
Current Behavior
Under Python's time.time() measurement, the NKI matrix multiplication kernel is much slower than torch.matmul.
Reproduction Steps
torch_matmul.py:
from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "
if __name__ == "__main__":
K, M, N = 128, 128, 512
device = xm.xla_device()
cpu = torch.device('cpu')
A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
for _ in range(100):
start_time = time.time()
output_torch = torch.matmul(A, B)
xm.mark_step()
xm.wait_device_ops()
end_time = time.time()
print(f"output_torch={output_torch}")
print("torch matmul time (ms): ", (end_time - start_time ) * 1000)
nki_matmul.py:
from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "
@nki.jit
def nki_matmul_basic_(lhsT, rhs):
"""NKI kernel to compute a 128x128x512 matrix multiplication operation
Args:
lhsT: an input tensor of shape [128,128], a left hand side argument of the
matrix multiplication, delivered transposed for optimal performance
rhs: an input tensor of shape [128,512], a right hand side argument of the
matrix multiplication
Returns:
result: the resulting output tensor of shape [128,512]
"""
result = nl.ndarray((128, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)
i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:128]
i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]
i_out_p, i_out_f = nl.mgrid[0:128, 0:512]
lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])
result_psum = nisa.nc_matmul(lhs_tile, rhs_tile)
result_sbuf = nl.copy(result_psum, dtype=result.dtype)
nl.store(result[i_out_p, i_out_f], value=result_sbuf)
return result
if __name__ == "__main__":
K, M, N = 128, 128, 512
device = xm.xla_device()
cpu = torch.device('cpu')
A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
for _ in range(100):
start_time = time.time()
output_nki = nki_matmul_basic_(A.T, B)
xm.mark_step()
xm.wait_device_ops()
end_time = time.time()
print(f"output_nki={output_nki}")
print("nki matmul time (ms): ", (end_time - start_time ) * 1000)
Regression Issue
- Select this option if this issue appears to be a regression.
Possible Solution
No response
Additional Information/Context
No response
neuronx-cc version used
NeuronX Compiler version 2.16.372.0+4a9b2326
Framework(s) and their versions used (JAX, PyTorch, etc..)
No response