MUDA
Loading...
Searching...
No Matches
temp_buffer.h
1#pragma once
2#include <cuda_runtime.h>
3#include <muda/check/check.h>
4namespace muda::details
5{
6template <typename T>
8{
9 public:
10 TempBuffer() {}
11
12 TempBuffer(size_t size) { resize(size); }
13
15 {
16 if(m_data)
17 {
18 // we don't check the error here to prevent exception when app is shutting down
19 cudaFree(m_data);
20 }
21 }
22
23 TempBuffer(TempBuffer&& other) noexcept
24 {
25 m_size = other.m_size;
26 m_capacity = other.m_capacity;
27 m_data = other.m_data;
28 other.m_size = 0;
29 other.m_capacity = 0;
30 other.m_data = nullptr;
31 }
32
33 TempBuffer& operator=(TempBuffer&& other) noexcept
34 {
35 if(this == &other)
36 {
37 return *this;
38 }
39 m_size = other.m_size;
40 m_capacity = other.m_capacity;
41 m_data = other.m_data;
42 other.m_size = 0;
43 other.m_capacity = 0;
44 other.m_data = nullptr;
45 return *this;
46 }
47
48 // no change on copy
49 TempBuffer(const TempBuffer&) noexcept {}
50 // no change on copy
51 TempBuffer& operator=(const TempBuffer&) noexcept { return *this; }
52
53 void copy_to(std::vector<T>& vec, cudaStream_t stream = nullptr) const
54 {
55 vec.resize(m_size);
56 checkCudaErrors(cudaMemcpyAsync(
57 vec.data(), m_data, m_size * sizeof(T), cudaMemcpyDeviceToHost, stream));
58 }
59
60 void copy_from(TempBuffer<T>& other, cudaStream_t stream = nullptr)
61 {
62 resize(other.size());
63 checkCudaErrors(cudaMemcpyAsync(
64 m_data, other.data(), other.size() * sizeof(T), cudaMemcpyDeviceToDevice, stream));
65 }
66
67 void copy_from(const std::vector<T>& vec, cudaStream_t stream = nullptr)
68 {
69 resize(vec.size());
70 checkCudaErrors(cudaMemcpyAsync(
71 m_data, vec.data(), vec.size() * sizeof(T), cudaMemcpyHostToDevice, stream));
72 }
73
74 TempBuffer(const std::vector<T>& vec) { copy_from(vec); }
75
76 TempBuffer& operator=(const std::vector<T>& vec)
77 {
78 copy_from(vec);
79 return *this;
80 }
81
82 void reserve(size_t new_cap, cudaStream_t stream = nullptr)
83 {
84 if(new_cap <= m_capacity)
85 {
86 return;
87 }
88 T* new_data = nullptr;
89 checkCudaErrors(cudaMalloc(&new_data, new_cap * sizeof(T)));
90 if(m_data)
91 {
92 checkCudaErrors(cudaFree(m_data));
93 }
94 m_data = new_data;
95 m_capacity = new_cap;
96 }
97
98 void resize(size_t size, cudaStream_t stream = nullptr)
99 {
100 if(size <= m_capacity)
101 {
102 m_size = size;
103 return;
104 }
105 reserve(size, stream);
106 m_size = size;
107 }
108
109 void free() noexcept
110 {
111 m_size = 0;
112 m_capacity = 0;
113 if(m_data)
114 {
115 checkCudaErrors(cudaFree(m_data));
116 m_data = nullptr;
117 }
118 }
119
120 auto size() const noexcept { return m_size; }
121 auto data() const noexcept { return m_data; }
122 auto capacity() const noexcept { return m_capacity; }
123
124 private:
125 size_t m_size = 0;
126 size_t m_capacity = 0;
127 T* m_data = nullptr;
128};
129
131} // namespace muda::details
Definition temp_buffer.h:8