MUDA
Loading...
Searching...
No Matches
launch_base.inl
1#include <muda/exception.h>
2#include <muda/compute_graph/compute_graph_accessor.h>
3#include <muda/compute_graph/compute_graph_var.h>
4#include <muda/graph/graph.h>
5#include <iostream>
6#include "launch_base.h"
7
8namespace muda
9{
10MUDA_INLINE MUDA_GENERIC LaunchCore::LaunchCore(cudaStream_t stream) MUDA_NOEXCEPT
11 : m_stream(stream)
12{
13//Logger::instance();
14#ifdef __CUDA_ARCH__
15#else
16 if(!ComputeGraphBuilder::is_phase_none())
17 {
18 if(ComputeGraphBuilder::is_phase_serial_launching())
19 {
20 MUDA_ASSERT(stream == nullptr
21 || stream == details::ComputeGraphAccessor().current_stream(),
22 "LaunchBase: stream must be nullptr or equals to current stream");
23 init_stream(details::ComputeGraphAccessor().current_stream());
24 }
25 else if(ComputeGraphBuilder::is_caturing())
26 {
27 init_stream(details::ComputeGraphAccessor().capture_stream());
28 }
29 }
30#endif
31}
32
33MUDA_INLINE void LaunchCore::push_range(const std::string& name)
34{
35 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
36 "`push_range()` is meaningless in ComputeGraph");
37
38 nvtxEventAttributes_t eventAttrib = {0};
39 eventAttrib.version = NVTX_VERSION;
40 eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
41 eventAttrib.colorType = NVTX_COLOR_ARGB;
42 eventAttrib.color = 255;
43 eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
44 eventAttrib.message.ascii = name.c_str();
45 nvtxRangePushEx(&eventAttrib);
46}
47
48MUDA_INLINE void LaunchCore::pop_range()
49{
50 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
51 "`pop_range()` is meaningless in ComputeGraph");
52 nvtxRangePop();
53}
54
55MUDA_INLINE void LaunchCore::record(cudaEvent_t e, int flag)
56{
57 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
58 "You need provide at least one ComputeGraphVar for dependency generation");
59 checkCudaErrors(cudaEventRecordWithFlags(e, stream(), flag));
60}
61
62MUDA_INLINE void LaunchCore::record(ComputeGraphVar<cudaEvent_t>& e,
63 const std::vector<ComputeGraphVarBase*>& vars)
64{
65 auto event = e.eval();
66 for(auto var : vars)
67 var->base_building_eval(); // eval all vars (for safety, we eval them as RWViewer)
68 ComputeGraphBuilder::invoke_phase_actions(
69 [&]
70 {
71 checkCudaErrors(cudaEventRecordWithFlags(event, m_stream, cudaEventRecordDefault));
72 },
73 [&] { details::ComputeGraphAccessor().set_event_record_node(event); },
74 [&] { details::ComputeGraphAccessor().set_event_record_node(nullptr); });
75}
76
77MUDA_INLINE void LaunchCore::when(cudaEvent_t e, int flag)
78{
79 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
80 "`when()` makes code reader confused in ComputeGraph, please use `wait()` instead")
81 checkCudaErrors(cudaStreamWaitEvent(stream(), e, flag));
82}
83
84MUDA_INLINE void LaunchCore::wait(cudaEvent_t e, int flag)
85{
86 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
87 "You need provide at least one ComputeGraphVar for dependency generation");
88
89 checkCudaErrors(cudaStreamWaitEvent(m_stream, e, flag));
90}
91
92MUDA_INLINE void LaunchCore::wait(const ComputeGraphVar<cudaEvent_t>& e,
93 const std::vector<ComputeGraphVarBase*>& vars)
94{
95 auto event = e.ceval();
96 for(auto var : vars)
97 var->base_building_eval(); // eval all vars (for safety, we eval them as RWViewer)
98 ComputeGraphBuilder::invoke_phase_actions(
99 [&]
100 {
101 checkCudaErrors(cudaStreamWaitEvent(m_stream, event, cudaEventWaitDefault));
102 },
103 [&] { details::ComputeGraphAccessor().set_event_wait_node(event); },
104 [&] { details::ComputeGraphAccessor().set_event_wait_node(nullptr); });
105}
106
107MUDA_INLINE void LaunchCore::wait()
108{
109 wait_stream(m_stream);
110}
111
112MUDA_INLINE void LaunchCore::callback(const std::function<void(cudaStream_t, cudaError)>& callback)
113{
114 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
115 "`callback()` in ComputeGraph is unsupported now");
116 auto userdata = new std::function<void(cudaStream_t, cudaError)>(callback);
117 checkCudaErrors(
118 cudaStreamAddCallback(stream(), details::stream_error_callback, userdata, 0));
119}
120
121template <typename... ViewT>
122MUDA_INLINE void LaunchCore::record(ComputeGraphVar<cudaEvent_t>& e,
123 ComputeGraphVar<ViewT>&... vars)
124{
125 record(e, {static_cast<ComputeGraphVarBase*>(&vars)...});
126}
127
128
129template <typename... ViewT>
130MUDA_INLINE void LaunchCore::wait(const ComputeGraphVar<cudaEvent_t>& e,
131 ComputeGraphVar<ViewT>&... vars)
132{
133 return wait(e, {static_cast<ComputeGraphVarBase*>(&vars)...});
134}
135
136MUDA_INLINE void LaunchCore::kernel_name(std::string_view name)
137{
138 if constexpr(muda::RUNTIME_CHECK_ON)
139 details::LaunchInfoCache::current_kernel_name(name);
140}
141
142MUDA_INLINE void muda::LaunchCore::file_line(std::string_view file, int line)
143{
144 if constexpr(muda::RUNTIME_CHECK_ON)
145 {
146 details::LaunchInfoCache::current_kernel_file(file);
147 details::LaunchInfoCache::current_kernel_line(line);
148 }
149}
150
151MUDA_INLINE MUDA_HOST void LaunchCore::pop_kernel_label()
152{
153 if constexpr(muda::RUNTIME_CHECK_ON)
154 {
155 details::LaunchInfoCache::current_kernel_name("");
156 details::LaunchInfoCache::current_kernel_file("");
157 details::LaunchInfoCache::current_kernel_line(0ull);
158 }
159}
160
161
162MUDA_INLINE LaunchCore::~LaunchCore() MUDA_NOEXCEPT
163{
164 if constexpr(muda::RUNTIME_CHECK_ON)
165 {
166 if(ComputeGraphBuilder::is_direct_launching() && Debug::is_debug_sync_all())
167 wait();
168 }
169}
170
171MUDA_INLINE void LaunchCore::wait_event(cudaEvent_t event)
172{
173 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
174 "`wait_event()` is meaningless in ComputeGraph");
175 checkCudaErrors(cudaEventSynchronize(event));
176
177 if constexpr(muda::RUNTIME_CHECK_ON)
178 {
179 Debug::call_sync_callback();
180 }
181}
182
183MUDA_INLINE void LaunchCore::wait_stream(cudaStream_t stream)
184{
185 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
186 "`wait_stream()` a stream is meaningless in ComputeGraph");
187 checkCudaErrors(cudaStreamSynchronize(stream));
188
189 if constexpr(muda::RUNTIME_CHECK_ON)
190 {
191 Debug::call_sync_callback();
192 }
193}
194
195MUDA_INLINE void LaunchCore::wait_device()
196{
197 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
198 "`wait_device()` a stream is meaningless in ComputeGraph");
199 checkCudaErrors(cudaDeviceSynchronize());
200
201 if constexpr(muda::RUNTIME_CHECK_ON)
202 {
203 Debug::call_sync_callback();
204 }
205}
206
207template <typename T>
208MUDA_GENERIC LaunchBase<T>::LaunchBase(cudaStream_t stream) MUDA_NOEXCEPT
209 : LaunchCore(stream)
210{
211}
212
213template <typename T>
214T& LaunchBase<T>::push_range(const std::string& name)
215{
216 LaunchCore::push_range(name);
217 return derived();
218}
219
220template <typename T>
221T& LaunchBase<T>::pop_range()
222{
223 LaunchCore::pop_range();
224 return derived();
225}
226
227template <typename T>
228T& LaunchBase<T>::record(cudaEvent_t e, int flag)
229{
230 LaunchCore::record(e, flag);
231 return derived();
232}
233
234template <typename T>
235T& LaunchBase<T>::record(ComputeGraphVar<cudaEvent_t>& e,
236 const std::vector<ComputeGraphVarBase*>& vars)
237{
238 LaunchCore::record(e, vars);
239 return derived();
240}
241
242template <typename T>
243T& LaunchBase<T>::when(cudaEvent_t e, int flag)
244{
245 LaunchCore::when(e, flag);
246 return derived();
247}
248
249template <typename T>
250T& LaunchBase<T>::wait(cudaEvent_t e, int flag)
251{
252 LaunchCore::wait(e, flag);
253 return derived();
254}
255
256template <typename T>
257T& LaunchBase<T>::wait(const ComputeGraphVar<cudaEvent_t>& e,
258 const std::vector<ComputeGraphVarBase*>& vars)
259{
260 LaunchCore::wait(e, vars);
261 return derived();
262}
263
264template <typename T>
265T& LaunchBase<T>::wait()
266{
267 LaunchCore::wait();
268 return derived();
269}
270
271template <typename T>
272T& LaunchBase<T>::callback(const std::function<void(cudaStream_t, cudaError)>& callback)
273{
274 LaunchCore::callback(callback);
275 return derived();
276}
277template <typename T>
278template <typename... ViewT>
279T& LaunchBase<T>::record(ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars)
280{
281 return record(e, {static_cast<ComputeGraphVarBase*>(&vars)...});
282}
283
284template <typename T>
285template <typename... ViewT>
286T& LaunchBase<T>::wait(const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars)
287{
288 return wait(e, {static_cast<ComputeGraphVarBase*>(&vars)...});
289}
290
291template <typename T>
292template <typename Next>
293Next LaunchBase<T>::next(Next n)
294{
295 static_assert(std::is_base_of_v<LaunchBase<Next>, Next>,
296 "Next should be derived from LaunchBase<Next>");
297 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(), "`next()` is not allowed in ComputeGraph");
298 n.init_stream(stream());
299 return n;
300}
301
302template <typename T>
303template <typename Next, typename... Args>
304Next LaunchBase<T>::next(Args&&... args)
305{
306 static_assert(std::is_base_of_v<LaunchBase<Next>, Next>,
307 "Next should be derived from LaunchBase<Next>");
308 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(), "`next()` is not allowed in ComputeGraph");
309 Next n(std::forward<Args>(args)...);
310 n.init_stream(stream());
311 return n;
312}
313
314template <typename T>
315T& LaunchBase<T>::kernel_name(std::string_view name)
316{
317 LaunchCore::kernel_name(name);
318 return derived();
319}
320
321template <typename T>
322T& muda::LaunchBase<T>::file_line(std::string_view file, int line)
323{
324 LaunchCore::file_line(file, line);
325 return derived();
326}
327
328template <typename T>
329T& LaunchBase<T>::pop_kernel_label()
330{
331 LaunchCore::pop_kernel_label();
332 return derived();
333}
334
335template <typename T>
336LaunchBase<T>::~LaunchBase() MUDA_NOEXCEPT
337{
338}
339
340MUDA_INLINE Empty on(cudaStream_t stream)
341{
342 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
343 "`on(stream)` is meaningless in ComputeGraph, using `on()` is enough");
344 return Empty(stream);
345}
346
347MUDA_INLINE Empty on()
348{
349 return Empty(nullptr);
350}
351
352MUDA_INLINE void wait_device()
353{
354 Empty::wait_device();
355}
356
357MUDA_INLINE void wait_stream(cudaStream_t stream)
358{
359 Empty::wait_stream(stream);
360}
361
362MUDA_INLINE void wait_event(cudaEvent_t event)
363{
364 Empty::wait_event(event);
365}
366} // namespace muda
Definition launch_base.h:86