Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions sgl-kernel/csrc/cpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,73 @@ inline void parallel_for(int n, const func_t& f) {
#endif
}

// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {
int actual_nth = at::get_num_threads();
if (m == 1) {
return actual_nth;
}
return std::max(1, (actual_nth >> 1) * 2);
}

template <typename func_t>
inline void parallel_2d(int m, int n, const func_t& f) {

// make sure we have even num_threads
int nth = adjust_num_threads(m);

// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m) {
nth_n = nth / nth_m;
if (nth_m * nth_n == nth) {
break;
}
}

#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;

int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);

int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);

f(begin_m, end_m, begin_n, end_n);
}
#else
f(0, m, 0, n);
#endif
}

template <typename T>
int get_cache_blocks(int BLOCK_SIZE, int K) {
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
}

// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
Expand Down
18 changes: 9 additions & 9 deletions sgl-kernel/csrc/cpu/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,20 @@ void weight_packed_linear_kernel_impl(
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);

// l2 cache block for n
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);

// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {

// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];

for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {

int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
Expand All @@ -315,10 +318,7 @@ void weight_packed_linear_kernel_impl(
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);

// move to the next index
data_index_step(mb, MB, nb, NB);
}
}}}

if (use_brgemm) {
at::native::cpublas::brgemm_release();
Expand Down
66 changes: 43 additions & 23 deletions sgl-kernel/csrc/cpu/gemm_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,16 @@ inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}

inline void unpack_B(
at::Half* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
TORCH_CHECK(false, "unpack_B: Half not supported!");
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
Expand Down Expand Up @@ -257,13 +266,6 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;

for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);

int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}

at::native::cpublas::brgemm(
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);

Expand Down Expand Up @@ -351,19 +353,39 @@ void fp8_scaled_mm_kernel_impl(
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;

const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);

scalar_t* __restrict__ Btmp = buffer;

if (use_brgemm) {
at::parallel_for(0, NB, 0, [&](int64_t begin, int64_t end) {
int64_t nb{0};
data_index_init(begin, nb, NB);
for (int64_t i = begin; i < end; ++i) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
for (int64_t k = 0; k < K; k += BLOCK_K) {
int64_t kb_size = std::min(static_cast<int64_t>(BLOCK_K), K - k);
const int64_t idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
auto ldb = nb_size;
auto B = mat2 + nb_start * K;
unpack_B(Btmp + nb_start*K+ k * BLOCK_N, B + k * ldb, nb_size, kb_size, ldb, BLOCK_N, scale_ptr[idx]);
}
data_index_step( nb, NB);
}
});
}
// l2 cache block for n
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {

int tid = at::get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
// int tid = at::get_thread_num();
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];

for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;

int64_t mb_start = mb * BLOCK_M;
Expand All @@ -375,7 +397,7 @@ void fp8_scaled_mm_kernel_impl(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp,
/* Btmp */ Btmp + nb_start * K,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
Expand All @@ -388,9 +410,7 @@ void fp8_scaled_mm_kernel_impl(
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);

// move to the next index
data_index_step(mb, MB, nb, NB);
}
}}}

if (use_brgemm) {
at::native::cpublas::brgemm_release();
Expand Down Expand Up @@ -500,8 +520,8 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
int64_t size_per_thread = N * K ;
auto buffer = at::empty({size_per_thread}, mat1.options());

AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
Expand Down
Loading