1#include <muda/compute_graph/compute_graph.h>
2#include <muda/type_traits/always.h>
3#include <muda/launch/kernel_tag.h>
16 template <
typename F,
typename UserTag>
17 MUDA_GLOBAL
void parallel_for_kernel(ParallelForCallable<F> f)
19 if constexpr(std::is_invocable_v<F, int>)
21 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
28 else if constexpr(std::is_invocable_v<F, ParallelForDetails>)
30 ParallelForDetails details{ParallelForType::DynamicBlocks,
31 static_cast<int>(blockIdx.x * blockDim.x
34 if(details.i() < details.total_num())
41 static_assert(always_false_v<F>,
"f must be void (int) or void (ParallelForDetails)");
45 template <
typename F,
typename UserTag>
46 MUDA_GLOBAL
void grid_stride_loop_kernel(ParallelForCallable<F> f)
48 if constexpr(std::is_invocable_v<F, int>)
50 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
51 auto grid_size = gridDim.x * blockDim.x;
53 for(; i < f.count; i += grid_size)
56 else if constexpr(std::is_invocable_v<F, ParallelForDetails>)
58 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
59 auto grid_size = gridDim.x * blockDim.x;
60 auto block_size = blockDim.x;
63 auto round = (count + grid_size - 1) / grid_size;
64 for(
int j = 0; i < count; i += grid_size, ++j)
66 ParallelForDetails details{
67 ParallelForType::GridStrideLoop,
static_cast<int>(i), count};
69 details.m_total_batch = round;
70 details.m_batch_i = j;
71 if(i + block_size > details.total_num())
72 details.m_active_num_in_block = count - j * grid_size;
74 details.m_active_num_in_block = block_size;
80 static_assert(always_false_v<F>,
"f must be void (int) or void (ParallelForDetails)");
86template <
typename F,
typename UserTag>
87MUDA_HOST ParallelFor& ParallelFor::apply(
int count, F&& f)
89 if constexpr(COMPUTE_GRAPH_ON)
91 using CallableType = raw_type_t<F>;
93 ComputeGraphBuilder::invoke_phase_actions(
95 invoke<F, UserTag>(count, std::forward<F>(f));
100 auto parms = as_node_parms<F, UserTag>(count, std::forward<F>(f));
101 details::ComputeGraphAccessor().set_kernel_node(parms);
106 details::ComputeGraphAccessor().set_kernel_node<details::ParallelForCallable<CallableType>>(
112 invoke<F, UserTag>(count, std::forward<F>(f));
118template <
typename F,
typename UserTag>
119MUDA_HOST ParallelFor& ParallelFor::apply(
int count, F&& f, Tag<UserTag>)
121 return apply<F, UserTag>(count, std::forward<F>(f));
124template <
typename F,
typename UserTag>
125MUDA_HOST MUDA_NODISCARD
auto ParallelFor::as_node_parms(
int count, F&& f)
128 using CallableType = raw_type_t<F>;
132 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), count);
135 int best_block_size = calculate_block_dim<F, UserTag>(count);
136 auto n_blocks = calculate_grid_dim(count, best_block_size);
137 parms->func((
void*)details::parallel_for_kernel<CallableType, UserTag>);
138 parms->grid_dim(n_blocks);
142 parms->func((
void*)details::grid_stride_loop_kernel<CallableType, UserTag>);
143 parms->grid_dim(m_grid_dim);
146 parms->block_dim(m_block_dim);
147 parms->shared_mem_bytes(
static_cast<uint32_t
>(m_shared_mem_size));
148 parms->parse([](details::ParallelForCallable<CallableType>& p) -> std::vector<void*>
154template <
typename F,
typename UserTag>
155MUDA_HOST MUDA_NODISCARD
auto ParallelFor::as_node_parms(
int count, F&& f, Tag<UserTag>)
158 return as_node_parms<F, UserTag>(count, std::forward<F>(f));
161template <
typename F,
typename UserTag>
162MUDA_HOST
void ParallelFor::invoke(
int count, F&& f)
164 using CallableType = raw_type_t<F>;
171 int best_block_size = calculate_block_dim<F, UserTag>(count);
172 auto n_blocks = calculate_grid_dim(count, best_block_size);
173 auto callable = details::ParallelForCallable<CallableType>{f, count};
174 details::parallel_for_kernel<CallableType, UserTag>
175 <<<n_blocks, best_block_size, m_shared_mem_size, m_stream>>>(callable);
179 auto callable = details::ParallelForCallable<CallableType>{f, count};
180 details::grid_stride_loop_kernel<CallableType, UserTag>
181 <<<m_grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
186template <
typename F,
typename UserTag>
187MUDA_INLINE MUDA_GENERIC
int ParallelFor::calculate_block_dim(
int count)
const MUDA_NOEXCEPT
189 using CallableType = raw_type_t<F>;
190 int best_block_size = -1;
193 int min_grid_size = -1;
194 checkCudaErrors(cudaOccupancyMaxPotentialBlockSize(
197 details::parallel_for_kernel<CallableType, UserTag>,
202 best_block_size = m_block_dim;
204 MUDA_ASSERT(best_block_size >= 0,
"Invalid block dim");
205 return best_block_size;
208MUDA_INLINE MUDA_GENERIC
int ParallelFor::calculate_grid_dim(
int count)
const MUDA_NOEXCEPT
210 return calculate_grid_dim(count, m_grid_dim);
213MUDA_INLINE MUDA_GENERIC
int ParallelFor::calculate_grid_dim(
int count,
int block_dim) MUDA_NOEXCEPT
215 auto min_threads = count;
216 auto min_blocks = (min_threads + block_dim - 1) / block_dim;
220MUDA_INLINE MUDA_GENERIC
void ParallelFor::check_input(
int count)
const MUDA_NOEXCEPT
222 MUDA_KERNEL_ASSERT(count >= 0,
"count must be >= 0");
223 MUDA_KERNEL_ASSERT(m_block_dim > 0,
"blockDim must be > 0");
226MUDA_INLINE MUDA_DEVICE
int ParallelForDetails::active_num_in_block() const MUDA_NOEXCEPT
228 if(m_type == ParallelForType::DynamicBlocks)
230 auto block_id = blockIdx.x;
231 return (blockIdx.x == gridDim.x - 1) ? m_total_num - block_id * blockDim.x :
234 else if(m_type == ParallelForType::GridStrideLoop)
236 return m_active_num_in_block;
240 MUDA_KERNEL_ERROR(
"invalid paralell for type");
245MUDA_INLINE MUDA_DEVICE
bool ParallelForDetails::is_final_block() const MUDA_NOEXCEPT
247 if(m_type == ParallelForType::DynamicBlocks)
249 return (blockIdx.x == gridDim.x - 1);
251 else if(m_type == ParallelForType::GridStrideLoop)
253 return m_active_num_in_block == blockDim.x;
257 MUDA_KERNEL_ERROR(
"invalid paralell for type");