MUDA
Loading...
Searching...
No Matches
parallel_for.inl
1#include <muda/compute_graph/compute_graph.h>
2#include <muda/type_traits/always.h>
3#include <muda/launch/kernel_tag.h>
4namespace muda
5{
6namespace details
7{
8 /*
9 **************************************************************************
10 * This part is the core of the "launch part of muda" *
11 **************************************************************************
12 * F: the callable object *
13 * UserTag: the tag struct for user to recognize on profiling *
14 **************************************************************************
15 */
16 template <typename F, typename UserTag>
17 MUDA_GLOBAL void parallel_for_kernel(ParallelForCallable<F> f)
18 {
19 if constexpr(std::is_invocable_v<F, int>)
20 {
21 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
22 auto i = tid;
23 if(i < f.count)
24 {
25 f.callable(i);
26 }
27 }
28 else if constexpr(std::is_invocable_v<F, ParallelForDetails>)
29 {
30 ParallelForDetails details{ParallelForType::DynamicBlocks,
31 static_cast<int>(blockIdx.x * blockDim.x
32 + threadIdx.x),
33 f.count};
34 if(details.i() < details.total_num())
35 {
36 f.callable(details);
37 }
38 }
39 else
40 {
41 static_assert(always_false_v<F>, "f must be void (int) or void (ParallelForDetails)");
42 }
43 }
44
45 template <typename F, typename UserTag>
46 MUDA_GLOBAL void grid_stride_loop_kernel(ParallelForCallable<F> f)
47 {
48 if constexpr(std::is_invocable_v<F, int>)
49 {
50 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
51 auto grid_size = gridDim.x * blockDim.x;
52 auto i = tid;
53 for(; i < f.count; i += grid_size)
54 f.callable(i);
55 }
56 else if constexpr(std::is_invocable_v<F, ParallelForDetails>)
57 {
58 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
59 auto grid_size = gridDim.x * blockDim.x;
60 auto block_size = blockDim.x;
61 auto i = tid;
62 auto count = f.count;
63 auto round = (count + grid_size - 1) / grid_size;
64 for(int j = 0; i < count; i += grid_size, ++j)
65 {
66 ParallelForDetails details{
67 ParallelForType::GridStrideLoop, static_cast<int>(i), count};
68
69 details.m_total_batch = round;
70 details.m_batch_i = j;
71 if(i + block_size > details.total_num()) // the block may be incomplete in the last round
72 details.m_active_num_in_block = count - j * grid_size;
73 else
74 details.m_active_num_in_block = block_size;
75 f.callable(details);
76 }
77 }
78 else
79 {
80 static_assert(always_false_v<F>, "f must be void (int) or void (ParallelForDetails)");
81 }
82 }
83} // namespace details
84
85
86template <typename F, typename UserTag>
87MUDA_HOST ParallelFor& ParallelFor::apply(int count, F&& f)
88{
89 if constexpr(COMPUTE_GRAPH_ON)
90 {
91 using CallableType = raw_type_t<F>;
92
93 ComputeGraphBuilder::invoke_phase_actions(
94 [&] { // direct invoke
95 invoke<F, UserTag>(count, std::forward<F>(f));
96 },
97 [&]
98 {
99 // as node parms
100 auto parms = as_node_parms<F, UserTag>(count, std::forward<F>(f));
101 details::ComputeGraphAccessor().set_kernel_node(parms);
102 },
103 [&]
104 {
105 // topo build
106 details::ComputeGraphAccessor().set_kernel_node<details::ParallelForCallable<CallableType>>(
107 nullptr);
108 });
109 }
110 else
111 {
112 invoke<F, UserTag>(count, std::forward<F>(f));
113 }
114 pop_kernel_label();
115 return *this;
116}
117
118template <typename F, typename UserTag>
119MUDA_HOST ParallelFor& ParallelFor::apply(int count, F&& f, Tag<UserTag>)
120{
121 return apply<F, UserTag>(count, std::forward<F>(f));
122}
123
124template <typename F, typename UserTag>
125MUDA_HOST MUDA_NODISCARD auto ParallelFor::as_node_parms(int count, F&& f)
126 -> S<NodeParms<F>>
127{
128 using CallableType = raw_type_t<F>;
129
130 check_input(count);
131
132 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), count);
133 if(m_grid_dim <= 0) // dynamic grid dim
134 {
135 int best_block_size = calculate_block_dim<F, UserTag>(count);
136 auto n_blocks = calculate_grid_dim(count, best_block_size);
137 parms->func((void*)details::parallel_for_kernel<CallableType, UserTag>);
138 parms->grid_dim(n_blocks);
139 }
140 else // grid-stride loop
141 {
142 parms->func((void*)details::grid_stride_loop_kernel<CallableType, UserTag>);
143 parms->grid_dim(m_grid_dim);
144 }
145
146 parms->block_dim(m_block_dim);
147 parms->shared_mem_bytes(static_cast<uint32_t>(m_shared_mem_size));
148 parms->parse([](details::ParallelForCallable<CallableType>& p) -> std::vector<void*>
149 { return {&p}; });
150
151 return parms;
152}
153
154template <typename F, typename UserTag>
155MUDA_HOST MUDA_NODISCARD auto ParallelFor::as_node_parms(int count, F&& f, Tag<UserTag>)
156 -> S<NodeParms<F>>
157{
158 return as_node_parms<F, UserTag>(count, std::forward<F>(f));
159}
160
161template <typename F, typename UserTag>
162MUDA_HOST void ParallelFor::invoke(int count, F&& f)
163{
164 using CallableType = raw_type_t<F>;
165 // check_input(count);
166 if(count > 0)
167 {
168 if(m_grid_dim <= 0) // parallel for
169 {
170 // calculate the blocks we need
171 int best_block_size = calculate_block_dim<F, UserTag>(count);
172 auto n_blocks = calculate_grid_dim(count, best_block_size);
173 auto callable = details::ParallelForCallable<CallableType>{f, count};
174 details::parallel_for_kernel<CallableType, UserTag>
175 <<<n_blocks, best_block_size, m_shared_mem_size, m_stream>>>(callable);
176 }
177 else // grid stride loop
178 {
179 auto callable = details::ParallelForCallable<CallableType>{f, count};
180 details::grid_stride_loop_kernel<CallableType, UserTag>
181 <<<m_grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
182 }
183 }
184}
185
186template <typename F, typename UserTag>
187MUDA_INLINE MUDA_GENERIC int ParallelFor::calculate_block_dim(int count) const MUDA_NOEXCEPT
188{
189 using CallableType = raw_type_t<F>;
190 int best_block_size = -1;
191 if(m_block_dim <= 0) // automatic choose
192 {
193 int min_grid_size = -1;
194 checkCudaErrors(cudaOccupancyMaxPotentialBlockSize(
195 &min_grid_size,
196 &best_block_size,
197 details::parallel_for_kernel<CallableType, UserTag>,
198 m_shared_mem_size));
199 }
200 else
201 {
202 best_block_size = m_block_dim;
203 }
204 MUDA_ASSERT(best_block_size >= 0, "Invalid block dim");
205 return best_block_size;
206}
207
208MUDA_INLINE MUDA_GENERIC int ParallelFor::calculate_grid_dim(int count) const MUDA_NOEXCEPT
209{
210 return calculate_grid_dim(count, m_grid_dim);
211}
212
213MUDA_INLINE MUDA_GENERIC int ParallelFor::calculate_grid_dim(int count, int block_dim) MUDA_NOEXCEPT
214{
215 auto min_threads = count;
216 auto min_blocks = (min_threads + block_dim - 1) / block_dim;
217 return min_blocks;
218}
219
220MUDA_INLINE MUDA_GENERIC void ParallelFor::check_input(int count) const MUDA_NOEXCEPT
221{
222 MUDA_KERNEL_ASSERT(count >= 0, "count must be >= 0");
223 MUDA_KERNEL_ASSERT(m_block_dim > 0, "blockDim must be > 0");
224}
225
226MUDA_INLINE MUDA_DEVICE int ParallelForDetails::active_num_in_block() const MUDA_NOEXCEPT
227{
228 if(m_type == ParallelForType::DynamicBlocks)
229 {
230 auto block_id = blockIdx.x;
231 return (blockIdx.x == gridDim.x - 1) ? m_total_num - block_id * blockDim.x :
232 blockDim.x;
233 }
234 else if(m_type == ParallelForType::GridStrideLoop)
235 {
236 return m_active_num_in_block;
237 }
238 else
239 {
240 MUDA_KERNEL_ERROR("invalid paralell for type");
241 return 0;
242 }
243}
244
245MUDA_INLINE MUDA_DEVICE bool ParallelForDetails::is_final_block() const MUDA_NOEXCEPT
246{
247 if(m_type == ParallelForType::DynamicBlocks)
248 {
249 return (blockIdx.x == gridDim.x - 1);
250 }
251 else if(m_type == ParallelForType::GridStrideLoop)
252 {
253 return m_active_num_in_block == blockDim.x;
254 }
255 else
256 {
257 MUDA_KERNEL_ERROR("invalid paralell for type");
258 return false;
259 }
260}
261} // namespace muda