MUDA
Loading...
Searching...
No Matches
stream.h
1#pragma once
2#include <cuda.h>
3#include <cuda_runtime.h>
4#include <cuda_runtime_api.h>
5#include <device_launch_parameters.h>
6#include <muda/check/check_cuda_errors.h>
7#include <muda/tools/temp_buffer.h>
8
9namespace muda
10{
11template <typename T>
12class DeviceBuffer;
13
17class Stream
18{
19 cudaStream_t m_handle = nullptr;
20
21 public:
22 enum class Flag : unsigned int
23 {
24 eDefault = cudaStreamDefault,
25 eNonBlocking = cudaStreamNonBlocking
26 };
27
28 MUDA_NODISCARD Stream(Flag f = Flag::eDefault);
29 ~Stream();
30
31 operator cudaStream_t() const { return m_handle; }
32 cudaStream_t view() const { return m_handle; }
33
34 // delete copy constructor and copy assignment operator
35 Stream(const Stream&) = delete;
36 Stream& operator=(const Stream&) = delete;
37
38 // allow move constructor and move assignment operator
39 Stream(Stream&& o) MUDA_NOEXCEPT;
40 Stream& operator=(Stream&& o) MUDA_NOEXCEPT;
41
42 void wait() const;
43
44 void begin_capture(cudaStreamCaptureMode mode = cudaStreamCaptureModeThreadLocal) const;
45 void end_capture(cudaGraph_t* graph) const;
46
47 static Stream& Default();
48
50 {
51 public:
52 MUDA_DEVICE TailLaunch(){};
53 MUDA_DEVICE operator cudaStream_t() const;
54 };
55
57 {
58 public:
59 MUDA_DEVICE FireAndForget(){};
60 MUDA_DEVICE operator cudaStream_t() const;
61 };
62
64 {
65 public:
66 MUDA_DEVICE GraphTailLaunch(){};
67 MUDA_DEVICE operator cudaStream_t() const;
68 };
69
71 {
72 public:
73 MUDA_DEVICE GraphFireAndForget(){};
74 MUDA_DEVICE operator cudaStream_t() const;
75 };
76
77 std::byte* workspace(size_t byte_size);
78
79 private:
80 Stream(nullptr_t)
81 : m_handle(nullptr)
82 {
83 }
84 details::ByteTempBuffer m_workspace;
85};
86
87
88} // namespace muda
89
90#include "details/stream.inl"
Definition stream.h:57
Definition stream.h:71
Definition stream.h:64
Definition stream.h:50
RAII wrapper for cudaStream.
Definition stream.h:18
Definition temp_buffer.h:8
Definition kernel_tag.h:10