1#include <muda/compute_graph/compute_graph_builder.h>
2#include <muda/compute_graph/nodes/compute_graph_kernel_node.h>
3#include <muda/launch/kernel.h>
8 template <
typename F,
typename UserTag>
9 MUDA_GLOBAL
void generic_kernel(LaunchCallable<F> f)
11 static_assert(std::is_invocable_v<F>,
"f:void (void)");
15 template <
typename F,
typename UserTag>
16 MUDA_GLOBAL
void generic_kernel_with_range(LaunchCallable<F> f)
18 auto x = blockIdx.x * blockDim.x + threadIdx.x;
19 auto y = blockIdx.y * blockDim.y + threadIdx.y;
20 auto z = blockIdx.z * blockDim.z + threadIdx.z;
22 if(x < f.dim.x && y < f.dim.y && z < f.dim.z)
24 if constexpr(std::is_invocable_v<F, int2>)
26 f.callable(int2{
static_cast<int>(x),
static_cast<int>(y)});
28 else if constexpr(std::is_invocable_v<F, int3>)
30 f.callable(int3{
static_cast<int>(x),
static_cast<int>(y),
static_cast<int>(z)});
32 else if constexpr(std::is_invocable_v<F, uint1>)
36 else if constexpr(std::is_invocable_v<F, uint2>)
38 f.callable(uint2{x, y});
40 else if constexpr(std::is_invocable_v<F, uint3>)
42 f.callable(uint3{x, y, z});
44 else if constexpr(std::is_invocable_v<F, dim3>)
46 f.callable(dim3{x, y, z});
48 else if constexpr(std::is_invocable_v<F, int>
49 || std::is_invocable_v<F, unsigned int>)
51 static_assert(
"You should use `ParallelFor()` instead of `Launch()` for better semantics");
55 static_assert(always_false_v<F>,
56 "invalid callable, it should be:"
58 "void (unsigned int) or"
67MUDA_INLINE dim3 cube(
int x) MUDA_NOEXCEPT
72MUDA_INLINE dim3 square(
int x) MUDA_NOEXCEPT
77template <
typename F,
typename UserTag>
78MUDA_INLINE MUDA_HOST
auto Launch::as_node_parms(F&& f) -> S<NodeParms<F>>
82 using CallableType = raw_type_t<F>;
83 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), dim3{0});
85 parms->func((
void*)details::generic_kernel<CallableType, UserTag>);
86 parms->grid_dim(m_grid_dim);
87 parms->block_dim(m_block_dim);
88 parms->shared_mem_bytes(
static_cast<uint32_t
>(m_shared_mem_size));
89 parms->parse([](details::LaunchCallable<CallableType>& p) -> std::vector<void*>
94template <
typename F,
typename UserTag>
95MUDA_HOST MUDA_NODISCARD
auto Launch::as_node_parms(F&& f, Tag<UserTag>)
98 return as_node_parms<F, UserTag>(std::forward<F>(f));
101template <
typename F,
typename UserTag>
102MUDA_INLINE MUDA_HOST
auto Launch::as_node_parms(
const dim3& active_dim, F&& f)
105 check_input_with_range();
107 auto grid_dim = calculate_grid_dim(active_dim);
109 using CallableType = raw_type_t<F>;
110 auto parms = std::make_shared<NodeParms<F>>(std::forward<F>(f), active_dim);
112 parms->func((
void*)details::generic_kernel_with_range<CallableType, UserTag>);
113 parms->grid_dim(grid_dim);
114 parms->block_dim(m_block_dim);
115 parms->shared_mem_bytes(m_shared_mem_size);
116 parms->parse([](details::LaunchCallable<CallableType>& p) -> std::vector<void*>
121template <
typename F,
typename UserTag>
122MUDA_HOST MUDA_NODISCARD
auto Launch::as_node_parms(
const dim3& active_dim, F&& f, Tag<UserTag>)
125 return as_node_parms<F, UserTag>(active_dim, std::forward<F>(f));
128template <
typename F,
typename UserTag>
129MUDA_HOST
void Launch::invoke(F&& f)
133 using CallableType = raw_type_t<F>;
134 auto callable = details::LaunchCallable<CallableType>{std::forward<F>(f), dim3{0}};
135 details::generic_kernel<CallableType, UserTag>
136 <<<m_grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
139template <
typename F,
typename UserTag>
140MUDA_HOST
void Launch::invoke(
const dim3& active_dim, F&& f)
142 check_input_with_range();
144 dim3 grid_dim = calculate_grid_dim(active_dim);
146 using CallableType = raw_type_t<F>;
147 auto callable = details::LaunchCallable<CallableType>{std::forward<F>(f), active_dim};
148 details::generic_kernel_with_range<CallableType, UserTag>
149 <<<grid_dim, m_block_dim, m_shared_mem_size, m_stream>>>(callable);
152template <
typename F,
typename UserTag>
153MUDA_HOST Launch& Launch::apply(F&& f)
155 if constexpr(COMPUTE_GRAPH_ON)
157 using CallableType = raw_type_t<F>;
158 ComputeGraphBuilder::invoke_phase_actions(
159 [&] { invoke<F, UserTag>(std::forward<F>(f)); },
162 auto parms = this->as_node_parms<F, UserTag>(std::forward<F>(f));
163 details::ComputeGraphAccessor().set_kernel_node(parms);
167 details::ComputeGraphAccessor().set_kernel_node<KernelNodeParms<CallableType>>(
nullptr);
172 invoke<F, UserTag>(std::forward<F>(f));
178template <
typename F,
typename UserTag>
179MUDA_HOST Launch& Launch::apply(F&& f, Tag<UserTag>)
181 return apply<F, UserTag>(std::forward<F>(f));
183template <
typename F,
typename UserTag>
184MUDA_HOST Launch& muda::Launch::apply(
const dim3& active_dim, F&& f)
186 if constexpr(COMPUTE_GRAPH_ON)
188 using CallableType = raw_type_t<F>;
189 ComputeGraphBuilder::invoke_phase_actions(
190 [&] { invoke<F, UserTag>(active_dim, std::forward<F>(f)); },
194 this->as_node_parms<F, UserTag>(active_dim, std::forward<F>(f));
195 details::ComputeGraphAccessor().set_kernel_node(parms);
199 details::ComputeGraphAccessor().set_kernel_node<KernelNodeParms<CallableType>>(
nullptr);
204 invoke<F, UserTag>(active_dim, std::forward<F>(f));
211template <
typename F,
typename UserTag>
212MUDA_HOST Launch& Launch::apply(
const dim3& active_dim, F&& f, Tag<UserTag>)
214 return apply<F, UserTag>(active_dim, std::forward<F>(f));
217MUDA_INLINE MUDA_GENERIC dim3 Launch::calculate_grid_dim(
const dim3& active_dim)
const MUDA_NOEXCEPT
221 ret.x = (active_dim.x + m_block_dim.x - 1) / m_block_dim.x;
222 ret.y = (active_dim.y + m_block_dim.y - 1) / m_block_dim.y;
223 ret.z = (active_dim.z + m_block_dim.z - 1) / m_block_dim.z;
227MUDA_INLINE MUDA_GENERIC
void Launch::check_input_with_range() const MUDA_NOEXCEPT
229 MUDA_ASSERT(m_grid_dim.x == 0,
"grid_dim should be `dim3{0}`");
232MUDA_INLINE MUDA_GENERIC
void Launch::check_input() const MUDA_NOEXCEPT
234 MUDA_ASSERT(m_grid_dim.x > 0 && m_grid_dim.y > 0 && m_grid_dim.z > 0,
235 "grid_dim should be non-zero");