1#include <cuda_device_runtime_api.h>
2#include <muda/launch/stream_define.h>
6MUDA_INLINE Stream::Stream(Stream::Flag f)
8 checkCudaErrors(cudaStreamCreateWithFlags(&m_handle,
static_cast<unsigned int>(f)));
11MUDA_INLINE
void Stream::wait()
const
13 checkCudaErrors(cudaStreamSynchronize(m_handle));
16MUDA_INLINE
void Stream::begin_capture(cudaStreamCaptureMode mode)
const
18 checkCudaErrors(cudaStreamBeginCapture(m_handle, mode));
21MUDA_INLINE
void Stream::end_capture(cudaGraph_t* graph)
const
23 checkCudaErrors(cudaStreamEndCapture(m_handle, graph));
26MUDA_INLINE Stream& Stream::Default()
28 thread_local static Stream s{
nullptr};
32MUDA_INLINE std::byte* Stream::workspace(
size_t byte_size)
34 m_workspace.resize(byte_size, m_handle);
35 return m_workspace.data();
38MUDA_INLINE MUDA_DEVICE Stream::TailLaunch::operator cudaStream_t()
const
40 return details::stream::tail_launch();
43MUDA_INLINE MUDA_DEVICE Stream::FireAndForget::operator cudaStream_t()
const
46 return details::stream::fire_and_forget();
49MUDA_INLINE MUDA_DEVICE Stream::GraphTailLaunch::operator cudaStream_t()
const
52 return details::stream::graph_tail_launch();
55MUDA_INLINE MUDA_DEVICE Stream::GraphFireAndForget::operator cudaStream_t()
const
57 return details::stream::graph_fire_and_forget();
60MUDA_INLINE Stream::~Stream()
63 checkCudaErrors(cudaStreamDestroy(m_handle));
66MUDA_INLINE Stream::Stream(Stream&& o) MUDA_NOEXCEPT : m_handle(o.m_handle)
71MUDA_INLINE Stream& Stream::operator=(Stream&& o) MUDA_NOEXCEPT
77 checkCudaErrors(cudaStreamDestroy(m_handle));
79 m_handle = o.m_handle;