MUDA
Loading...
Searching...
No Matches
cub_wrapper.h
1#pragma once
2#include <cub/version.cuh>
3#include <muda/launch/launch_base.h>
4#include <muda/buffer.h>
5#include <muda/container.h>
6#include <muda/buffer/buffer_launch.h>
7#include <muda/compute_graph/compute_graph.h>
8#include <muda/launch/stream.h>
9
10namespace muda
11{
12template <typename Derive>
13class CubWrapper : public LaunchBase<Derive>
14{
15 protected:
16 std::byte* prepare_buffer(size_t reqSize)
17 {
18 return m_muda_stream->workspace(reqSize);
19 }
20
21 public:
22 CubWrapper(Stream& stream = Stream::Default())
23 : LaunchBase<Derive>(stream)
24 , m_muda_stream(&stream)
25 {
26 }
27
28 // meaningless for cub, so we just delete it
29 void kernel_name(std::string_view) = delete;
30
31 Stream* m_muda_stream = nullptr;
32};
33} // namespace muda
Definition cub_wrapper.h:14
Definition launch_base.h:86
RAII wrapper for cudaStream.
Definition stream.h:18