Skip to content

File kernel.h

File List > launch > kernel.h

Go to the documentation of this file

#pragma once
#include <cuda.h>
#include <muda/muda_def.h>
#include <muda/launch/stream_define.h>
#include <type_traits>

namespace muda
{
template <typename F>
class Kernel
{
    dim3         m_grid_dim;
    dim3         m_block_dim;
    size_t       m_shared_memory_size;
    cudaStream_t m_stream;
    F            m_kernel;

  public:
    MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, size_t shared_memory_size, cudaStream_t stream, F f)
        : m_grid_dim(grid_dim)
        , m_block_dim(m_block_dim)
        , m_shared_memory_size(shared_memory_size)
        , m_stream(stream)
        , m_kernel(f)
    {
#ifdef __CUDA_ARCH__
        MUDA_KERNEL_ASSERT(stream == details::stream::tail_launch()
                               || stream == details::stream::fire_and_forget(),
                           "Kernel Launch on device with invalid stream! "
                           "Only Stream::TailLaunch{} and Stream::FireAndForget{} are allowed");
#endif
    }

    MUDA_GENERIC Kernel(F f)
        : Kernel{1, 1, 0, 0, f}
    {
    }

    MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, F f)
        : Kernel{grid_dim, m_block_dim, 0, 0, f}
    {
    }

    MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, size_t shared_memory_size, F f)
        : Kernel{grid_dim, m_block_dim, shared_memory_size, 0, f}
    {
    }

    MUDA_GENERIC Kernel(dim3 grid_dim, dim3 m_block_dim, cudaStream_t stream, F f)
        : Kernel{grid_dim, m_block_dim, 0, stream, f}
    {
    }

    MUDA_GENERIC Kernel(cudaStream_t stream, F f)
        : Kernel{1, 1, 0, stream, f}
    {
    }

    template <typename... Args>
    MUDA_GENERIC void operator()(Args&&... args) &&
    {
        static_assert(std::is_invocable_v<F, Args...>, "invalid arguments");
#if MUDA_WITH_DEVICE_STREAM_MODEL
        m_kernel<<<m_grid_dim, m_block_dim, m_shared_memory_size, m_stream>>>(
            std::forward<Args>(args)...);
        checkCudaErrors(cudaGetLastError());
#else
        cudaStream_t stream = nullptr;
        if(m_stream == details::stream::tail_launch())
        {
            checkCudaErrors(cudaDeviceSynchronize());
        }
        else if(m_stream == details::stream::fire_and_forget())
        {
            // do nothing
        }
        else
        {
            stream = m_stream;
        }
        m_kernel<<<m_grid_dim, m_block_dim, m_shared_memory_size, stream>>>(
            std::forward<Args>(args)...);
        checkCudaErrors(cudaGetLastError());
#endif
    }

    // delete copy and move
    MUDA_GENERIC         Kernel(const Kernel&)    = delete;
    MUDA_GENERIC Kernel& operator=(const Kernel&) = delete;
    MUDA_GENERIC         Kernel(Kernel&&)         = delete;
    MUDA_GENERIC Kernel& operator=(Kernel&&)      = delete;
};
}  // namespace muda