45 using S = std::shared_ptr<T>;
46 MUDA_GENERIC ::cudaStream_t stream()
const {
return m_stream; }
48 ::cudaStream_t m_stream;
49 MUDA_HOST
void pop_kernel_label();
52 static void kernel_name(std::string_view name);
53 static void file_line(std::string_view file,
int line);
55 MUDA_GENERIC
LaunchCore(::cudaStream_t stream) MUDA_NOEXCEPT;
57 void init_stream(::cudaStream_t s) { m_stream = s; }
59 void push_range(
const std::string& name);
62 void record(cudaEvent_t e,
int flag = cudaEventRecordDefault);
64 const std::vector<ComputeGraphVarBase*>& vars);
65 template <
typename... ViewT>
67 void when(cudaEvent_t e,
int flag = cudaEventWaitDefault);
69 void wait(cudaEvent_t e,
int flag = cudaEventWaitDefault);
71 const std::vector<ComputeGraphVarBase*>& vars);
72 template <
typename... ViewT>
75 void callback(
const std::function<
void(::cudaStream_t, ::cudaError)>& callback);
77 static void wait_event(cudaEvent_t event);
78 static void wait_stream(::cudaStream_t stream);
79 static void wait_device();
87 template <
typename Others>
92 using derived_type = T;
93 MUDA_GENERIC
LaunchBase(::cudaStream_t stream) MUDA_NOEXCEPT;
102 T& push_range(
const std::string& name);
108 T& kernel_name(std::string_view name);
109 T& file_line(std::string_view file,
int line);
117 T& record(cudaEvent_t e,
int flag = cudaEventRecordDefault);
120 const std::vector<ComputeGraphVarBase*>& vars);
122 template <
typename... ViewT>
136 T& when(cudaEvent_t e,
int flag = cudaEventWaitDefault);
138 T& wait(cudaEvent_t e,
int flag = cudaEventWaitDefault);
140 const std::vector<ComputeGraphVarBase*>& vars);
141 template <
typename... ViewT>
150 T& callback(
const std::function<
void(::cudaStream_t, ::cudaError)>& callback);
152 template <
typename Next>
154 template <
typename Next,
typename... Args>
155 Next next(Args&&... args);
160 T& pop_kernel_label();
163 T& derived() MUDA_NOEXCEPT {
return *(T*)(
this); }