MUDA
Loading...
Searching...
No Matches
launch_base.h
1#pragma once
2#include <cuda.h>
3#include <cuda_runtime.h>
4#include <cuda_runtime_api.h>
5#include <device_launch_parameters.h>
6
7#include <string>
8#include <functional>
9#include <memory>
10#include <cooperative_groups.h>
11
12#include <cuda_profiler_api.h>
13#include <nvtx3/nvToolsExt.h>
14#include <nvtx3/nvToolsExtCuda.h>
15#include <muda/type_traits/type_modifier.h>
16#include <muda/tools/launch_info_cache.h>
17
18#include <muda/check/check_cuda_errors.h>
19#include <muda/muda_def.h>
20#include <muda/launch/event.h>
21#include <muda/launch/kernel_tag.h>
22
23namespace muda
24{
25namespace details
26{
27 inline void stream_error_callback(cudaStream_t stream, cudaError error, void* userdata)
28 {
29 auto callback =
30 reinterpret_cast<std::function<void(cudaStream_t, cudaError)>*>(userdata);
31 (*callback)(stream, error);
32 delete callback;
33 }
34} // namespace details
35
36class ComputeGraphVarBase;
37
38template <typename T>
39class ComputeGraphVar;
40
42{
43 protected:
44 template <typename T>
45 using S = std::shared_ptr<T>;
46 MUDA_GENERIC ::cudaStream_t stream() const { return m_stream; }
47
48 ::cudaStream_t m_stream;
49 MUDA_HOST void pop_kernel_label();
50
51 public:
52 static void kernel_name(std::string_view name);
53 static void file_line(std::string_view file, int line);
54
55 MUDA_GENERIC LaunchCore(::cudaStream_t stream) MUDA_NOEXCEPT;
56
57 void init_stream(::cudaStream_t s) { m_stream = s; }
58
59 void push_range(const std::string& name);
60 void pop_range();
61
62 void record(cudaEvent_t e, int flag = cudaEventRecordDefault);
63 void record(ComputeGraphVar<cudaEvent_t>& e,
64 const std::vector<ComputeGraphVarBase*>& vars);
65 template <typename... ViewT>
67 void when(cudaEvent_t e, int flag = cudaEventWaitDefault);
68 // let the host wait for the event
69 void wait(cudaEvent_t e, int flag = cudaEventWaitDefault);
70 void wait(const ComputeGraphVar<cudaEvent_t>& e,
71 const std::vector<ComputeGraphVarBase*>& vars);
72 template <typename... ViewT>
73 void wait(const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars);
74 void wait();
75 void callback(const std::function<void(::cudaStream_t, ::cudaError)>& callback);
76
77 static void wait_event(cudaEvent_t event);
78 static void wait_stream(::cudaStream_t stream);
79 static void wait_device();
80
81 ~LaunchCore() MUDA_NOEXCEPT;
82};
83
84template <typename T>
85class LaunchBase : public LaunchCore
86{
87 template <typename Others>
88 friend class LaunchBase;
89 using Base = LaunchCore;
90
91 public:
92 using derived_type = T;
93 MUDA_GENERIC LaunchBase(::cudaStream_t stream) MUDA_NOEXCEPT;
94
95 // create a named scope for better recognization (if you are using some profile tools)
96 // usage:
97 // on(stream)
98 // .push_range("part1")
99 // .next<launch>(1,1).apply(...)
100 // .pop_range()
101 // .wait();
102 T& push_range(const std::string& name);
103 T& pop_range();
104
105
106 // create a name for the following kernel launch
107 // viewers will record this name for the sake of better recognization when debugging
108 T& kernel_name(std::string_view name);
109 T& file_line(std::string_view file, int line);
110
111 // record an event on this point with current stream, you could use .when() to
112 // capture this event for synchronization
113 // flags:
114 // cudaEventRecordDefault : Default event creation flag.
115 // cudaEventRecordExternal : Event is captured in the graph as an external
116 // event node when performing stream capture.
117 T& record(cudaEvent_t e, int flag = cudaEventRecordDefault);
118
119 T& record(ComputeGraphVar<cudaEvent_t>& e,
120 const std::vector<ComputeGraphVarBase*>& vars);
121
122 template <typename... ViewT>
124
125 // let the following kernels wait until the event is triggered
126 // (asynchronize with the host)
127 // usage:
128 // on(stream)
129 // .when(event)
130 // .next<launch>(1,1).apply(...)
131 // .wait();
132 // flags:
133 // cudaEventRecordDefault : Default event creation flag.
134 // cudaEventRecordExternal : Event is captured in the graph as an external
135 // event node when performing stream capture.
136 T& when(cudaEvent_t e, int flag = cudaEventWaitDefault);
137 // let the host wait for the event
138 T& wait(cudaEvent_t e, int flag = cudaEventWaitDefault);
139 T& wait(const ComputeGraphVar<cudaEvent_t>& e,
140 const std::vector<ComputeGraphVarBase*>& vars);
141 template <typename... ViewT>
142 T& wait(const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars);
143
144
145 // let the host wait for the current stream
146 T& wait();
147
148 // register a host callback function, which will be called when all the jobs before
149 // this point are done.
150 T& callback(const std::function<void(::cudaStream_t, ::cudaError)>& callback);
151
152 template <typename Next>
153 Next next(Next n);
154 template <typename Next, typename... Args>
155 Next next(Args&&... args);
156
157 ~LaunchBase() MUDA_NOEXCEPT;
158
159 protected:
160 T& pop_kernel_label();
161
162 private:
163 T& derived() MUDA_NOEXCEPT { return *(T*)(this); }
164};
165
166class Empty : public LaunchBase<Empty>
167{
168 public:
169 Empty(::cudaStream_t stream = nullptr)
170 : LaunchBase(stream)
171 {
172 }
173};
174
175Empty on(::cudaStream_t stream);
176
177Empty on();
178
179void wait_device();
180void wait_stream(::cudaStream_t stream);
181void wait_event(cudaEvent_t event);
182} // namespace muda
183
184#include "details/launch_base.inl"
Definition compute_graph_var.h:90
Definition launch_base.h:167
Definition launch_base.h:86
Definition launch_base.h:42