MUDA
Loading...
Searching...
No Matches
parallel_for.h
Go to the documentation of this file.
1/*****************************************************************/
11#pragma once
12#include <muda/launch/launch_base.h>
13#include <muda/launch/kernel_tag.h>
14#include <stdexcept>
15#include <exception>
16
17namespace muda
18{
19namespace details
20{
21 template <typename F>
23 {
24 public:
25 F callable;
26 int count;
27 template <typename U>
28 MUDA_GENERIC ParallelForCallable(U&& callable, int count) MUDA_NOEXCEPT
29 : callable(std::forward<U>(callable)),
30 count(count)
31 {
32 }
33 // MUDA_GENERIC ~ParallelForCallable() = default;
34 };
35
36 template <typename F, typename UserTag>
37 MUDA_GLOBAL void parallel_for_kernel(ParallelForCallable<F> f);
38
39 template <typename F, typename UserTag>
40 MUDA_GLOBAL void grid_stride_loop_kernel(ParallelForCallable<F> f);
41} // namespace details
42
43enum class ParallelForType : uint32_t
44{
45 DynamicBlocks,
46 GridStrideLoop
47};
48
50{
51 public:
52 MUDA_NODISCARD MUDA_DEVICE int active_num_in_block() const MUDA_NOEXCEPT;
53 MUDA_NODISCARD MUDA_DEVICE bool is_final_block() const MUDA_NOEXCEPT;
54 MUDA_NODISCARD MUDA_DEVICE ParallelForType parallel_for_type() const MUDA_NOEXCEPT
55 {
56 return m_type;
57 }
58
59 MUDA_NODISCARD MUDA_DEVICE int total_num() const MUDA_NOEXCEPT
60 {
61 return m_total_num;
62 }
63 MUDA_NODISCARD MUDA_DEVICE operator int() const MUDA_NOEXCEPT
64 {
65 return m_current_i;
66 }
67
68 MUDA_NODISCARD MUDA_DEVICE int i() const MUDA_NOEXCEPT
69 {
70 return m_current_i;
71 }
72
73 MUDA_NODISCARD MUDA_DEVICE int batch_i() const MUDA_NOEXCEPT
74 {
75 return m_batch_i;
76 }
77
78 MUDA_NODISCARD MUDA_DEVICE int total_batch() const MUDA_NOEXCEPT
79 {
80 return m_total_batch;
81 }
82
83 private:
84 template <typename F, typename UserTag>
85 friend MUDA_GLOBAL void details::parallel_for_kernel(ParallelForCallable<F> f);
86
87 template <typename F, typename UserTag>
88 friend MUDA_GLOBAL void details::grid_stride_loop_kernel(ParallelForCallable<F> f);
89
90 MUDA_DEVICE ParallelForDetails(ParallelForType type, int i, int total_num) MUDA_NOEXCEPT
91 : m_type(type),
92 m_total_num(total_num),
93 m_current_i(i)
94 {
95 }
96
97 ParallelForType m_type;
98 int m_total_num;
99 int m_total_batch = 1;
100 int m_batch_i = 0;
101 int m_active_num_in_block = 0;
102 int m_current_i = 0;
103};
104
105using details::grid_stride_loop_kernel;
106using details::parallel_for_kernel;
107
108
115class ParallelFor : public LaunchBase<ParallelFor>
116{
117 int m_grid_dim;
118 int m_block_dim;
119 size_t m_shared_mem_size;
120
121 public:
122 template <typename F>
124
142 MUDA_HOST ParallelFor(size_t shared_mem_size = 0, cudaStream_t stream = nullptr) MUDA_NOEXCEPT
143 : LaunchBase(stream),
144 m_grid_dim(0),
145 m_block_dim(-1),
146 m_shared_mem_size(shared_mem_size)
147 {
148 }
149
166 MUDA_HOST ParallelFor(int blockDim, size_t shared_mem_size = 0, cudaStream_t stream = nullptr) MUDA_NOEXCEPT
167 : LaunchBase(stream),
168 m_grid_dim(0),
169 m_block_dim(blockDim),
170 m_shared_mem_size(shared_mem_size)
171 {
172 }
173
174
192 MUDA_HOST ParallelFor(int gridDim,
193 int blockDim,
194 size_t shared_mem_size = 0,
195 cudaStream_t stream = nullptr) MUDA_NOEXCEPT
196 : LaunchBase(stream),
197 m_grid_dim(gridDim),
198 m_block_dim(blockDim),
199 m_shared_mem_size(shared_mem_size)
200 {
201 }
202
203 template <typename F, typename UserTag = Default>
204 MUDA_HOST ParallelFor& apply(int count, F&& f);
205
206 template <typename F, typename UserTag = Default>
207 MUDA_HOST ParallelFor& apply(int count, F&& f, Tag<UserTag>);
208
209
210 template <typename F, typename UserTag = Default>
211 MUDA_HOST MUDA_NODISCARD auto as_node_parms(int count, F&& f) -> S<NodeParms<F>>;
212
213 template <typename F, typename UserTag = Default>
214 MUDA_HOST MUDA_NODISCARD auto as_node_parms(int count, F&& f, Tag<UserTag>)
215 -> S<NodeParms<F>>;
216
217 MUDA_GENERIC MUDA_NODISCARD static int round_up_blocks(int count, int block_dim) MUDA_NOEXCEPT
218 {
219 return (count + block_dim - 1) / block_dim;
220 }
221
222 public:
223 template <typename F, typename UserTag>
224 MUDA_HOST void invoke(int count, F&& f);
225
226 template <typename F, typename UserTag>
227 MUDA_GENERIC int calculate_block_dim(int count) const MUDA_NOEXCEPT;
228
229 MUDA_GENERIC int calculate_grid_dim(int count) const MUDA_NOEXCEPT;
230
231 static MUDA_GENERIC int calculate_grid_dim(int count, int block_dim) MUDA_NOEXCEPT;
232
233 MUDA_GENERIC void check_input(int count) const MUDA_NOEXCEPT;
234};
235} // namespace muda
236
237#include "details/parallel_for.inl"
Definition kernel_node.h:15
Definition launch_base.h:86
Definition parallel_for.h:50
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided,...
Definition parallel_for.h:116
MUDA_HOST ParallelFor(size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Calculate grid dim automatically to cover the range, automatially choose the block size to achieve ma...
Definition parallel_for.h:142
MUDA_HOST ParallelFor(int gridDim, int blockDim, size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Use Gride Stride Loop to cover the range, you need mannally set the grid size and block size....
Definition parallel_for.h:192
MUDA_HOST ParallelFor(int blockDim, size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Calculate grid dim automatically to cover the range, but you need mannally set the block size.
Definition parallel_for.h:166
Definition parallel_for.h:23
Definition kernel_tag.h:6