diff --git a/include/cuco/detail/open_addressing/kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh index 7aff8a1c2..3317dcf0b 100644 --- a/include/cuco/detail/open_addressing/kernels.cuh +++ b/include/cuco/detail/open_addressing/kernels.cuh @@ -69,6 +69,8 @@ CUCO_KERNEL void insert_if_n(InputIt first, AtomicT* num_successes, Ref ref) { + namespace cg = cooperative_groups; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; typename Ref::size_type thread_num_successes = 0; @@ -84,7 +86,13 @@ CUCO_KERNEL void insert_if_n(InputIt first, } else { auto const tile = cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); +#if defined(CUCO_HAS_CG_INVOKE_ONE) + if (ref.insert(tile, insert_element)) { + cg::invoke_one(tile, [&]() { thread_num_successes++; }); + } +#else if (ref.insert(tile, insert_element) && tile.thread_rank() == 0) { thread_num_successes++; } +#endif } } idx += loop_stride; @@ -93,9 +101,15 @@ CUCO_KERNEL void insert_if_n(InputIt first, // compute number of successfully inserted elements for each block // and atomically add to the grand total auto const block_num_successes = BlockReduce(temp_storage).Sum(thread_num_successes); +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(cg::this_thread_block(), [&]() { + num_successes->fetch_add(block_num_successes, cuda::std::memory_order_relaxed); + }); +#else if (threadIdx.x == 0) { num_successes->fetch_add(block_num_successes, cuda::std::memory_order_relaxed); } +#endif } /** @@ -248,7 +262,11 @@ CUCO_KERNEL void contains_if_n(InputIt first, if (idx < n) { typename std::iterator_traits::value_type const& key = *(first + idx); auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false; +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(tile, [&]() { *(output_begin + idx) = found; }); +#else if (tile.thread_rank() == 0) { *(output_begin + idx) = found; } +#endif } } idx += loop_stride; @@ -270,6 +288,7 @@ CUCO_KERNEL void contains_if_n(InputIt first, template CUCO_KERNEL void size(StorageRef storage, Predicate is_filled, AtomicT* count) { + namespace cg = cooperative_groups; using size_type = typename StorageRef::size_type; auto const loop_stride = cuco::detail::grid_stride(); @@ -290,7 +309,12 @@ CUCO_KERNEL void size(StorageRef storage, Predicate is_filled, AtomicT* count) using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; auto const block_count = BlockReduce(temp_storage).Sum(thread_count); +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(cg::this_thread_block(), + [&]() { count->fetch_add(block_count, cuda::std::memory_order_relaxed); }); +#else if (threadIdx.x == 0) { count->fetch_add(block_count, cuda::std::memory_order_relaxed); } +#endif } template @@ -315,7 +339,11 @@ CUCO_KERNEL void rehash(typename ContainerRef::storage_ref_type storage_ref, auto const n = storage_ref.num_windows(); while (idx - thread_rank < n) { +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(block, [&]() { buffer_size = 0; }); +#else if (thread_rank == 0) { buffer_size = 0; } +#endif block.sync(); // gather values in shmem buffer diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index dcbac0907..c7e7fff37 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -300,7 +300,12 @@ class open_addressing_ref_impl { auto const num_windows = static_cast(this->window_extent()); #if defined(CUDA_HAS_CUDA_BARRIER) __shared__ cuda::barrier barrier; + +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cooperative_groups::invoke_one(g, [&]() { init(&barrier, g.size()); }); +#else if (g.thread_rank() == 0) { init(&barrier, g.size()); } +#endif g.sync(); cuda::memcpy_async( diff --git a/include/cuco/detail/static_map/kernels.cuh b/include/cuco/detail/static_map/kernels.cuh index 6e034567b..f95954168 100644 --- a/include/cuco/detail/static_map/kernels.cuh +++ b/include/cuco/detail/static_map/kernels.cuh @@ -118,9 +118,15 @@ CUCO_KERNEL void find(InputIt first, cuco::detail::index_type n, OutputIt output auto const tile = cg::tiled_partition(block); auto const found = ref.find(tile, key); +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(tile, [&]() { + *(output_begin + idx) = found == ref.end() ? ref.empty_value_sentinel() : (*found).second; + }); +#else if (tile.thread_rank() == 0) { *(output_begin + idx) = found == ref.end() ? ref.empty_value_sentinel() : (*found).second; } +#endif } } idx += loop_stride; diff --git a/include/cuco/detail/static_set/kernels.cuh b/include/cuco/detail/static_set/kernels.cuh index b3b84e306..5678816c9 100644 --- a/include/cuco/detail/static_set/kernels.cuh +++ b/include/cuco/detail/static_set/kernels.cuh @@ -81,9 +81,15 @@ CUCO_KERNEL void find(InputIt first, cuco::detail::index_type n, OutputIt output auto const tile = cg::tiled_partition(block); auto const found = ref.find(tile, key); +#if defined(CUCO_HAS_CG_INVOKE_ONE) + cg::invoke_one(tile, [&]() { + *(output_begin + idx) = found == ref.end() ? ref.empty_key_sentinel() : *found; + }); +#else if (tile.thread_rank() == 0) { *(output_begin + idx) = found == ref.end() ? ref.empty_key_sentinel() : *found; } +#endif } } idx += loop_stride;