MUDA
Loading...
Searching...
No Matches
kernel.h
1#pragma once
2#include <cuda.h>
3#include <muda/muda_def.h>
4#include <muda/launch/stream_define.h>
5#include <type_traits>
6
7namespace muda
8{
9template <typename F>
10class Kernel
11{
12 dim3 m_grid_dim;
13 dim3 m_block_dim;
14 size_t m_shared_memory_size;
15 cudaStream_t m_stream;
16 F m_kernel;
17
18 public:
19 MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, size_t shared_memory_size, cudaStream_t stream, F f)
20 : m_grid_dim(grid_dim)
21 , m_block_dim(m_block_dim)
22 , m_shared_memory_size(shared_memory_size)
23 , m_stream(stream)
24 , m_kernel(f)
25 {
26#ifdef __CUDA_ARCH__
27 MUDA_KERNEL_ASSERT(stream == details::stream::tail_launch()
28 || stream == details::stream::fire_and_forget(),
29 "Kernel Launch on device with invalid stream! "
30 "Only Stream::TailLaunch{} and Stream::FireAndForget{} are allowed");
31#endif
32 }
33
34 MUDA_GENERIC Kernel(F f)
35 : Kernel{1, 1, 0, 0, f}
36 {
37 }
38
39 MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, F f)
40 : Kernel{grid_dim, m_block_dim, 0, 0, f}
41 {
42 }
43
44 MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, size_t shared_memory_size, F f)
45 : Kernel{grid_dim, m_block_dim, shared_memory_size, 0, f}
46 {
47 }
48
49 MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, cudaStream_t stream, F f)
50 : Kernel{grid_dim, m_block_dim, 0, stream, f}
51 {
52 }
53
54 MUDA_GENERIC Kernel(cudaStream_t stream, F f)
55 : Kernel{1, 1, 0, stream, f}
56 {
57 }
58
59 template <typename... Args>
60 MUDA_GENERIC void operator()(Args&&... args) &&
61 {
62 static_assert(std::is_invocable_v<F, Args...>, "invalid arguments");
63#if MUDA_WITH_DEVICE_STREAM_MODEL
64 m_kernel<<<m_grid_dim, m_block_dim, m_shared_memory_size, m_stream>>>(
65 std::forward<Args>(args)...);
66 checkCudaErrors(cudaGetLastError());
67#else
68 cudaStream_t stream = nullptr;
69 if(m_stream == details::stream::tail_launch())
70 {
71 checkCudaErrors(cudaDeviceSynchronize());
72 }
73 else if(m_stream == details::stream::fire_and_forget())
74 {
75 // do nothing
76 }
77 else
78 {
79 stream = m_stream;
80 }
81 m_kernel<<<m_grid_dim, m_block_dim, m_shared_memory_size, stream>>>(
82 std::forward<Args>(args)...);
83 checkCudaErrors(cudaGetLastError());
84#endif
85 }
86
87 // delete copy and move
88 MUDA_GENERIC Kernel(const Kernel&) = delete;
89 MUDA_GENERIC Kernel& operator=(const Kernel&) = delete;
90 MUDA_GENERIC Kernel(Kernel&&) = delete;
91 MUDA_GENERIC Kernel& operator=(Kernel&&) = delete;
92};
93} // namespace muda
Definition kernel.h:11