diff --git a/csrc/vertical_slash_index.cu b/csrc/vertical_slash_index.cu index 45af042..7839343 100644 --- a/csrc/vertical_slash_index.cu +++ b/csrc/vertical_slash_index.cu @@ -60,12 +60,20 @@ __global__ void convert_vertical_slash_indexes_kernel( int tmp_col_cnt = 0, tmp_blk_cnt = 0; int s = 0, v = 0; - int v_idx = vertical_indexes[v++]; - int s_idx = slash_indexes[s++]; - while (s_idx >= end_m) { + + // in case of vs are empty + int v_idx = (v < NNZ_V) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M); + int s_idx = (s < NNZ_S) ? slash_indexes[s++] : -1; + + // make sure s_idx is valid + while (s_idx >= end_m && s < NNZ_S) { s_idx = slash_indexes[s++]; } - s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + if (s_idx >= end_m) { + s_idx = end_m + BLOCK_SIZE_M; + } else { + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + } int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; while (1) { if (v_idx < range_end) {