MUDA
Loading...
Searching...
No Matches
cub_wrapper.h
1
#pragma once
2
#include <cub/version.cuh>
3
#include <muda/launch/launch_base.h>
4
#include <muda/buffer.h>
5
#include <muda/container.h>
6
#include <muda/buffer/buffer_launch.h>
7
#include <muda/compute_graph/compute_graph.h>
8
#include <muda/launch/stream.h>
9
10
namespace
muda
11
{
12
template
<
typename
Derive>
13
class
CubWrapper
:
public
LaunchBase
<Derive>
14
{
15
protected
:
16
std::byte* prepare_buffer(
size_t
reqSize)
17
{
18
return
m_muda_stream->workspace(reqSize);
19
}
20
21
public
:
22
CubWrapper
(
Stream
& stream = Stream::Default())
23
:
LaunchBase<Derive>
(stream)
24
, m_muda_stream(&stream)
25
{
26
}
27
28
// meaningless for cub, so we just delete it
29
void
kernel_name(std::string_view) =
delete
;
30
31
Stream
* m_muda_stream =
nullptr
;
32
};
33
}
// namespace muda
muda::CubWrapper
Definition
cub_wrapper.h:14
muda::LaunchBase
Definition
launch_base.h:86
muda::Stream
RAII wrapper for cudaStream.
Definition
stream.h:18
src
muda
cub
device
cub_wrapper.h
Generated by
1.9.8