1#include <muda/buffer/var_view.h>
3#include <muda/buffer/buffer_2d_view.h>
4#include <muda/buffer/buffer_3d_view.h>
6#include <muda/buffer/graph_var_view.h>
7#include <muda/buffer/graph_buffer_view.h>
8#include <muda/buffer/graph_buffer_2d_view.h>
9#include <muda/buffer/graph_buffer_3d_view.h>
11#include <muda/buffer/agent.h>
12#include <muda/buffer/reshape_nd/nd_reshaper.h>
26MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer<T>& buffer,
size_t new_size)
31 [&](BufferView<T> view)
33 if constexpr(std::is_trivially_constructible_v<T>)
35 Memory(m_stream).set(view.data(), view.size() *
sizeof(T), 0);
39 static_assert(std::is_constructible_v<T>,
40 "The type T must be constructible, which means T must have a 0-arg constructor");
42 details::buffer::kernel_construct(m_grid_dim, m_block_dim, m_stream, view);
48MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer2D<T>& buffer, Extent2D extent)
52 [&](Buffer2DView<T> view)
72 details::buffer::kernel_construct(m_grid_dim, m_block_dim, m_stream, view);
77MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer3D<T>& buffer, Extent3D extent)
81 [&](Buffer3DView<T> view)
98 details::buffer::kernel_construct(m_grid_dim, m_block_dim, m_stream, view);
103MUDA_HOST BufferLaunch& BufferLaunch::reserve(DeviceBuffer<T>& buffer,
size_t capacity)
105 NDReshaper::reserve(m_grid_dim, m_block_dim, m_stream, buffer, capacity);
110MUDA_HOST BufferLaunch& BufferLaunch::reserve(DeviceBuffer2D<T>& buffer, Extent2D capacity)
112 NDReshaper::reserve(m_grid_dim, m_block_dim, m_stream, buffer, capacity);
117MUDA_HOST BufferLaunch& BufferLaunch::reserve(DeviceBuffer3D<T>& buffer, Extent3D capacity)
119 NDReshaper::reserve(m_grid_dim, m_block_dim, m_stream, buffer, capacity);
124MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer<T>& buffer,
size_t new_size,
const T& val)
126 return resize(buffer, new_size, [&](BufferView<T> view) { fill(view, val); });
130MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer2D<T>& buffer,
134 return resize(buffer, extent, [&](Buffer2DView<T> view) { fill(view, val); });
138MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer3D<T>& buffer,
142 return resize(buffer, extent, [&](Buffer3DView<T> view) { fill(view, val); });
146MUDA_HOST BufferLaunch& BufferLaunch::clear(DeviceBuffer<T>& buffer)
153MUDA_HOST BufferLaunch& BufferLaunch::clear(DeviceBuffer2D<T>& buffer)
155 resize(buffer, Extent2D::Zero());
160MUDA_HOST BufferLaunch& BufferLaunch::clear(DeviceBuffer3D<T>& buffer)
162 resize(buffer, Extent3D::Zero());
167MUDA_HOST BufferLaunch& BufferLaunch::alloc(DeviceBuffer<T>& buffer,
size_t n)
169 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
170 "cannot alloc a buffer in a compute graph");
171 MUDA_ASSERT(!buffer.m_data,
"The buffer is already allocated");
177MUDA_HOST BufferLaunch& BufferLaunch::alloc(DeviceBuffer2D<T>& buffer, Extent2D extent)
179 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
180 "cannot alloc a buffer in a compute graph");
181 MUDA_ASSERT(!buffer.m_data,
"The buffer is already allocated");
182 resize(buffer, extent);
187MUDA_HOST BufferLaunch& BufferLaunch::alloc(DeviceBuffer3D<T>& buffer, Extent3D extent)
189 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
190 "cannot alloc a buffer in a compute graph");
191 MUDA_ASSERT(!buffer.m_data,
"The buffer is already allocated");
192 resize(buffer, extent);
197MUDA_HOST BufferLaunch& BufferLaunch::free(DeviceBuffer<T>& buffer)
199 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
200 "cannot free a buffer in a compute graph");
201 MUDA_ASSERT(buffer.m_data,
"The buffer is not allocated");
203 auto& m_data = buffer.m_data;
204 auto& m_size = buffer.m_size;
205 auto& m_capacity = buffer.m_capacity;
207 Memory(m_stream).free(m_data);
215MUDA_HOST BufferLaunch& BufferLaunch::free(DeviceBuffer2D<T>& buffer)
217 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
218 "cannot free a buffer in a compute graph");
219 MUDA_ASSERT(buffer.m_data,
"The buffer is not allocated");
221 auto& m_data = buffer.m_data;
222 auto& m_pitch_bytes = buffer.m_pitch_bytes;
223 auto& m_extent = buffer.m_extent;
224 auto& m_capacity = buffer.m_capacity;
226 Memory(m_stream).free(m_data);
229 m_extent = Extent2D::Zero();
230 m_capacity = Extent2D::Zero();
235MUDA_HOST BufferLaunch& BufferLaunch::free(DeviceBuffer3D<T>& buffer)
237 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
238 "cannot free a buffer in a compute graph");
239 MUDA_ASSERT(buffer.m_data,
"The buffer is not allocated");
241 auto& m_data = buffer.m_data;
242 auto& m_pitch_bytes = buffer.m_pitch_bytes;
243 auto& m_pitch_bytes_area = buffer.m_pitch_bytes_area;
244 auto& m_extent = buffer.m_extent;
245 auto& m_capacity = buffer.m_capacity;
247 Memory(m_stream).free(m_data);
250 m_pitch_bytes_area = 0;
251 m_extent = Extent3D::Zero();
252 m_capacity = Extent3D::Zero();
257MUDA_HOST BufferLaunch& BufferLaunch::shrink_to_fit(DeviceBuffer<T>& buffer)
259 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
260 "cannot shrink a buffer in a compute graph");
261 NDReshaper::shrink_to_fit(m_grid_dim, m_block_dim, m_stream, buffer);
266MUDA_HOST BufferLaunch& BufferLaunch::shrink_to_fit(DeviceBuffer2D<T>& buffer)
268 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
269 "cannot shrink a buffer in a compute graph");
270 NDReshaper::shrink_to_fit(m_grid_dim, m_block_dim, m_stream, buffer);
275MUDA_HOST BufferLaunch& BufferLaunch::shrink_to_fit(DeviceBuffer3D<T>& buffer)
277 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
278 "cannot shrink a buffer in a compute graph");
279 NDReshaper::shrink_to_fit(m_grid_dim, m_block_dim, m_stream, buffer);
290MUDA_HOST BufferLaunch& BufferLaunch::copy(VarView<T> dst, CVarView<T> src)
292 details::buffer::kernel_assign(m_stream, dst, src);
297MUDA_HOST BufferLaunch& BufferLaunch::copy(BufferView<T> dst, CBufferView<T> src)
299 MUDA_ASSERT(dst.size() == src.size(),
"BufferView should have the same size");
300 details::buffer::kernel_assign(m_grid_dim, m_block_dim, m_stream, dst, src);
305MUDA_HOST BufferLaunch& BufferLaunch::copy(Buffer2DView<T> dst, CBuffer2DView<T> src)
307 MUDA_ASSERT(dst.extent() == src.extent(),
"BufferView should have the same size");
308 details::buffer::kernel_assign(m_grid_dim, m_block_dim, m_stream, dst, src);
313MUDA_HOST BufferLaunch& BufferLaunch::copy(Buffer3DView<T> dst, CBuffer3DView<T> src)
315 MUDA_ASSERT(dst.extent() == src.extent(),
"BufferView should have the same size");
316 details::buffer::kernel_assign(m_grid_dim, m_block_dim, m_stream, dst, src);
321MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<VarView<T>>& dst,
322 const ComputeGraphVar<VarView<T>>& src)
324 return copy(dst.eval(), src.ceval());
328MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<BufferView<T>>& dst,
329 const ComputeGraphVar<BufferView<T>>& src)
331 return copy(dst.eval(), src.ceval());
335MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<Buffer2DView<T>>& dst,
336 const ComputeGraphVar<Buffer2DView<T>>& src)
338 return copy(dst.eval(), src.ceval());
342MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<Buffer3DView<T>>& dst,
343 const ComputeGraphVar<Buffer3DView<T>>& src)
345 return copy(dst.eval(), src.ceval());
355MUDA_HOST BufferLaunch& BufferLaunch::copy(T* dst, CVarView<T> src)
357 Memory(m_stream).download(dst, src.data(),
sizeof(T));
362MUDA_HOST BufferLaunch& BufferLaunch::copy(T* dst, CBufferView<T> src)
364 Memory(m_stream).download(dst, src.data(), src.size() *
sizeof(T));
370MUDA_HOST BufferLaunch& BufferLaunch::copy(T* dst, CBuffer2DView<T> src)
372 cudaMemcpy3DParms parms = {0};
374 parms.srcPtr = src.cuda_pitched_ptr();
375 parms.srcPos = src.offset().template cuda_pos<T>();
376 parms.dstPtr = make_cudaPitchedPtr(
377 dst, parms.srcPtr.xsize, parms.srcPtr.xsize, parms.srcPtr.ysize);
378 parms.extent = src.extent().template cuda_extent<T>();
379 parms.dstPos = make_cudaPos(0, 0, 0);
381 Memory(m_stream).download(parms);
386MUDA_HOST BufferLaunch& BufferLaunch::copy(T* dst, CBuffer3DView<T> src)
388 cudaMemcpy3DParms parms = {0};
390 parms.srcPtr = src.cuda_pitched_ptr();
391 parms.srcPos = src.offset().template cuda_pos<T>();
392 parms.dstPtr = make_cudaPitchedPtr(
393 dst, parms.srcPtr.xsize, parms.srcPtr.xsize, parms.srcPtr.ysize);
394 parms.extent = src.extent().template cuda_extent<T>();
395 parms.dstPos = make_cudaPos(0, 0, 0);
397 Memory(m_stream).download(parms);
402MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<T*>& dst,
403 const ComputeGraphVar<VarView<T>>& src)
405 return copy(dst.eval(), src.ceval());
409MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<T*>& dst,
410 const ComputeGraphVar<BufferView<T>>& src)
412 return copy(dst.eval(), src.ceval());
416MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<T*>& dst,
417 const ComputeGraphVar<Buffer2DView<T>>& src)
419 return copy(dst.eval(), src.ceval());
423MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<T*>& dst,
424 const ComputeGraphVar<Buffer3DView<T>>& src)
426 return copy(dst.eval(), src.ceval());
435MUDA_HOST BufferLaunch& BufferLaunch::copy(VarView<T> dst,
const T* src)
437 Memory(m_stream).upload(dst.data(), src,
sizeof(T));
442MUDA_HOST BufferLaunch& BufferLaunch::copy(BufferView<T> dst,
const T* src)
444 Memory(m_stream).upload(dst.data(), src, dst.size() *
sizeof(T));
449MUDA_HOST BufferLaunch& BufferLaunch::copy(Buffer2DView<T> dst,
const T* src)
451 cudaMemcpy3DParms parms = {0};
453 parms.extent = dst.extent().template cuda_extent<T>();
454 parms.dstPos = dst.offset().template cuda_pos<T>();
455 parms.dstPtr = dst.cuda_pitched_ptr();
457 parms.srcPtr = make_cudaPitchedPtr(
const_cast<T*
>(src),
461 parms.srcPos = make_cudaPos(0, 0, 0);
463 Memory(m_stream).upload(parms);
469MUDA_HOST BufferLaunch& BufferLaunch::copy(Buffer3DView<T> dst,
const T* src)
471 cudaMemcpy3DParms parms = {0};
473 parms.extent = dst.extent().template cuda_extent<T>();
474 parms.dstPos = dst.offset().template cuda_pos<T>();
475 parms.dstPtr = dst.cuda_pitched_ptr();
477 parms.srcPtr = make_cudaPitchedPtr(
const_cast<T*
>(src),
481 parms.srcPos = make_cudaPos(0, 0, 0);
483 Memory(m_stream).upload(parms);
489MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<VarView<T>>& dst,
490 const ComputeGraphVar<T*>& src)
492 return copy(dst.eval(), src.ceval());
496MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<BufferView<T>>& dst,
497 const ComputeGraphVar<T*>& src)
499 return copy(dst.eval(), src.ceval());
503MUDA_HOST BufferLaunch& BufferLaunch::copy(ComputeGraphVar<Buffer2DView<T>>& dst,
504 const ComputeGraphVar<T*>& src)
506 return copy(dst.eval(), src.ceval());
510MUDA_HOST BufferLaunch& copy(ComputeGraphVar<Buffer3DView<T>>& dst,
511 const ComputeGraphVar<T*>& src)
513 return copy(dst.eval(), src.ceval());
522MUDA_HOST BufferLaunch& BufferLaunch::fill(VarView<T> view,
const T& val)
524 details::buffer::kernel_fill(m_stream, view, val);
529MUDA_HOST BufferLaunch& BufferLaunch::fill(BufferView<T> buffer,
const T& val)
531 details::buffer::kernel_fill(m_grid_dim, m_block_dim, m_stream, buffer, val);
536MUDA_HOST BufferLaunch& BufferLaunch::fill(Buffer2DView<T> buffer,
const T& val)
538 details::buffer::kernel_fill(m_grid_dim, m_block_dim, m_stream, buffer, val);
543MUDA_HOST BufferLaunch& BufferLaunch::fill(Buffer3DView<T> buffer,
const T& val)
545 details::buffer::kernel_fill(m_grid_dim, m_block_dim, m_stream, buffer, val);
550MUDA_HOST BufferLaunch& BufferLaunch::fill(ComputeGraphVar<VarView<T>>& buffer,
551 const ComputeGraphVar<T>& val)
553 return fill(buffer.eval(), val.ceval());
557MUDA_HOST BufferLaunch& BufferLaunch::fill(ComputeGraphVar<BufferView<T>>& buffer,
558 const ComputeGraphVar<T>& val)
560 return fill(buffer.eval(), val.ceval());
564MUDA_HOST BufferLaunch& BufferLaunch::fill(ComputeGraphVar<Buffer2DView<T>>& buffer,
565 const ComputeGraphVar<T>& val)
567 return fill(buffer.eval(), val.ceval());
571MUDA_HOST BufferLaunch& BufferLaunch::fill(ComputeGraphVar<Buffer3DView<T>>& buffer,
572 const ComputeGraphVar<T>& val)
574 return fill(buffer.eval(), val.ceval());
582template <
typename T,
typename FConstruct>
583MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer<T>& buffer,
size_t new_size, FConstruct&& fct)
585 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
586 "cannot resize a buffer in a compute graph");
588 m_grid_dim, m_block_dim, m_stream, buffer, new_size, std::forward<FConstruct>(fct));
592template <
typename T,
typename FConstruct>
593MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer2D<T>& buffer,
597 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
598 "cannot resize a buffer in a compute graph");
600 m_grid_dim, m_block_dim, m_stream, buffer, new_extent, std::forward<FConstruct>(fct));
605template <
typename T,
typename FConstruct>
606MUDA_HOST BufferLaunch& BufferLaunch::resize(DeviceBuffer3D<T>& buffer,
610 MUDA_ASSERT(ComputeGraphBuilder::is_direct_launching(),
611 "cannot resize a buffer in a compute graph");
614 m_grid_dim, m_block_dim, m_stream, buffer, new_extent, std::forward<FConstruct>(fct));
A view interface for any array-like liner memory, which can be constructed from DeviceBuffer/DeviceVe...