File stream.h
Go to the documentation of this file
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <device_launch_parameters.h>
#include <muda/check/check_cuda_errors.h>
#include <muda/tools/temp_buffer.h>
namespace muda
{
template <typename T>
class DeviceBuffer;
class Stream
{
cudaStream_t m_handle = nullptr;
public:
enum class Flag : unsigned int
{
eDefault = cudaStreamDefault,
eNonBlocking = cudaStreamNonBlocking
};
MUDA_NODISCARD Stream(Flag f = Flag::eDefault);
~Stream();
operator cudaStream_t() const { return m_handle; }
cudaStream_t view() const { return m_handle; }
// delete copy constructor and copy assignment operator
Stream(const Stream&) = delete;
Stream& operator=(const Stream&) = delete;
// allow move constructor and move assignment operator
Stream(Stream&& o) MUDA_NOEXCEPT;
Stream& operator=(Stream&& o) MUDA_NOEXCEPT;
void wait() const;
void begin_capture(cudaStreamCaptureMode mode = cudaStreamCaptureModeThreadLocal) const;
void end_capture(cudaGraph_t* graph) const;
static Stream& Default();
class TailLaunch
{
public:
MUDA_DEVICE TailLaunch(){};
MUDA_DEVICE operator cudaStream_t() const;
};
class FireAndForget
{
public:
MUDA_DEVICE FireAndForget(){};
MUDA_DEVICE operator cudaStream_t() const;
};
class GraphTailLaunch
{
public:
MUDA_DEVICE GraphTailLaunch(){};
MUDA_DEVICE operator cudaStream_t() const;
};
class GraphFireAndForget
{
public:
MUDA_DEVICE GraphFireAndForget(){};
MUDA_DEVICE operator cudaStream_t() const;
};
std::byte* workspace(size_t byte_size);
private:
Stream(nullptr_t)
: m_handle(nullptr)
{
}
details::ByteTempBuffer m_workspace;
};
} // namespace muda
#include "details/stream.inl"