diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 34aa86ca1efa..43ccfe16ebc1 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -112,6 +112,73 @@ inline void parallel_for(int n, const func_t& f) { #endif } +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); +} +#else + f(0, m, 0, n); +#endif +} + +template +int get_cache_blocks(int BLOCK_SIZE, int K) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); +} + // data indexing for dimension collapse template inline T data_index_init(T offset) { diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 5bee42ec0bfd..107eff5bc0e4 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -286,17 +286,20 @@ void weight_packed_linear_kernel_impl( // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx const bool use_brgemm = (M > 4) || (!std::is_same_v); + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); + // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int64_t mb{0}, nb{0}; - data_index_init(begin, mb, MB, nb, NB); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { // for brgemm, use float32 for accumulate alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - for (int64_t i = begin; i < end; ++i) { - UNUSED(i); + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(M - mb_start, BLOCK_M); int64_t nb_start = nb * BLOCK_N; @@ -315,10 +318,7 @@ void weight_packed_linear_kernel_impl( /* ldb */ nb_size, /* ldc */ out_strideM, /* brg */ use_brgemm); - - // move to the next index - data_index_step(mb, MB, nb, NB); - } + }}} if (use_brgemm) { at::native::cpublas::brgemm_release(); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 94d72e9e2d45..4eb5f857ae7c 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -44,7 +44,16 @@ inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ out[d] = static_cast(input[d] + bias[d]); } } - +inline void unpack_B( + at::Half* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { + TORCH_CHECK(false, "unpack_B: Half not supported!"); +} inline void unpack_B( at::BFloat16* __restrict__ Btmp, const at::Float8_e4m3fn* __restrict__ packed_B, @@ -257,13 +266,6 @@ struct brgemm { // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] const int ldb_tmp = BLOCK_N; - for (int k = 0; k < K; k += BLOCK_K) { - int kb_size = std::min(BLOCK_K, K - k); - - int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 - unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); - } - at::native::cpublas::brgemm( M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); @@ -351,19 +353,39 @@ void fp8_scaled_mm_kernel_impl( const int64_t blocks_n_per_group = block_size_N / BLOCK_N; const bool use_brgemm = can_use_brgemm(M); - + scalar_t* __restrict__ Btmp = buffer; + + if (use_brgemm) { + at::parallel_for(0, NB, 0, [&](int64_t begin, int64_t end) { + int64_t nb{0}; + data_index_init(begin, nb, NB); + for (int64_t i = begin; i < end; ++i) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + for (int64_t k = 0; k < K; k += BLOCK_K) { + int64_t kb_size = std::min(static_cast(BLOCK_K), K - k); + const int64_t idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + auto ldb = nb_size; + auto B = mat2 + nb_start * K; + unpack_B(Btmp + nb_start*K+ k * BLOCK_N, B + k * ldb, nb_size, kb_size, ldb, BLOCK_N, scale_ptr[idx]); + } + data_index_step( nb, NB); + } + }); + } + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int64_t mb{0}, nb{0}; - data_index_init(begin, mb, MB, nb, NB); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { - int tid = at::get_thread_num(); - scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; - float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + // int tid = at::get_thread_num(); + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - for (int64_t i = begin; i < end; ++i) { - UNUSED(i); + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; int64_t mb_start = mb * BLOCK_M; @@ -375,7 +397,7 @@ void fp8_scaled_mm_kernel_impl( /* A */ mat1 + mb_start * mat1_strideM, /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K /* C */ out + mb_start * out_strideM + nb_start, - /* Btmp */ Btmp, + /* Btmp */ Btmp + nb_start * K, /* Ctmp */ Ctmp, /* scale */ scale_ptr, /* bias */ bias + nb_start, @@ -388,9 +410,7 @@ void fp8_scaled_mm_kernel_impl( /* brg */ use_brgemm, /* block_size_K */ block_size_K); - // move to the next index - data_index_step(mb, MB, nb, NB); - } + }}} if (use_brgemm) { at::native::cpublas::brgemm_release(); @@ -500,8 +520,8 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca // Btmp : [T, BLOCK_N * K] // Ctmp : [T, BLOCK_M * BLOCK_N] int num_threads = at::get_num_threads(); - int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; - auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + int64_t size_per_thread = N * K ; + auto buffer = at::empty({size_per_thread}, mat1.options()); AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { fp8_scaled_mm_kernel_impl(