1#include <muda/exception.h>
2#include <muda/compute_graph/compute_graph_accessor.h>
3#include <muda/compute_graph/compute_graph_var.h>
4#include <muda/graph/graph.h>
6#include "launch_base.h"
10MUDA_INLINE MUDA_GENERIC LaunchCore::LaunchCore(cudaStream_t stream) MUDA_NOEXCEPT
16 if(!ComputeGraphBuilder::is_phase_none())
18 if(ComputeGraphBuilder::is_phase_serial_launching())
20 MUDA_ASSERT(stream ==
nullptr
21 || stream == details::ComputeGraphAccessor().current_stream(),
22 "LaunchBase: stream must be nullptr or equals to current stream");
23 init_stream(details::ComputeGraphAccessor().current_stream());
25 else if(ComputeGraphBuilder::is_caturing())
27 init_stream(details::ComputeGraphAccessor().capture_stream());
33MUDA_INLINE
void LaunchCore::push_range(
const std::string& name)
35 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
36 "`push_range()` is meaningless in ComputeGraph");
38 nvtxEventAttributes_t eventAttrib = {0};
39 eventAttrib.version = NVTX_VERSION;
40 eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
41 eventAttrib.colorType = NVTX_COLOR_ARGB;
42 eventAttrib.color = 255;
43 eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
44 eventAttrib.message.ascii = name.c_str();
45 nvtxRangePushEx(&eventAttrib);
48MUDA_INLINE
void LaunchCore::pop_range()
50 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
51 "`pop_range()` is meaningless in ComputeGraph");
55MUDA_INLINE
void LaunchCore::record(cudaEvent_t e,
int flag)
57 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
58 "You need provide at least one ComputeGraphVar for dependency generation");
59 checkCudaErrors(cudaEventRecordWithFlags(e, stream(), flag));
62MUDA_INLINE
void LaunchCore::record(ComputeGraphVar<cudaEvent_t>& e,
63 const std::vector<ComputeGraphVarBase*>& vars)
65 auto event = e.eval();
67 var->base_building_eval();
68 ComputeGraphBuilder::invoke_phase_actions(
71 checkCudaErrors(cudaEventRecordWithFlags(event, m_stream, cudaEventRecordDefault));
73 [&] { details::ComputeGraphAccessor().set_event_record_node(event); },
74 [&] { details::ComputeGraphAccessor().set_event_record_node(
nullptr); });
77MUDA_INLINE
void LaunchCore::when(cudaEvent_t e,
int flag)
79 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
80 "`when()` makes code reader confused in ComputeGraph, please use `wait()` instead")
81 checkCudaErrors(cudaStreamWaitEvent(stream(), e, flag));
84MUDA_INLINE
void LaunchCore::wait(cudaEvent_t e,
int flag)
86 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
87 "You need provide at least one ComputeGraphVar for dependency generation");
89 checkCudaErrors(cudaStreamWaitEvent(m_stream, e, flag));
92MUDA_INLINE
void LaunchCore::wait(
const ComputeGraphVar<cudaEvent_t>& e,
93 const std::vector<ComputeGraphVarBase*>& vars)
95 auto event = e.ceval();
97 var->base_building_eval();
98 ComputeGraphBuilder::invoke_phase_actions(
101 checkCudaErrors(cudaStreamWaitEvent(m_stream, event, cudaEventWaitDefault));
103 [&] { details::ComputeGraphAccessor().set_event_wait_node(event); },
104 [&] { details::ComputeGraphAccessor().set_event_wait_node(
nullptr); });
107MUDA_INLINE
void LaunchCore::wait()
109 wait_stream(m_stream);
112MUDA_INLINE
void LaunchCore::callback(
const std::function<
void(cudaStream_t, cudaError)>& callback)
114 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
115 "`callback()` in ComputeGraph is unsupported now");
116 auto userdata =
new std::function<void(cudaStream_t, cudaError)>(callback);
118 cudaStreamAddCallback(stream(), details::stream_error_callback, userdata, 0));
121template <
typename... ViewT>
122MUDA_INLINE
void LaunchCore::record(ComputeGraphVar<cudaEvent_t>& e,
123 ComputeGraphVar<ViewT>&... vars)
125 record(e, {
static_cast<ComputeGraphVarBase*
>(&vars)...});
129template <
typename... ViewT>
130MUDA_INLINE
void LaunchCore::wait(
const ComputeGraphVar<cudaEvent_t>& e,
131 ComputeGraphVar<ViewT>&... vars)
133 return wait(e, {
static_cast<ComputeGraphVarBase*
>(&vars)...});
136MUDA_INLINE
void LaunchCore::kernel_name(std::string_view name)
138 if constexpr(muda::RUNTIME_CHECK_ON)
139 details::LaunchInfoCache::current_kernel_name(name);
142MUDA_INLINE
void muda::LaunchCore::file_line(std::string_view file,
int line)
144 if constexpr(muda::RUNTIME_CHECK_ON)
146 details::LaunchInfoCache::current_kernel_file(file);
147 details::LaunchInfoCache::current_kernel_line(line);
151MUDA_INLINE MUDA_HOST
void LaunchCore::pop_kernel_label()
153 if constexpr(muda::RUNTIME_CHECK_ON)
155 details::LaunchInfoCache::current_kernel_name(
"");
156 details::LaunchInfoCache::current_kernel_file(
"");
157 details::LaunchInfoCache::current_kernel_line(0ull);
162MUDA_INLINE LaunchCore::~LaunchCore() MUDA_NOEXCEPT
164 if constexpr(muda::RUNTIME_CHECK_ON)
166 if(ComputeGraphBuilder::is_direct_launching() && Debug::is_debug_sync_all())
171MUDA_INLINE
void LaunchCore::wait_event(cudaEvent_t event)
173 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
174 "`wait_event()` is meaningless in ComputeGraph");
175 checkCudaErrors(cudaEventSynchronize(event));
177 if constexpr(muda::RUNTIME_CHECK_ON)
179 Debug::call_sync_callback();
183MUDA_INLINE
void LaunchCore::wait_stream(cudaStream_t stream)
185 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
186 "`wait_stream()` a stream is meaningless in ComputeGraph");
187 checkCudaErrors(cudaStreamSynchronize(stream));
189 if constexpr(muda::RUNTIME_CHECK_ON)
191 Debug::call_sync_callback();
195MUDA_INLINE
void LaunchCore::wait_device()
197 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
198 "`wait_device()` a stream is meaningless in ComputeGraph");
199 checkCudaErrors(cudaDeviceSynchronize());
201 if constexpr(muda::RUNTIME_CHECK_ON)
203 Debug::call_sync_callback();
208MUDA_GENERIC LaunchBase<T>::LaunchBase(cudaStream_t stream) MUDA_NOEXCEPT
214T& LaunchBase<T>::push_range(
const std::string& name)
216 LaunchCore::push_range(name);
221T& LaunchBase<T>::pop_range()
223 LaunchCore::pop_range();
228T& LaunchBase<T>::record(cudaEvent_t e,
int flag)
230 LaunchCore::record(e, flag);
235T& LaunchBase<T>::record(ComputeGraphVar<cudaEvent_t>& e,
236 const std::vector<ComputeGraphVarBase*>& vars)
238 LaunchCore::record(e, vars);
243T& LaunchBase<T>::when(cudaEvent_t e,
int flag)
245 LaunchCore::when(e, flag);
250T& LaunchBase<T>::wait(cudaEvent_t e,
int flag)
252 LaunchCore::wait(e, flag);
257T& LaunchBase<T>::wait(
const ComputeGraphVar<cudaEvent_t>& e,
258 const std::vector<ComputeGraphVarBase*>& vars)
260 LaunchCore::wait(e, vars);
265T& LaunchBase<T>::wait()
272T& LaunchBase<T>::callback(
const std::function<
void(cudaStream_t, cudaError)>& callback)
274 LaunchCore::callback(callback);
278template <
typename... ViewT>
279T& LaunchBase<T>::record(ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars)
281 return record(e, {
static_cast<ComputeGraphVarBase*
>(&vars)...});
285template <
typename... ViewT>
286T& LaunchBase<T>::wait(
const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars)
288 return wait(e, {
static_cast<ComputeGraphVarBase*
>(&vars)...});
292template <
typename Next>
293Next LaunchBase<T>::next(Next n)
295 static_assert(std::is_base_of_v<LaunchBase<Next>, Next>,
296 "Next should be derived from LaunchBase<Next>");
297 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
"`next()` is not allowed in ComputeGraph");
298 n.init_stream(stream());
303template <
typename Next,
typename... Args>
304Next LaunchBase<T>::next(Args&&... args)
306 static_assert(std::is_base_of_v<LaunchBase<Next>, Next>,
307 "Next should be derived from LaunchBase<Next>");
308 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
"`next()` is not allowed in ComputeGraph");
309 Next n(std::forward<Args>(args)...);
310 n.init_stream(stream());
315T& LaunchBase<T>::kernel_name(std::string_view name)
317 LaunchCore::kernel_name(name);
324 LaunchCore::file_line(file, line);
329T& LaunchBase<T>::pop_kernel_label()
331 LaunchCore::pop_kernel_label();
336LaunchBase<T>::~LaunchBase() MUDA_NOEXCEPT
340MUDA_INLINE Empty on(cudaStream_t stream)
342 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
343 "`on(stream)` is meaningless in ComputeGraph, using `on()` is enough");
344 return Empty(stream);
347MUDA_INLINE Empty on()
349 return Empty(
nullptr);
352MUDA_INLINE
void wait_device()
354 Empty::wait_device();
357MUDA_INLINE
void wait_stream(cudaStream_t stream)
359 Empty::wait_stream(stream);
362MUDA_INLINE
void wait_event(cudaEvent_t event)
364 Empty::wait_event(event);
Definition launch_base.h:86