File kernel_node.h
File List > graph > kernel_node.h
Go to the documentation of this file
#pragma once
#include <muda/graph/graph_base.h>
namespace muda
{
class KernelNode : public GraphNode
{
public:
using this_type = KernelNode;
friend class Graph;
};
template <typename U>
class KernelNodeParms : public NodeParms
{
std::vector<void*> m_args;
cudaKernelNodeParams m_parms;
public:
using this_type = KernelNodeParms;
friend class Graph;
friend class std::shared_ptr<this_type>;
friend class std::unique_ptr<this_type>;
friend class std::weak_ptr<this_type>;
template <typename... Args>
KernelNodeParms(Args&&... args)
: kernelParmData(std::forward<Args>(args)...)
, m_parms({})
{
}
KernelNodeParms() {}
U kernelParmData;
auto func() { return m_parms.func; }
void func(void* v) { m_parms.func = v; }
auto grid_dim() { return m_parms.gridDim; }
void grid_dim(const dim3& v) { m_parms.gridDim = v; }
auto block_dim() { return m_parms.blockDim; }
void block_dim(const dim3& v) { m_parms.blockDim = v; }
auto shared_mem_bytes() { return m_parms.sharedMemBytes; }
void shared_mem_bytes(unsigned int v) { m_parms.sharedMemBytes = v; }
auto kernel_params() { return m_parms.kernelParams; }
void kernel_params(const std::vector<void*>& v)
{
m_args = v;
m_parms.kernelParams = m_args.data();
}
void parse(std::function<std::vector<void*>(U&)> pred)
{
m_args = pred(kernelParmData);
m_parms.kernelParams = m_args.data();
}
auto extra() { return m_parms.extra; }
void extra(void** v) { m_parms.extra = v; }
const cudaKernelNodeParams* handle() const { return &m_parms; }
};
} // namespace muda