Skip to content

File cub_wrapper.h

File List > cub > device > cub_wrapper.h

Go to the documentation of this file

#pragma once
#include <cub/version.cuh>
#include <muda/launch/launch_base.h>
#include <muda/buffer.h>
#include <muda/container.h>
#include <muda/buffer/buffer_launch.h>
#include <muda/compute_graph/compute_graph.h>
#include <muda/launch/stream.h>

namespace muda
{
template <typename Derive>
class CubWrapper : public LaunchBase<Derive>
{
  protected:
    std::byte* prepare_buffer(size_t reqSize)
    {
        return m_muda_stream->workspace(reqSize);
    }

  public:
    CubWrapper(Stream& stream = Stream::Default())
        : LaunchBase<Derive>(stream)
        , m_muda_stream(&stream)
    {
    }

    // meaningless for cub, so we just delete it
    void kernel_name(std::string_view) = delete;

    Stream* m_muda_stream = nullptr;
};
}  // namespace muda