21#ifdef __KERNEL_ONEAPI__
23template<
typename IsActiveOp>
27 IsActiveOp is_active_op)
29# ifdef WITH_ONEAPI_SYCL_HOST_TASK
31 for (
int state_index = 0; state_index <
num_states; state_index++) {
32 if (is_active_op(state_index)) {
33 indices[write_index++] = state_index;
36 *num_indices = write_index;
40 const sycl::nd_item<1> &item_id = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
41 const uint blocksize = item_id.get_local_range(0);
44 sycl::access::address_space::local_space>
45 ptr = sycl::ext::oneapi::group_local_memory<
47 int *warp_offset = *
ptr;
54 const uint thread_index = item_id.get_local_id(0);
55 const uint thread_warp = item_id.get_sub_group().get_local_id();
57 const uint warp_index = item_id.get_sub_group().get_group_id();
58 const uint num_warps = item_id.get_sub_group().get_group_range()[0];
60 const uint state_index = item_id.get_global_id(0);
63 const uint is_active = (state_index <
num_states) ? is_active_op(state_index) : 0;
65# ifndef __KERNEL_METAL__
66template<
typename IsActiveOp>
76 const int thread_index,
77 const uint state_index,
79 const int thread_warp,
82 threadgroup
int *warp_offset)
85 IsActiveOp is_active_op)
89# ifndef __KERNEL_METAL__
102 const uint is_active = (state_index <
num_states) ? is_active_op(state_index) : 0;
106#ifdef __KERNEL_ONEAPI__
107 const uint thread_offset = sycl::exclusive_scan_over_group(
108 item_id.get_sub_group(), is_active, std::plus<>());
115#ifdef __KERNEL_ONEAPI__
116 if (thread_warp == item_id.get_sub_group().get_local_range()[0] - 1) {
120 warp_offset[warp_index] = thread_offset + is_active;
123#ifdef __KERNEL_ONEAPI__
133 if (thread_index == blocksize - 1) {
136 for (
int i = 0;
i < num_warps;
i++) {
137 int num_active = warp_offset[
i];
138 warp_offset[
i] = offset;
139 offset += num_active;
142 const uint block_num_active = warp_offset[warp_index] + thread_offset + is_active;
146#ifdef __KERNEL_ONEAPI__
156 const uint block_offset = warp_offset[num_warps];
157 indices[block_offset + warp_offset[warp_index] + thread_offset] = state_index;
161#ifdef __KERNEL_METAL__
163# define gpu_parallel_active_index_array(num_states, indices, num_indices, is_active_op) \
164 const uint is_active = (ccl_gpu_global_id_x() < num_states) ? \
165 is_active_op(ccl_gpu_global_id_x()) : \
167 gpu_parallel_active_index_array_impl(num_states, \
178 (threadgroup int *)threadgroup_array)
179#elif defined(__KERNEL_ONEAPI__)
181# define gpu_parallel_active_index_array(num_states, indices, num_indices, is_active_op) \
182 gpu_parallel_active_index_array_impl(num_states, indices, num_indices, is_active_op)
186# define gpu_parallel_active_index_array(num_states, indices, num_indices, is_active_op) \
187 gpu_parallel_active_index_array_impl(num_states, indices, num_indices, is_active_op)
ATOMIC_INLINE uint32_t atomic_fetch_and_add_uint32(uint32_t *p, uint32_t x)
#define GPU_PARALLEL_ACTIVE_INDEX_DEFAULT_BLOCK_SIZE
#define ccl_gpu_block_dim_x
#define ccl_gpu_thread_idx_x
#define ccl_gpu_syncthreads()
#define ccl_gpu_warp_size
#define ccl_gpu_thread_mask(thread_warp)
#define ccl_gpu_ballot(predicate)
#define ccl_gpu_block_idx_x
#define CCL_NAMESPACE_END
#define ccl_gpu_local_syncthreads()
ccl_device_inline uint popcount(const uint x)
CCL_NAMESPACE_BEGIN __device__ void gpu_parallel_active_index_array_impl(const uint num_states, ccl_global int *indices, ccl_global int *num_indices, IsActiveOp is_active_op)