diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 39ba03f25d..3152c14711 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 39ba03f25d4c4c4e9f551a2dcf001cadd0b86cbe +Subproject commit 3152c1471152f98b9c53cf16d956febba3789a84 diff --git a/aiter/configs/a8w8_blockscale_tuned_gemm.csv b/aiter/configs/a8w8_blockscale_tuned_gemm.csv index 50ce88f3de..d1474b388d 100755 --- a/aiter/configs/a8w8_blockscale_tuned_gemm.csv +++ b/aiter/configs/a8w8_blockscale_tuned_gemm.csv @@ -1,118 +1,118 @@ M,N,K,kernelId,splitK,us,kernelName -16,1536,7168,8,0,19.4246,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,3072,1536,8,0,6.3894,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,576,7168,8,0,19.3458,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,7168,256,8,0,3.7914,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,7168,2048,7,0,10.947,a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 -16,4608,7168,8,0,19.9315,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,7168,2304,8,0,11.3226,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,512,7168,8,0,19.4527,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -16,4096,512,8,0,4.0618,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -32,1536,7168,8,0,19.927,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -32,3072,1536,13,0,7.799,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -32,576,7168,8,0,19.8035,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -32,7168,256,13,0,4.8474,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -32,7168,2048,12,0,13.5078,a8w8_blockscale_1x128x128_256x32x128x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -32,4608,7168,7,0,27.3435,a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 -32,7168,2304,12,0,14.359,a8w8_blockscale_1x128x128_256x32x128x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -32,512,7168,8,0,19.7747,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -32,4096,512,13,0,4.7998,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -64,1536,7168,13,0,25.3771,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -64,3072,1536,18,0,10.1946,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -64,576,7168,8,0,19.875,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -64,7168,256,16,0,6.443,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -64,7168,2048,18,0,18.4746,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -64,4608,7168,18,0,31.0863,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -64,7168,2304,18,0,19.2478,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -64,512,7168,8,0,19.8418,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -64,4096,512,3,0,6.4018,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -128,1536,7168,18,0,31.0711,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -128,3072,1536,3,0,14.4742,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -128,576,7168,8,0,20.0614,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -128,7168,256,16,0,9.2646,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -128,7168,2048,18,0,27.5911,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -128,4608,7168,18,0,48.6159,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -128,7168,2304,3,0,27.7587,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -128,512,7168,8,0,20.0222,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 -128,4096,512,16,0,9.4706,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -256,1536,7168,18,0,48.0159,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -256,3072,1536,3,0,20.4438,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -256,576,7168,13,0,25.5535,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -256,7168,256,16,0,14.019,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -256,7168,2048,2,0,46.1811,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -256,4608,7168,2,0,89.5716,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -256,7168,2304,3,0,49.7975,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -256,512,7168,13,0,25.4523,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 -256,4096,512,2,0,14.0618,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -512,1536,7168,18,0,75.9904,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -512,3072,1536,3,0,32.1807,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -512,576,7168,18,0,31.4631,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -512,7168,256,16,0,23.9571,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -512,7168,2048,2,0,84.9144,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -512,4608,7168,2,0,161.569,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -512,7168,2304,2,0,90.31,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -512,512,7168,18,0,31.4115,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -512,4096,512,3,0,21.8851,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1024,1536,7168,18,0,121.7549,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -1024,3072,1536,2,0,59.7275,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1024,576,7168,18,0,49.1011,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -1024,7168,256,16,0,41.9815,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -1024,7168,2048,2,0,165.671,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1024,4608,7168,2,0,317.4746,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1024,7168,2304,3,0,175.3202,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1024,512,7168,18,0,48.4499,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -1024,4096,512,16,0,37.3275,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -1536,1536,7168,2,0,162.9982,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1536,3072,1536,3,0,88.3156,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1536,576,7168,18,0,76.2496,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -1536,7168,256,16,0,59.7435,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -1536,7168,2048,2,0,237.0036,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1536,4608,7168,2,0,450.6713,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1536,7168,2304,2,0,253.8764,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -1536,512,7168,18,0,75.9592,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -1536,4096,512,3,0,54.9919,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,1536,7168,2,0,212.3527,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,3072,1536,2,0,113.2165,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,576,7168,2,0,87.9612,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,7168,256,16,0,77.4816,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -2048,7168,2048,2,0,317.2466,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,4608,7168,2,0,606.9813,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,7168,2304,3,0,338.8351,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,512,7168,2,0,81.9908,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -2048,4096,512,16,0,70.4172,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -4096,1536,7168,2,0,402.536,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,3072,1536,3,0,220.7904,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,576,7168,2,0,172.5742,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,7168,256,16,0,149.649,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -4096,7168,2048,2,0,619.029,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,4608,7168,2,0,1172.2924,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,7168,2304,2,0,663.4679,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,512,7168,2,0,158.9842,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -4096,4096,512,16,0,134.6393,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -8192,1536,7168,2,0,805.105,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,3072,1536,3,0,430.6929,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,576,7168,2,0,340.7579,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,7168,256,16,0,293.9925,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -8192,7168,2048,2,0,1232.6365,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,4608,7168,2,0,2352.2349,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,7168,2304,2,0,1321.7627,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,512,7168,2,0,291.7201,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -8192,4096,512,16,0,262.15,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -16384,1536,7168,2,0,1584.0766,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,3072,1536,2,0,852.8383,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,576,7168,18,0,670.6759,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -16384,7168,256,16,0,581.5456,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -16384,7168,2048,2,0,2456.8027,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,4608,7168,2,0,4718.6563,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,7168,2304,2,0,2638.6195,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,512,7168,2,0,534.7103,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -16384,4096,512,16,0,515.5647,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -20480,1536,7168,2,0,1954.8342,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,3072,1536,2,0,1063.7988,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,576,7168,18,0,832.3255,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 -20480,7168,256,16,0,726.03,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 -20480,7168,2048,2,0,3075.1182,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,4608,7168,2,0,5890.4539,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,7168,2304,2,0,3292.9571,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,512,7168,2,0,656.6522,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 -20480,4096,512,16,0,643.5182,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +16,1536,7168,8,0,23.8369,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,3072,1536,8,0,7.9184,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,576,7168,8,0,23.3269,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,7168,256,8,0,3.8016,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,7168,2048,7,0,12.8812,a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +16,4608,7168,8,0,24.7533,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,7168,2304,7,0,12.224,a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +16,512,7168,8,0,23.3101,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,4096,512,8,0,4.2568,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,1536,7168,8,0,23.8553,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,3072,1536,8,0,10.0604,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,576,7168,8,0,23.2385,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,7168,256,11,0,5.2576,a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1 +32,7168,2048,12,0,15.0405,a8w8_blockscale_1x128x128_256x32x128x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +32,4608,7168,7,0,33.0525,a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +32,7168,2304,12,0,14.9417,a8w8_blockscale_1x128x128_256x32x128x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +32,512,7168,8,0,23.2729,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,4096,512,13,0,5.1852,a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +64,1536,7168,8,0,27.4433,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,3072,1536,18,0,10.5876,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +64,576,7168,8,0,23.1609,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,7168,256,11,0,7.7212,a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1 +64,7168,2048,18,0,21.4081,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +64,4608,7168,18,0,33.4145,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +64,7168,2304,18,0,20.1577,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +64,512,7168,8,0,23.1005,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,4096,512,18,0,7.1216,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,1536,7168,18,0,32.3921,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,3072,1536,18,0,16.4341,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,576,7168,8,0,22.6425,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +128,7168,256,16,0,9.514,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +128,7168,2048,18,0,29.7597,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,4608,7168,18,0,52.059,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,7168,2304,18,0,30.3301,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +128,512,7168,8,0,22.7425,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +128,4096,512,18,0,9.5268,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,1536,7168,18,0,48.4922,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,3072,1536,18,0,23.6121,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,576,7168,8,0,29.1325,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +256,7168,256,16,0,14.4945,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +256,7168,2048,18,0,53.0526,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,4608,7168,18,0,98.0432,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,7168,2304,18,0,55.3778,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +256,512,7168,8,0,28.6397,a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +256,4096,512,16,0,15.6425,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +512,1536,7168,18,0,78.3759,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,3072,1536,18,0,36.985,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,576,7168,18,0,32.2117,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,7168,256,16,0,24.5481,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +512,7168,2048,18,0,100.978,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,4608,7168,18,0,189.044,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,7168,2304,0,0,100.9292,a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +512,512,7168,18,0,32.1833,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +512,4096,512,16,0,24.0829,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1024,1536,7168,18,0,123.9025,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1024,3072,1536,16,0,64.8831,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1024,576,7168,18,0,49.8942,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1024,7168,256,16,0,42.5982,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1024,7168,2048,18,0,188.4792,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1024,4608,7168,18,0,350.7798,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1024,7168,2304,3,0,183.8311,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +1024,512,7168,18,0,49.2014,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1024,4096,512,16,0,39.537,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1536,1536,7168,18,0,189.5424,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1536,3072,1536,16,0,92.874,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1536,576,7168,18,0,79.0635,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1536,7168,256,16,0,60.2766,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +1536,7168,2048,2,0,262.0463,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +1536,4608,7168,18,0,510.856,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1536,7168,2304,3,0,265.2131,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +1536,512,7168,18,0,78.6707,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +1536,4096,512,16,0,56.7614,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +2048,1536,7168,18,0,235.7629,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +2048,3072,1536,16,0,120.8353,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +2048,576,7168,18,0,97.866,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +2048,7168,256,16,0,78.0207,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +2048,7168,2048,2,0,350.2674,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +2048,4608,7168,18,0,669.2975,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +2048,7168,2304,3,0,347.8874,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +2048,512,7168,18,0,96.8336,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +2048,4096,512,16,0,72.0211,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +4096,1536,7168,18,0,463.6419,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +4096,3072,1536,16,0,226.6441,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +4096,576,7168,18,0,188.8516,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +4096,7168,256,16,0,149.7958,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +4096,7168,2048,2,0,672.5115,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +4096,4608,7168,2,0,1320.8077,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +4096,7168,2304,3,0,680.5575,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +4096,512,7168,18,0,170.5095,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +4096,4096,512,16,0,135.8522,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +8192,1536,7168,18,0,901.7256,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +8192,3072,1536,16,0,439.0326,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +8192,576,7168,18,0,353.605,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +8192,7168,256,16,0,293.0472,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +8192,7168,2048,0,0,1314.0525,a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +8192,4608,7168,2,0,2604.0725,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +8192,7168,2304,3,0,1344.1438,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +8192,512,7168,18,0,312.3824,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +8192,4096,512,16,0,263.0335,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +16384,1536,7168,2,0,1771.0959,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +16384,3072,1536,16,0,863.8991,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +16384,576,7168,18,0,680.1267,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +16384,7168,256,16,0,579.8835,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +16384,7168,2048,0,0,2605.2217,a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +16384,4608,7168,2,0,5173.8068,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +16384,7168,2304,3,0,2671.0604,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +16384,512,7168,18,0,612.0553,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +16384,4096,512,16,0,518.4013,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +20480,1536,7168,2,0,2183.4476,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +20480,3072,1536,16,0,1076.0683,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +20480,576,7168,18,0,843.8382,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +20480,7168,256,16,0,723.3741,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 +20480,7168,2048,0,0,3257.3851,a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +20480,4608,7168,2,0,6422.7243,a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +20480,7168,2304,3,0,3335.873,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3 +20480,512,7168,18,0,758.6562,a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1 +20480,4096,512,16,0,645.1858,a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1 diff --git a/aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv new file mode 100755 index 0000000000..359adb4736 --- /dev/null +++ b/aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv @@ -0,0 +1,118 @@ +M,N,K,kernelId,splitK,us,kernelName +16,1536,7168,8,0,20.3733,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,3072,1536,8,0,7.2748,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,576,7168,8,0,19.6829,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,7168,256,7,0,3.808,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +16,7168,2048,8,0,10.7236,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,4608,7168,8,0,20.9385,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,7168,2304,8,0,11.3236,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,512,7168,8,0,19.6857,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +16,4096,512,8,0,3.6836,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,1536,7168,8,0,20.6333,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,3072,1536,13,0,8.5716,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +32,576,7168,8,0,19.6709,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,7168,256,13,0,4.62,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +32,7168,2048,13,0,13.9641,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +32,4608,7168,13,0,24.8581,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +32,7168,2304,13,0,14.4225,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +32,512,7168,8,0,19.5977,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +32,4096,512,8,0,5.0272,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,1536,7168,8,0,24.1181,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,3072,1536,8,0,10.6444,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,576,7168,8,0,19.6377,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,7168,256,7,0,7.3096,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +64,7168,2048,13,0,18.8161,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +64,4608,7168,18,0,35.849,a8w8_blockscale_wpreshuffle_1x128x128_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +64,7168,2304,13,0,19.6389,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +64,512,7168,8,0,19.6893,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +64,4096,512,8,0,7.0608,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +128,1536,7168,13,0,33.8481,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +128,3072,1536,13,0,14.1409,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +128,576,7168,8,0,19.6941,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +128,7168,256,13,0,9.0728,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +128,7168,2048,12,0,31.4937,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +128,4608,7168,18,0,61.6147,a8w8_blockscale_wpreshuffle_1x128x128_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +128,7168,2304,10,0,32.3685,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +128,512,7168,8,0,19.7121,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +128,4096,512,8,0,8.9284,a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1 +256,1536,7168,13,0,50.8314,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +256,3072,1536,13,0,21.5577,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +256,576,7168,13,0,24.1441,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +256,7168,256,7,0,13.3049,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +256,7168,2048,9,0,56.027,a8w8_blockscale_wpreshuffle_1x128x128_256x32x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +256,4608,7168,14,0,109.4776,a8w8_blockscale_wpreshuffle_1x128x128_256x64x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +256,7168,2304,15,0,56.3062,a8w8_blockscale_wpreshuffle_1x128x128_256x64x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +256,512,7168,13,0,23.8285,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +256,4096,512,7,0,12.9837,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +512,1536,7168,13,0,80.5739,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +512,3072,1536,10,0,37.0022,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +512,576,7168,12,0,34.7126,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +512,7168,256,7,0,21.8937,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +512,7168,2048,0,0,93.6516,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +512,4608,7168,0,0,179.194,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +512,7168,2304,0,0,87.4988,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +512,512,7168,18,0,35.4666,a8w8_blockscale_wpreshuffle_1x128x128_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +512,4096,512,13,0,20.5525,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +1024,1536,7168,18,0,140.8366,a8w8_blockscale_wpreshuffle_1x128x128_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +1024,3072,1536,9,0,65.4555,a8w8_blockscale_wpreshuffle_1x128x128_256x32x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +1024,576,7168,12,0,55.5282,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +1024,7168,256,7,0,38.4886,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +1024,7168,2048,0,0,170.8535,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1024,4608,7168,0,0,361.3259,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1024,7168,2304,0,0,161.2487,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1024,512,7168,12,0,59.4271,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +1024,4096,512,12,0,34.851,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +1536,1536,7168,0,0,185.4252,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1536,3072,1536,0,0,83.9684,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1536,576,7168,12,0,82.9763,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +1536,7168,256,4,0,54.707,a8w8_blockscale_wpreshuffle_1x128x128_256x16x256x128_8x16_16x16_16x16x1_8x32x1_1x16x1x16_8_1x2_intrawave_v1 +1536,7168,2048,0,0,253.9371,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1536,4608,7168,0,0,537.6935,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1536,7168,2304,0,0,237.1646,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +1536,512,7168,13,0,81.02,a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1 +1536,4096,512,12,0,49.8538,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +2048,1536,7168,9,0,265.9643,a8w8_blockscale_wpreshuffle_1x128x128_256x32x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +2048,3072,1536,0,0,105.3044,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +2048,576,7168,12,0,105.8376,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +2048,7168,256,4,0,70.4467,a8w8_blockscale_wpreshuffle_1x128x128_256x16x256x128_8x16_16x16_16x16x1_8x32x1_1x16x1x16_8_1x2_intrawave_v1 +2048,7168,2048,0,0,331.0246,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +2048,4608,7168,0,0,705.4126,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +2048,7168,2304,0,0,312.3609,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +2048,512,7168,14,0,108.9081,a8w8_blockscale_wpreshuffle_1x128x128_256x64x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v1 +2048,4096,512,12,0,63.5987,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +4096,1536,7168,0,0,457.5092,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,3072,1536,0,0,198.6353,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,576,7168,0,0,195.5016,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,7168,256,4,0,133.113,a8w8_blockscale_wpreshuffle_1x128x128_256x16x256x128_8x16_16x16_16x16x1_8x32x1_1x16x1x16_8_1x2_intrawave_v1 +4096,7168,2048,0,0,623.5339,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,4608,7168,0,0,1308.8472,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,7168,2304,0,0,592.0273,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,512,7168,0,0,180.0848,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +4096,4096,512,12,0,120.5713,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +8192,1536,7168,0,0,891.405,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,3072,1536,0,0,387.0061,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,576,7168,0,0,379.1368,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,7168,256,4,0,258.6827,a8w8_blockscale_wpreshuffle_1x128x128_256x16x256x128_8x16_16x16_16x16x1_8x32x1_1x16x1x16_8_1x2_intrawave_v1 +8192,7168,2048,0,0,1208.9008,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,4608,7168,0,0,2517.7124,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,7168,2304,0,0,1154.5041,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +8192,512,7168,12,0,352.4495,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +8192,4096,512,12,0,233.8998,a8w8_blockscale_wpreshuffle_1x128x128_256x32x128x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x2_intrawave_v1 +16384,1536,7168,0,0,1768.6828,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,3072,1536,0,0,749.7908,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,576,7168,0,0,743.4744,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,7168,256,7,0,532.4307,a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1 +16384,7168,2048,0,0,2454.5874,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,4608,7168,0,0,5013.9511,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,7168,2304,0,0,2296.1438,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,512,7168,0,0,649.6352,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +16384,4096,512,0,0,455.798,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,1536,7168,0,0,2140.6012,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,3072,1536,0,0,923.4112,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,576,7168,0,0,926.0608,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,7168,256,4,0,635.9159,a8w8_blockscale_wpreshuffle_1x128x128_256x16x256x128_8x16_16x16_16x16x1_8x32x1_1x16x1x16_8_1x2_intrawave_v1 +20480,7168,2048,0,0,2992.9348,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,4608,7168,0,0,6238.7055,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,7168,2304,0,0,2895.9425,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,512,7168,0,0,753.5264,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 +20480,4096,512,0,0,564.9256,a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3 diff --git a/aiter/configs/a8w8_blockscale_wpreshuffle_untuned_gemm.csv b/aiter/configs/a8w8_blockscale_wpreshuffle_untuned_gemm.csv new file mode 100644 index 0000000000..8a8bcd7154 --- /dev/null +++ b/aiter/configs/a8w8_blockscale_wpreshuffle_untuned_gemm.csv @@ -0,0 +1,234 @@ +M,N,K +16, 1536, 7168 + +16, 3072, 1536 + +16, 576, 7168 + +16, 7168, 256 + +16, 7168, 2048 + +16, 4608, 7168 + +16, 7168, 2304 + +16, 512, 7168 + +16, 4096, 512 + +32, 1536, 7168 + +32, 3072, 1536 + +32, 576, 7168 + +32, 7168, 256 + +32, 7168, 2048 + +32, 4608, 7168 + +32, 7168, 2304 + +32, 512, 7168 + +32, 4096, 512 + +64, 1536, 7168 + +64, 3072, 1536 + +64, 576, 7168 + +64, 7168, 256 + +64, 7168, 2048 + +64, 4608, 7168 + +64, 7168, 2304 + +64, 512, 7168 + +64, 4096, 512 + +128, 1536, 7168 + +128, 3072, 1536 + +128, 576, 7168 + +128, 7168, 256 + +128, 7168, 2048 + +128, 4608, 7168 + +128, 7168, 2304 + +128, 512, 7168 + +128, 4096, 512 + +256, 1536, 7168 + +256, 3072, 1536 + +256, 576, 7168 + +256, 7168, 256 + +256, 7168, 2048 + +256, 4608, 7168 + +256, 7168, 2304 + +256, 512, 7168 + +256, 4096, 512 + +512, 1536, 7168 + +512, 3072, 1536 + +512, 576, 7168 + +512, 7168, 256 + +512, 7168, 2048 + +512, 4608, 7168 + +512, 7168, 2304 + +512, 512, 7168 + +512, 4096, 512 + +1024, 1536, 7168 + +1024, 3072, 1536 + +1024, 576, 7168 + +1024, 7168, 256 + +1024, 7168, 2048 + +1024, 4608, 7168 + +1024, 7168, 2304 + +1024, 512, 7168 + +1024, 4096, 512 + +1536, 1536, 7168 + +1536, 3072, 1536 + +1536, 576, 7168 + +1536, 7168, 256 + +1536, 7168, 2048 + +1536, 4608, 7168 + +1536, 7168, 2304 + +1536, 512, 7168 + +1536, 4096, 512 + +2048, 1536, 7168 + +2048, 3072, 1536 + +2048, 576, 7168 + +2048, 7168, 256 + +2048, 7168, 2048 + +2048, 4608, 7168 + +2048, 7168, 2304 + +2048, 512, 7168 + +2048, 4096, 512 + +4096, 1536, 7168 + +4096, 3072, 1536 + +4096, 576, 7168 + +4096, 7168, 256 + +4096, 7168, 2048 + +4096, 4608, 7168 + +4096, 7168, 2304 + +4096, 512, 7168 + +4096, 4096, 512 + +8192, 1536, 7168 + +8192, 3072, 1536 + +8192, 576, 7168 + +8192, 7168, 256 + +8192, 7168, 2048 + +8192, 4608, 7168 + +8192, 7168, 2304 + +8192, 512, 7168 + +8192, 4096, 512 + +16384, 1536, 7168 + +16384, 3072, 1536 + +16384, 576, 7168 + +16384, 7168, 256 + +16384, 7168, 2048 + +16384, 4608, 7168 + +16384, 7168, 2304 + +16384, 512, 7168 + +16384, 4096, 512 + +20480, 1536, 7168 + +20480, 3072, 1536 + +20480, 576, 7168 + +20480, 7168, 256 + +20480, 7168, 2048 + +20480, 4608, 7168 + +20480, 7168, 2304 + +20480, 512, 7168 + +20480, 4096, 512 \ No newline at end of file diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index d815a0e110..8c07f4fce8 100644 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -378,6 +378,36 @@ "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gen_instances.py --working_path {{}} --tune'" }, + "module_gemm_a8w8_blockscale_wpreshuffle": { + "srcs": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/include'", + "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_wpreshuffle_pybind.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [], + "verbose": "False", + "is_python_module": "True", + "is_standalone": "False", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CORE_DIR}/aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv'" + }, + "module_gemm_a8w8_blockscale_wpreshuffle_tune": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_wpreshuffle_tune_pybind.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/include'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [], + "verbose": "False", + "is_python_module": "True", + "is_standalone": "False", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_wpreshuffle/gen_instances.py --working_path {{}} --tune'" + }, "module_aiter_operator": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/aiter_operator_pybind.cu'", diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 09f3b864f7..5c7f093e38 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -64,6 +64,46 @@ def flatmm_a8w8_blockscale_asm( ): ... +@compile_ops("module_gemm_a8w8_blockscale_wpreshuffle", fc_name="gemm_a8w8_blockscale_wpreshuffle") +def gemm_a8w8_blockscale_wpreshuffle( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, +): ... + + +@compile_ops("module_gemm_a8w8_blockscale_wpreshuffle_tune",fc_name="gemm_a8w8_blockscale_wpreshuffle_tune") +def gemm_a8w8_blockscale_wpreshuffle_tune( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, + kernelId: int, + splitK = 0 +): ... + + +def gemm_a8w8_blockscale_wpreshuffle_CK( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + dtype=torch.bfloat16 +): + assert dtype in [ + torch.bfloat16, + torch.float16, + ], f"Output {dtype=} is currently not supported in gemm_a8w8" + m = XQ.shape[0] + n = WQ.shape[0] + k = XQ.shape[-1] + Y = torch.empty(m, n, dtype=dtype, device=XQ.device) + return gemm_a8w8_blockscale_wpreshuffle(XQ, WQ, x_scale, w_scale, Y) + + @functools.lru_cache(maxsize=1024) def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k: int): diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu index 6ea1b07c8d..71e8bc9b6a 100755 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu @@ -29,6 +29,23 @@ using BlockwiseKernelMap = std::unordered_map< BlockwiseKernel, IntTupleHash>; +template +BlockwiseKernel blockwise_heuristic_dispatch(int M, int N, int K) +{ + if (M <= 16) + { + return a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1; + } + else if (M <= 32) + { + return a8w8_blockscale_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1; + } + else + { + return a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1; + } +} + // Helper function to return the next largest power of 2 static constexpr int nextPow2(unsigned int num) { @@ -83,7 +100,7 @@ BlockwiseKernel blockscale_dispatch(int M, int N, int K) return it->second; } // Otherwise, use heuristics. - return a8w8_blockscale_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1; + return blockwise_heuristic_dispatch(M, N, K); } torch::Tensor gemm_a8w8_blockscale( diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_common.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_common.py index 4009c20828..768f2ccca0 100755 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_common.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_common.py @@ -94,6 +94,8 @@ def name(self) -> str: ###############| | | | | | | | | | | | | | | | | | | | | | # Compute friendly - (-1): kernelInstance( 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], 1, 2, [1, 16, 1, 16], [8], "Intrawave", 1 ), + (-1): kernelInstance( 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, [16, 16, 1], [16, 16, 1], 1, 1, [1, 16, 1, 16], [4], "Intrawave", 1), + (-2): kernelInstance( 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, [16, 16, 1], [16, 16, 1], 2, 1, [1, 32, 1, 8], [8], "Intrawave", 1), + (-3): kernelInstance( 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, [8, 32, 1], [8, 32, 1], 1, 1, [1, 32, 1, 8], [8], "Intrawave", 1), } diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/README.md b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/README.md new file mode 100755 index 0000000000..9266803f76 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/README.md @@ -0,0 +1,18 @@ +# CK gemm a8w8 tune + +1. Install aiter: +`python3 setup.py develop` + +2. Tune gemm a8w8: + First add GEMM shapes in `aiter/configs/a8w8_blockscale_wpreshuffle_untuned_gemm.csv`, then run the following cmd to start tuning, please wait a few minutes as it will build gemm_a8w8_blockscale_wpreshuffle_tune via jit: +`python3 csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.py -i aiter/configs/a8w8_blockscale_wpreshuffle_untuned_gemm.csv -o aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv` +If you want to use split K kernels, you can add the `-k` parameter at the end, notice that should change `bias` to `bias/(2^k)`. +You can find the results of the tuning in `aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv`. + +3. Test the performance, modify the test instance in `op_tests/test_gemm_a8w8_blockscale_wpreshuffle.py` and run it, please wait a few minutes as it will build gemm_a8w8_blockscale_wpreshuffle kernels in `aiter/configs/a8w8_blockscale_tuned_gemm.csv` via jit: +`python3 op_tests/test_gemm_a8w8_blockscale.py` + + +## More +If you want to re-install gemm_a8w8_blockscale_wpreshuffle, you should remove `aiter/jit/module_gemm_a8w8_blockscale_wpreshuffle.so` and `aiter/jit/build/module_gemm_a8w8_blockscale_wpreshuffle` first. +If you use flag `PREBUILD_KERNELS=1 USE_CK_A8W8=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_blockscale_wpreshuffle_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_blockscale_wpreshuffle_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. \ No newline at end of file diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle.cu b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle.cu new file mode 100755 index 0000000000..02c14cffa2 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle.cu @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_blockscale_wpreshuffle_common.cuh" +#include "gemm_a8w8_blockscale_wpreshuffle_manifest.h" +#include "gemm_a8w8_blockscale_wpreshuffle_lookup.h" +#include + +using BlockwiseKernel = std::function< + torch::Tensor(torch::Tensor &, torch::Tensor &, + torch::Tensor &, torch::Tensor &, + torch::Tensor &)>; + +// Define a custom hash function for std::tuple +struct IntTupleHash +{ + size_t operator()(const std::tuple &t) const + { + auto hash1 = std::hash{}(std::get<0>(t)); + auto hash2 = std::hash{}(std::get<1>(t)); + auto hash3 = std::hash{}(std::get<2>(t)); + return hash1 ^ hash2 ^ hash3; + } +}; + +using BlockwiseKernelMap = std::unordered_map< + std::tuple, + BlockwiseKernel, + IntTupleHash>; + +template +BlockwiseKernel blockwise_heuristic_dispatch(int M, int N, int K) +{ + if (M <= 16 || (M <= 128 && N * K <= 512 * 7168)) + { + return a8w8_blockscale_wpreshuffle_1x128x128_256x16x64x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1; + } + else if (M <= 32) + { + return a8w8_blockscale_wpreshuffle_1x128x128_256x32x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8_2x1_intrawave_v1; + } + else if (K < 320) + { + return a8w8_blockscale_wpreshuffle_1x128x128_256x16x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8_1x2_intrawave_v1; + } + else + { + return a8w8_blockscale_wpreshuffle_1x128x128_256x128x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3; + } +} + +// Helper function to return the next largest power of 2 +static constexpr int nextPow2(unsigned int num) +{ + if (num <= 1) + return 1; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +BlockwiseKernel blockscale_dispatch(int M, int N, int K) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + static const auto lookup = [] + { + if constexpr (std::is_same_v) { + return BlockwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + } else if constexpr (std::is_same_v) { + return BlockwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; + } else { + static_assert(false, "blockscale_dispatch used with unsupported dtype!"); + } }(); + + // First check if this shape(M,N,K) is available in the direct lookup. + auto it = lookup.find({M, N, K}); + // If we found an optimal kernel, use it. + if (it != lookup.end()) + { + return it->second; + } + + int padded_m = M; + if (M > 1 && M <= 16) + { + padded_m = 16; + } + else if (M <= 16384) + { + padded_m = nextPow2(M); + } + else if (M <= 20480) + { + padded_m = 20480; + } + // Second check if this shape(padded_m,N,K) is available in the direct lookup. + it = lookup.find({padded_m, N, K}); + // If we found an optimal kernel, use it. + if (it != lookup.end()) + { + return it->second; + } + // Otherwise, use heuristics. + return blockwise_heuristic_dispatch(M, N, K); +} + +torch::Tensor gemm_a8w8_blockscale_wpreshuffle( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y) +{ + TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), + "Scales should have the same dtype!"); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_common.py b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_common.py new file mode 100755 index 0000000000..5a50706c17 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_common.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +from dataclasses import dataclass + +@dataclass +class kernelInstance: + BLOCK_SIZE: int + ScaleBlockM: int + ScaleBlockN: int + ScaleBlockK: int + MPerBLOCK: int + NPerBLOCK: int + KPerBLOCK: int + AK1: int + BK1: int + MPerXDL: int + NPerXDL: int + WAVE_MAP_M: int + WAVE_MAP_N: int + ABLOCK_TRANSFER: list[int] + BBLOCK_TRANSFER: list[int] + CSHUFFLE_MX_PER_WAVE_PERSHUFFLE: int + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE: int + CBLOCK_TRANSFER: list[int] + CBLOCK_SPV: list[int] + PIPELINE_Sched: str + PIPELINE_VERSION: int + + @property + def name(self) -> str: + return ("_").join([ + "a8w8_blockscale_wpreshuffle", + ("x").join(map(lambda x: str(x), [ + self.ScaleBlockM, self.ScaleBlockN, self.ScaleBlockK])), + ("x").join(map(lambda x: str(x), [ + self.BLOCK_SIZE, self.MPerBLOCK, self.NPerBLOCK, self.KPerBLOCK])), + ("x").join(map(lambda x: str(x), [ + self.AK1, self.BK1])), + ("x").join(map(lambda x: str(x), [ + self.MPerXDL, self.NPerXDL])), + ("x").join(map(lambda x: str(x), self.ABLOCK_TRANSFER)), + ("x").join(map(lambda x: str(x), self.BBLOCK_TRANSFER)), + ("x").join(map(lambda x: str(x), self.CBLOCK_TRANSFER)), + ("x").join(map(lambda x: str(x), self.CBLOCK_SPV)), + ("x").join(map(lambda x: str(x), [self.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, self.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE])), + self.PIPELINE_Sched.lower(), + f"v{self.PIPELINE_VERSION}" + ]) + +kernels_list = { + # clang-format off + ###############| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| BBlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + ###############| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline | Pipeline| + ###############| | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| Lengths_K0_N_K1| PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler | Verision| + ###############| | | | | | | | | | | | | | | | | | | | | | + + # Compute friendly + 0: kernelInstance( 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 4, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 3 ), + 1: kernelInstance( 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 4, 2, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 3 ), + 2: kernelInstance( 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 3 ), + 3: kernelInstance( 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 4, 1, [8, 32, 1], [8, 32, 1], 2, 1, [1, 32, 1, 8], [8], "Intrawave", 3 ), + + # Memory friendly + 4: kernelInstance( 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, [16, 16, 1], [8, 32, 1], 1, 2, [1, 16, 1, 16], [8], "Intrawave", 1, ), + 5: kernelInstance( 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, [16, 16, 1], [8, 32, 1], 1, 2, [1, 16, 1, 16], [8], "Intrawave", 1, ), + 6: kernelInstance( 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, [16, 16, 1], [8, 32, 1], 1, 1, [1, 16, 1, 16], [4], "Intrawave", 1, ), + 7: kernelInstance( 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], 1, 2, [1, 16, 1, 16], [8], "Intrawave", 1, ), + 8: kernelInstance( 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, [16, 16, 1], [16, 16, 1], 1, 1, [1, 16, 1, 16], [4], "Intrawave", 1, ), + + 9: kernelInstance( 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 4, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 10: kernelInstance( 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 2, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 11: kernelInstance( 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, [8, 32, 1], [8, 32, 1], 2, 1, [1, 32, 1, 8], [8], "Intrawave", 1, ), + + 12: kernelInstance( 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 2, [16, 16, 1], [16, 16, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 13: kernelInstance( 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, [16, 16, 1], [16, 16, 1], 2, 1, [1, 32, 1, 8], [8], "Intrawave", 1, ), + + 14: kernelInstance( 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 4, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 15: kernelInstance( 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 16: kernelInstance( 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 2, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + + 17: kernelInstance( 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 2, [16, 16, 1], [16, 16, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ), + 18: kernelInstance( 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 2, [16, 16, 1], [16, 16, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 1, ) + # clang-format on +} + + + + +default_kernels_dict = { + # clang-format off + ##############| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| BBlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + ###############| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline | Pipeline| + ###############| | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| Lengths_K0_N_K1| PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler | Verision| + ###############| | | | | | | | | | | | | | | | | | | | | | + + # Compute friendly + (-1): kernelInstance( 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, [16, 16, 1], [16, 16, 1], 1, 1, [1, 16, 1, 16], [4], "Intrawave", 1), + (-2): kernelInstance( 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, [16, 16, 1], [16, 16, 1], 2, 1, [1, 32, 1, 8], [8], "Intrawave", 1), + (-3): kernelInstance( 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 4, [8, 32, 1], [8, 32, 1], 2, 2, [1, 32, 1, 8], [8], "Intrawave", 3), + (-4): kernelInstance( 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], 1, 2, [1, 16, 1, 16], [8], "Intrawave", 1), +} diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.cu b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.cu new file mode 100755 index 0000000000..8399b078f3 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.cu @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_blockscale_wpreshuffle_common.cuh" +#include "gemm_a8w8_blockscale_wpreshuffle_manifest.h" +#include "gemm_a8w8_blockscale_wpreshuffle_lookup.h" +#include + +using BlockwiseKernel = std::function< + torch::Tensor(torch::Tensor &, torch::Tensor &, + torch::Tensor &, torch::Tensor &, + torch::Tensor &)>; + +// For certain high priority shapes, we directly use the best kernel rather +// than use heuristics. +using BlockwiseKernelMap = std::unordered_map< + int, + BlockwiseKernel>; + +// Helper function to return the next largest power of 2 +static constexpr int nextPow2(unsigned int num) +{ + if (num <= 1) + return 1; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +BlockwiseKernel blockwise_dispatch(int id) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + // First check if this shape is available in the direct lookup. + static const auto lookup = [] + { + if constexpr (std::is_same_v) { + return BlockwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + } else if constexpr (std::is_same_v) { + return BlockwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; + } else { + static_assert(false, "blockwise_dispatch used with unsupported dtype!"); + } }(); + + TORCH_CHECK(id < lookup.size(), + "Kernel id " + std::to_string(id) +" is out of range!"); + auto it = lookup.find(id); + // If we found an optimal kernel, use it. + if (it != lookup.end()) + { + return it->second; + } + // Otherwise, use heuristics. + return lookup.find(0)->second; +} + + + +torch::Tensor gemm_a8w8_blockscale_wpreshuffle_tune( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int splitK) +{ + TORCH_CHECK(XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(), + "Weights and activations should both be fp8!"); + TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), + "Scales should have the same dtype!"); + std::optional bias = std::nullopt; + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = std::pow(2, splitK); + + if (Y.dtype() == at::ScalarType::BFloat16) + { + blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.py b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.py new file mode 100755 index 0000000000..251f86e8d3 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gemm_a8w8_blockscale_wpreshuffle_tune.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +import os +import sys +import aiter +import pandas as pd +import torch +import torch.nn.functional as F +import aiter +from aiter.test_common import checkAllclose, perftest +from aiter.ops.shuffle import shuffle_weight +from gemm_a8w8_blockscale_wpreshuffle_common import kernelInstance, kernels_list +import argparse +from einops import rearrange + +block_shape = (128, 128) + +def checkClose(a, b, rtol=1e-3, atol=0.01): + isClose = torch.isclose(a, b, rtol=rtol, atol=atol) + mask = ~isClose + if isClose.all(): + return True + else: + percent = (a[mask]).numel()/a.numel() + if percent > 0.01: + return False + else: + return True + +def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + block_shape_n, block_shape_k = block_shape + m, k = x.shape + n = weight.shape[0] + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + # x_scale = rearrange(x_scale.view(-1, 1).repeat(1, block_shape_n*block_shape_k).view(m, scale_k, 1, block_shape_k), + # 'num_blk_n num_blk_k blk_n blk_k ->(num_blk_n blk_n) (num_blk_k blk_k)') + x = x.to(x_scale.dtype).view(m, k//block_shape[1], block_shape[1]) * x_scale.unsqueeze(-1) + x = x.view(m, k) + + w_scale = rearrange(w_scale.view(-1, 1).repeat(1, block_shape_n*block_shape_k).view(scale_n, scale_k, block_shape_n, block_shape_k), + 'num_blk_n num_blk_k blk_n blk_k -> (num_blk_n blk_n) (num_blk_k blk_k)') + w_scale = w_scale[:n, :k] + weight = weight.to(w_scale.dtype) * w_scale + + out = F.linear(x.to(torch.float32), weight.to(torch.float32)) + # scale = torch.matmul(x_scale, w_scale) + # out = torch.mul(x, scale) + if bias is not None: + out = out.to(bias) + bias + return out.to(dtype) + +def get_untuned_gemm_list(untuned_gemm_file): + assert os.path.exists(untuned_gemm_file), f"Not exist a8w8_untuned_gemm.csv file: {untuned_gemm_file}" + untunedf = pd.read_csv(untuned_gemm_file) + return untunedf + +def get_tuned_gemm_list(tuned_gemm_file): + if os.path.exists(tuned_gemm_file): + tunedf = pd.read_csv(tuned_gemm_file) + else: + tunedf = pd.DataFrame(columns=["M", "N", "K", "kernelId", "splitK", "us", "kernelName"]) + return tunedf + +@perftest() +def kernel_instance_test(x, weight, x_scale, w_scale, out, kernel_id, splitK=0): + aiter.gemm_a8w8_blockscale_wpreshuffle_tune(x, weight, x_scale, w_scale, out, kernel_id, splitK) + return out + + +def tune_gemm(m, n, k, useSplitK = False): + dim = (m, n, k) + block_shape_n, block_shape_k = block_shape + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + x = (torch.rand((m, k), dtype=torch.float16, device="cuda")/10).to(torch.float8_e4m3fnuz) + weight = (torch.rand( (n, k), dtype=torch.float16, device="cuda")/10).to(torch.float8_e4m3fnuz) + weight_shuffle = shuffle_weight(weight, layout=(16, 16)) + x_scale = torch.rand([m, scale_k], dtype=torch.float32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") + out = torch.empty(m, n, dtype=torch.bfloat16, device="cuda") + + ref_out = run_torch(x, weight, x_scale, w_scale) + + print(f"*******************M:{m} X N:{n} X K:{k}**************************") + print(f"Start tuning a8w8 gemm kernel for M:{m}, N:{n}, K{k}:") + kernels_num = len(kernels_list) + best_kernelConfig = (-1, 0) + best_time = -1 + for i in range(kernels_num): + kernel = kernels_list[i] + maxsplitK = aiter.compute_gemm_SplitK(m, n, k, kernel.MPerBLOCK, kernel.NPerBLOCK, kernel.KPerBLOCK) \ + if useSplitK else 0 + for splitK in range(maxsplitK+1): + try: + (out), avg_t = kernel_instance_test(x, weight_shuffle, x_scale, w_scale, out, i, splitK) + isClosed = checkClose(ref_out, out, rtol=1e-2, atol=0.1) + if isClosed: + print(f"{str(dim):<20} kernelid:{i:<3d}\t avg: {avg_t:<8.2f} us, {kernel.name}, {splitK=}") + if best_time < 0 or avg_t < best_time: + best_kernelConfig = (i, splitK) + best_time = avg_t + else: + print(f"{str(dim):<20} kernelid:{i:<3d}\t No pass , {kernel.name}, {splitK=}") + except RuntimeError as e: + print(e) + print(f"{str(dim):<20} kernelid:{i:<3d}\t No support , {kernel.name}, {splitK=}") + + best_kernelId, splitK = best_kernelConfig + if best_kernelConfig[0] == -1: + print(f"No kernel can be used for M:{m}, N:{n}, K:{k}") + best_time = 'nan' + else: + best_time = round(best_time, 4) + + print(f"Tuning result for M:{m}, N:{n}, K:{k} is kernelId={best_kernelId} {kernels_list[best_kernelId].name} {splitK=}, {best_time}us") + print(f"*******************M:{m} X N:{n} X K{k}**************************") + + return best_kernelId, splitK, best_time + + +def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): + for i in range(len(untunedf)): + M = untunedf.loc[i, "M"] + N = untunedf.loc[i, "N"] + K = untunedf.loc[i, "K"] + + if tunedf[(tunedf["M"]==M) & (tunedf["N"]==N) & (tunedf["K"]==K)].empty: + kernelId, splitK, time = tune_gemm(M, N, K, useSplitK) + kernelName = 'None' if kernelId == -1 else kernels_list[kernelId].name + temp = pd.DataFrame({"M":[M], "N":[N], "K":[K], "kernelId":[kernelId], "splitK":[splitK], + "us":[time], "kernelName":[kernelName]}) + tunedf = pd.concat([tunedf, temp], ignore_index=True) + + else: + print(f"M:{M}, N:{N}, K{K} is in tuned gemm, skip!!!") + print() + print() + if issorted: + tunedf = tunedf.sort_values(by=["M", "N", "K"]) + print("Totall tuning result:") + print(tunedf) + return tunedf + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm a8w8 kernel", + ) + + parser.add_argument( + "-i", + "--untune_file", + default="aiter/configs/a8w8_blockscale_untuned_gemm.csv", + required=False, + help="input" + ) + + parser.add_argument( + "-o", + "--tune_file", + default="aiter/configs/a8w8_blockscale_tuned_gemm.csv", + required=False, + help="output: tuning result store this file" + ) + + parser.add_argument( + "-k", + "--splitK", + action='store_true', + required=False, + help="Use splitK kernels" + ) + + parser.add_argument( + "--sort", + action='store_true', + required=False, + help="Arranged according to the M N K size" + ) + + args = parser.parse_args() + untunedf = get_untuned_gemm_list(args.untune_file) + tunedf = get_tuned_gemm_list(args.tune_file) + tunedf = tune_gemm_list(untunedf, tunedf, args.sort, args.splitK) + tunedf.to_csv(args.tune_file, index=False) diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gen_instances.py new file mode 100755 index 0000000000..7bd106c427 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/gen_instances.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +import os +import sys +from dataclasses import dataclass +import copy +from pathlib import Path +import pandas as pd +import argparse +import shutil +from gemm_a8w8_blockscale_wpreshuffle_common import kernelInstance, kernels_list, default_kernels_dict + + +""" + +a8w8_blockscale_wpreshuffle_gemm instance gen + +""" + + +class gemm_a8w8_blockscale_wpreshuffle_codegen: + def __init__(self, working_path, istune=False): + self.working_path = working_path + self.impl_path = os.path.join(working_path, "impl") + self.instances_path = os.path.join(working_path, "instances") + self.istune = istune + + def gen_instance(self, k: kernelInstance): + INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_blockscale_wpreshuffle_common.cuh" + +template +torch::Tensor +{k.name}( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y + ) +{{{{ + // The smallest kernel we have available. Works well for memory bound shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % {k.MPerBLOCK} != 0) || (N % {k.NPerBLOCK} != 0) || (K % ({k.KPerBLOCK}) != 0); + if (pad) + {{{{ + // pad + {{INSTANCE_CONTENT_pad}} + // pad + }}}} + else + {{{{ + // no pad + {{INSTANCE_CONTENT_nopad}} + // no pad + }}}} +}}}} + +""" + + INSTANCE_CONTENT_nobias = f"""using DeviceGemmInstance = DeviceGemmHelperF8BlockScale< + DDataType, EDataType, + {k.BLOCK_SIZE}, + {k.ScaleBlockM}, {k.ScaleBlockN}, {k.ScaleBlockK}, + {k.MPerBLOCK}, {k.NPerBLOCK}, {k.KPerBLOCK}, + {k.AK1}, {k.BK1}, + {k.MPerXDL}, {k.NPerXDL}, + {k.WAVE_MAP_M}, {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, + ck::BlockGemmPipelineScheduler::{k.PIPELINE_Sched}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; + // Run kernel instance. + return gemm_a8w8_blockscale_wpreshuffle_impl(XQ, WQ, x_scale, w_scale, Y); +""" + if self.istune: + INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=(INSTANCE_CONTENT_nobias.format(GemmSpec="MNKPadding")), + INSTANCE_CONTENT_nopad=(INSTANCE_CONTENT_nobias.format(GemmSpec="Default"))) + else: + INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=INSTANCE_CONTENT_nobias.format(GemmSpec="MNKPadding"), + INSTANCE_CONTENT_nopad=INSTANCE_CONTENT_nobias.format(GemmSpec="Default")) + + Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( + INSTANCE_IMPL_str) + + INSTANCE_template = """// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "{name}.cuh" + +template torch::Tensor +{name}<{dtypes}>( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y + ); + +""" + INSTANCE_dFP32_eBF16 = INSTANCE_template.format( + name=k.name, dtypes="F32, B16") + INSTANCE_dFP32_eFP16 = INSTANCE_template.format( + name=k.name, dtypes="F32, F16") + # TODO: dFP8_eFP8 + + if self.istune: + Path(os.path.join(self.instances_path, f"{k.name}_dBF16_eBF16.cpp")).write_text( + INSTANCE_dFP32_eBF16) + else: + Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eBF16.cpp")).write_text( + INSTANCE_dFP32_eBF16) + Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eFP16.cpp")).write_text( + INSTANCE_dFP32_eFP16) + + def gen_lookup_dict(self, kernels_dict): + LOOKUP_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#define GENERATE_LOOKUP_TABLE(DTYPE, ETYPE) \\ + { \\""" + + LOOKUP_template = """ + {{{MNK}, \\ + {kernel_name}}}, \\""" + + LOOKUP_end = """ + } + +#endif // USE_ROCM +""" + with open(os.path.join(self.working_path, "gemm_a8w8_blockscale_wpreshuffle_lookup.h"), "w") as f: + f.write(LOOKUP_head) + for mnk, k in kernels_dict.items(): + # print((", ").join(map(lambda x: str(x), list(mnk))), ":", k.name) + if not self.istune and (isinstance(mnk, tuple) and mnk[0] > 0): + f.write(LOOKUP_template.format(MNK="{"+(", ").join( + map(lambda x: str(x), list(mnk))) + "}", kernel_name=k.name)) + elif self.istune and isinstance(mnk, int): + f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) + f.write(LOOKUP_end) + + def gen_manifest_head(self, kernels_dict): + MAINFEST_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#include + +#include +""" + MAINFEST_template = """ +template +torch::Tensor +{kernel_name}( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y); +""" + MAINFEST_end = """ + +#endif // USE_ROCM +""" + + with open(os.path.join(self.working_path, "gemm_a8w8_blockscale_wpreshuffle_manifest.h"), "w") as f: + f.write(MAINFEST_head) + for mnk, k in kernels_dict.items(): + f.write(MAINFEST_template.format(kernel_name=k.name)) + f.write(MAINFEST_end) + + def gen_instances(self, kernels_dict): + if os.path.exists(self.impl_path): + shutil.rmtree(self.impl_path) + os.mkdir(self.impl_path) + if os.path.exists(self.instances_path): + shutil.rmtree(self.instances_path) + os.mkdir(self.instances_path) + + for mnk, k in kernels_dict.items(): + self.gen_instance(k) + + self.gen_lookup_dict(kernels_dict) + self.gen_manifest_head(kernels_dict) + + + +def get_tune_dict(tune_dict_csv): + tune_dict = default_kernels_dict + if os.path.exists(tune_dict_csv): + tune_df = pd.read_csv(tune_dict_csv) + for i in range(len(tune_df)): + M = tune_df.loc[i, "M"] + N = tune_df.loc[i, "N"] + K = tune_df.loc[i, "K"] + kid = tune_df.loc[i, "kernelId"] + tune_dict[(M, N, K)] = kernels_list[kid] + return tune_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm a8w8 kernel", + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + parser.add_argument( + "-f", + "--tune_file", + default="aiter/configs/a8w8_blockscale_tuned_gemm.csv", + required=False, + help="tune_file include the result after run gemm_a8w8_tune.py" + ) + + parser.add_argument( + "--tune", + action='store_true', + required=False, + help="generated tune instanses" + ) + + # parser.add_argument( + # "--out_type", + # default="all", + # required=False, + # help="Specifie the type of scale\n \ + # all: [bf16, fp16] \n \ + # bf16, fp16" + # ) + + # parser.add_argument( + # "--scale_type", + # default="all", + # required=False, + # help="Specifie the type of scale\n \ + # all: [fp32, same as out] \n \ + # same: [same as out]" + # ) + + + args = parser.parse_args() + codegen = gemm_a8w8_blockscale_wpreshuffle_codegen(args.working_path, args.tune) + + if args.tune: + codegen.gen_instances(kernels_list) + else: + codegen.gen_instances(get_tune_dict(args.tune_file)) diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle.h b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle.h new file mode 100755 index 0000000000..ae8da31c03 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle.h @@ -0,0 +1,20 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +torch::Tensor gemm_a8w8_blockscale_wpreshuffle( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y); + +torch::Tensor gemm_a8w8_blockscale_wpreshuffle_tune( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int splitK); diff --git a/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle_common.cuh b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle_common.cuh new file mode 100755 index 0000000000..dae20f7012 --- /dev/null +++ b/csrc/ck_gemm_a8w8_blockscale_wpreshuffle/include/gemm_a8w8_blockscale_wpreshuffle_common.cuh @@ -0,0 +1,166 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using B16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; +using I8 = int8_t; +using I32 = int; +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +// static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// static constexpr ck::index_t Scale_Block_M = 1; +// static constexpr ck::index_t Scale_Block_N = 128; +// static constexpr ck::index_t Scale_Block_K = 128; + +template +using DeviceGemmHelperF8BlockScale = ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, + 2, AK1, AK1, 0, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, BK1, BK1, 0, + CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, A0DataType>; +// clang-format on + +template +__forceinline__ torch::Tensor gemm_a8w8_blockscale_wpreshuffle_impl( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y) +{ + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + int StrideA = XQ.stride(-2); + int StrideB = WQ.stride(-2); + int StrideE = N; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_gemm = DeviceGemmInstance{}; + auto invoker = device_gemm.MakeInvoker(); + auto argument = device_gemm.MakeArgument(XQ.data_ptr(), + WQ.data_ptr(), + std::array{}, + reinterpret_cast(Y.data_ptr()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), + a_element_op, + b_element_op, + cde_element_op); + + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); + + invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + return Y; +} + +#endif // USE_ROCM diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 51278b931a..5df0c78d90 100755 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -197,6 +197,15 @@ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0, \ py::arg("splitK") = 0); +#define GEMM_A8W8_BLOCKSCALE_WPRESHUFFLE_PYBIND \ + m.def("gemm_a8w8_blockscale_wpreshuffle", &gemm_a8w8_blockscale_wpreshuffle, "fp8 blockscale gemm wpreshuffle", py::arg("XQ"), py::arg("WQ"), \ + py::arg("x_scale"), py::arg("w_scale"), py::arg("Out")); + +#define GEMM_A8W8_BLOCKSCALE_WPRESHUFFLE_TUNE_PYBIND \ + m.def("gemm_a8w8_blockscale_wpreshuffle_tune", &gemm_a8w8_blockscale_wpreshuffle_tune, "gemm_a8w8_blockscale_wpreshuffle_tune", py::arg("XQ"), py::arg("WQ"), \ + py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0, \ + py::arg("splitK") = 0); + #define GEMM_A8W8_PYBIND \ m.def("gemm_a8w8", &gemm_a8w8, "gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), \ diff --git a/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_pybind.cu b/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_pybind.cu new file mode 100755 index 0000000000..0fcf306c1a --- /dev/null +++ b/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_pybind.cu @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "rocm_ops.hpp" +#include "gemm_a8w8_blockscale_wpreshuffle.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + GEMM_A8W8_BLOCKSCALE_WPRESHUFFLE_PYBIND; +} diff --git a/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_tune_pybind.cu b/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_tune_pybind.cu new file mode 100644 index 0000000000..53048384ec --- /dev/null +++ b/csrc/pybind/gemm_a8w8_blockscale_wpreshuffle_tune_pybind.cu @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "rocm_ops.hpp" +#include "gemm_a8w8_blockscale_wpreshuffle.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + GEMM_A8W8_BLOCKSCALE_WPRESHUFFLE_TUNE_PYBIND; +} diff --git a/csrc/rocm_ops.cpp b/csrc/rocm_ops.cpp index 145f7384de..e5fd3a3843 100644 --- a/csrc/rocm_ops.cpp +++ b/csrc/rocm_ops.cpp @@ -10,6 +10,7 @@ #include "custom_all_reduce.h" #include "communication_asm.h" #include "gemm_a8w8_blockscale.h" +#include "gemm_a8w8_blockscale_wpreshuffle.h" #include "custom.h" #include "moe_op.h" #include "moe_sorting.h" diff --git a/op_tests/test_gemm_a8w8_blockscale.py b/op_tests/test_gemm_a8w8_blockscale.py index 47ac2a245b..ce2ff6d5db 100755 --- a/op_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/test_gemm_a8w8_blockscale.py @@ -8,6 +8,7 @@ import sys import os import aiter +from aiter.ops.shuffle import shuffle_weight from einops import rearrange from einops import repeat as eirp @@ -44,6 +45,11 @@ def run_gemm_ck(x, weight, x_scale, w_scale, dtype=torch.bfloat16): return aiter.gemm_a8w8_blockscale_CK(x, weight, x_scale, w_scale, dtype) +@perftest() +def run_gemm_ck_wpreshuffle(x, weight, x_scale, w_scale, dtype=torch.bfloat16): + return aiter.gemm_a8w8_blockscale_wpreshuffle_CK(x, weight, x_scale, w_scale, dtype) + + @benchmark() def test_gemm(dtype, m, n, k): dim = (m, n, k) @@ -56,14 +62,17 @@ def test_gemm(dtype, m, n, k): weight = (torch.rand((n, k), dtype=torch.float16, device="cuda") / 10).to( torch.float8_e4m3fnuz ) + weight_shulle = shuffle_weight(weight, layout=(16, 16)) x_scale = torch.rand([m, scale_k], dtype=torch.float32, device="cuda") w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") a, avg_a = run_torch(x, weight, x_scale, w_scale, dtype) b, avg_b = run_gemm_ck(x, weight, x_scale, w_scale, dtype) + c, avg_c = run_gemm_ck_wpreshuffle(x, weight_shulle, x_scale, w_scale, dtype) - msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b -1:<5.1%}" + msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, ck wpreshuffle avg: {avg_c:<8.2f} us uplift: {avg_a/min(avg_b, avg_c) -1:<5.1%}" checkAllclose(a, b, msg="a,b: " + msg, rtol=1e-2, atol=0.01) + checkAllclose(a, c, msg="ck_wpreshuffle: ", rtol=1e-2, atol=0.01) @perftest(num_iters=5)