MUDA
Loading...
Searching...
No Matches
launch.inl
1#include <muda/compute_graph/compute_graph_builder.h>
2#include <muda/compute_graph/nodes/compute_graph_kernel_node.h>
3#include <muda/launch/kernel.h>
4namespace muda
5{
6namespace details
7{
8 template <typename F, typename UserTag>
9 MUDA_GLOBAL void generic_kernel(LaunchCallable<F> f)
10 {
11 static_assert(std::is_invocable_v<F>, "f:void (void)");
12 f.callable();
13 }
14
15 template <typename F, typename UserTag>
16 MUDA_GLOBAL void generic_kernel_with_range(LaunchCallable<F> f)
17 {
18 auto x = blockIdx.x * blockDim.x + threadIdx.x;
19 auto y = blockIdx.y * blockDim.y + threadIdx.y;
20 auto z = blockIdx.z * blockDim.z + threadIdx.z;
21
22 if(x < f.dim.x && y < f.dim.y && z < f.dim.z)
23 {
24 if constexpr(std::is_invocable_v<F, int2>)
25 {
26 f.callable(int2{static_cast<int>(x), static_cast<int>(y)});
27 }
28 else if constexpr(std::is_invocable_v<F, int3>)
29 {
30 f.callable(int3{static_cast<int>(x), static_cast<int>(y), static_cast<int>(z)});
31 }
32 else if constexpr(std::is_invocable_v<F, uint1>)
33 {
34 f.callable(uint1{x});
35 }
36 else if constexpr(std::is_invocable_v<F, uint2>)
37 {
38 f.callable(uint2{x, y});
39 }
40 else if constexpr(std::is_invocable_v<F, uint3>)
41 {
42 f.callable(uint3{x, y, z});
43 }
44 else if constexpr(std::is_invocable_v<F, dim3>)
45 {
46 f.callable(dim3{x, y, z});
47 }
48 else if constexpr(std::is_invocable_v<F, int>
49 || std::is_invocable_v<F, unsigned int>)
50 {
51 static_assert("You should use `ParallelFor()` instead of `Launch()` for better semantics");
52 }
53 else
54 {
55 static_assert(always_false_v<F>,
56 "invalid callable, it should be:"
57 "void (uint1) or"
58 "void (unsigned int) or"
59 "void (uint2) or"
60 "void (uint3) or"
61 "void (dim3)");
62 }
63 }
64 }
65} // namespace details
66
67MUDA_INLINE dim3 cube(int x) MUDA_NOEXCEPT
68{
69 return dim3(x, x, x);
70}
71
72MUDA_INLINE dim3 square(int x) MUDA_NOEXCEPT
73{
74 return dim3(x, x, 1);
75}
76
77template <typename F, typename UserTag>
78MUDA_INLINE MUDA_HOST auto Launch::as_node_parms(F&& f) -> S<NodeParms<F>>
79{
80 check_input();
81
82 using CallableType = raw_type_t<F>;
83 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), dim3{0});
84
85 parms->func((void*)details::generic_kernel<CallableType, UserTag>);
86 parms->grid_dim(m_grid_dim);
87 parms->block_dim(m_block_dim);
88 parms->shared_mem_bytes(static_cast<uint32_t>(m_shared_mem_size));
89 parms->parse([](details::LaunchCallable<CallableType>& p) -> std::vector<void*>
90 { return {&p}; });
91 return parms;
92}
93
94template <typename F, typename UserTag>
95MUDA_HOST MUDA_NODISCARD auto Launch::as_node_parms(F&& f, Tag<UserTag>)
96 -> S<NodeParms<F>>
97{
98 return as_node_parms<F, UserTag>(std::forward<F>(f));
99}
100
101template <typename F, typename UserTag>
102MUDA_INLINE MUDA_HOST auto Launch::as_node_parms(const dim3& active_dim, F&& f)
103 -> S<NodeParms<F>>
104{
105 check_input_with_range();
106
107 auto grid_dim = calculate_grid_dim(active_dim);
108
109 using CallableType = raw_type_t<F>;
110 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), active_dim);
111
112 parms->func((void*)details::generic_kernel_with_range<CallableType, UserTag>);
113 parms->grid_dim(grid_dim);
114 parms->block_dim(m_block_dim);
115 parms->shared_mem_bytes(m_shared_mem_size);
116 parms->parse([](details::LaunchCallable<CallableType>& p) -> std::vector<void*>
117 { return {&p}; });
118 return parms;
119}
120
121template <typename F, typename UserTag>
122MUDA_HOST MUDA_NODISCARD auto Launch::as_node_parms(const dim3& active_dim, F&& f, Tag<UserTag>)
123 -> S<NodeParms<F>>
124{
125 return as_node_parms<F, UserTag>(active_dim, std::forward<F>(f));
126}
127
128template <typename F, typename UserTag>
129MUDA_HOST void Launch::invoke(F&& f)
130{
131 check_input();
132
133 using CallableType = raw_type_t<F>;
134 auto callable = details::LaunchCallable<CallableType>{std::forward<F>(f), dim3{0}};
135 details::generic_kernel<CallableType, UserTag>
136 <<<m_grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
137}
138
139template <typename F, typename UserTag>
140MUDA_HOST void Launch::invoke(const dim3& active_dim, F&& f)
141{
142 check_input_with_range();
143
144 dim3 grid_dim = calculate_grid_dim(active_dim);
145
146 using CallableType = raw_type_t<F>;
147 auto callable = details::LaunchCallable<CallableType>{std::forward<F>(f), active_dim};
148 details::generic_kernel_with_range<CallableType, UserTag>
149 <<<grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
150}
151
152template <typename F, typename UserTag>
153MUDA_HOST Launch& Launch::apply(F&& f)
154{
155 if constexpr(COMPUTE_GRAPH_ON)
156 {
157 using CallableType = raw_type_t<F>;
158 ComputeGraphBuilder::invoke_phase_actions(
159 [&] { invoke<F, UserTag>(std::forward<F>(f)); },
160 [&]
161 {
162 auto parms = this->as_node_parms<F, UserTag>(std::forward<F>(f));
163 details::ComputeGraphAccessor().set_kernel_node(parms);
164 },
165 [&]
166 {
167 details::ComputeGraphAccessor().set_kernel_node<KernelNodeParms<CallableType>>(nullptr);
168 });
169 }
170 else
171 {
172 invoke<F, UserTag>(std::forward<F>(f));
173 }
174 pop_kernel_label();
175 return *this;
176}
177
178template <typename F, typename UserTag>
179MUDA_HOST Launch& Launch::apply(F&& f, Tag<UserTag>)
180{
181 return apply<F, UserTag>(std::forward<F>(f));
182}
183template <typename F, typename UserTag>
184MUDA_HOST Launch& muda::Launch::apply(const dim3& active_dim, F&& f)
185{
186 if constexpr(COMPUTE_GRAPH_ON)
187 {
188 using CallableType = raw_type_t<F>;
189 ComputeGraphBuilder::invoke_phase_actions(
190 [&] { invoke<F, UserTag>(active_dim, std::forward<F>(f)); },
191 [&]
192 {
193 auto parms =
194 this->as_node_parms<F, UserTag>(active_dim, std::forward<F>(f));
195 details::ComputeGraphAccessor().set_kernel_node(parms);
196 },
197 [&]
198 {
199 details::ComputeGraphAccessor().set_kernel_node<KernelNodeParms<CallableType>>(nullptr);
200 });
201 }
202 else
203 {
204 invoke<F, UserTag>(active_dim, std::forward<F>(f));
205 }
206 pop_kernel_label();
207
208 return *this;
209}
210
211template <typename F, typename UserTag>
212MUDA_HOST Launch& Launch::apply(const dim3& active_dim, F&& f, Tag<UserTag>)
213{
214 return apply<F, UserTag>(active_dim, std::forward<F>(f));
215}
216
217MUDA_INLINE MUDA_GENERIC dim3 Launch::calculate_grid_dim(const dim3& active_dim) const MUDA_NOEXCEPT
218{
219 dim3 ret;
220
221 ret.x = (active_dim.x + m_block_dim.x - 1) / m_block_dim.x;
222 ret.y = (active_dim.y + m_block_dim.y - 1) / m_block_dim.y;
223 ret.z = (active_dim.z + m_block_dim.z - 1) / m_block_dim.z;
224 return ret;
225}
226
227MUDA_INLINE MUDA_GENERIC void Launch::check_input_with_range() const MUDA_NOEXCEPT
228{
229 MUDA_ASSERT(m_grid_dim.x == 0, "grid_dim should be `dim3{0}`");
230}
231
232MUDA_INLINE MUDA_GENERIC void Launch::check_input() const MUDA_NOEXCEPT
233{
234 MUDA_ASSERT(m_grid_dim.x > 0 && m_grid_dim.y > 0 && m_grid_dim.z > 0,
235 "grid_dim should be non-zero");
236}
237} // namespace muda