MUDA
Loading...
Searching...
No Matches
kernel_node.h
1#pragma once
2#include <muda/graph/graph_base.h>
3
4namespace muda
5{
6class KernelNode : public GraphNode
7{
8 public:
9 using this_type = KernelNode;
10 friend class Graph;
11};
12
13template <typename U>
15{
16 std::vector<void*> m_args;
17 cudaKernelNodeParams m_parms;
18
19 public:
21 friend class Graph;
22 friend class std::shared_ptr<this_type>;
23 friend class std::unique_ptr<this_type>;
24 friend class std::weak_ptr<this_type>;
25
26 template <typename... Args>
27 KernelNodeParms(Args&&... args)
28 : kernelParmData(std::forward<Args>(args)...)
29 , m_parms({})
30 {
31 }
32
34 U kernelParmData;
35 auto func() { return m_parms.func; }
36 void func(void* v) { m_parms.func = v; }
37 auto grid_dim() { return m_parms.gridDim; }
38 void grid_dim(const dim3& v) { m_parms.gridDim = v; }
39 auto block_dim() { return m_parms.blockDim; }
40 void block_dim(const dim3& v) { m_parms.blockDim = v; }
41 auto shared_mem_bytes() { return m_parms.sharedMemBytes; }
42 void shared_mem_bytes(unsigned int v) { m_parms.sharedMemBytes = v; }
43 auto kernel_params() { return m_parms.kernelParams; }
44 void kernel_params(const std::vector<void*>& v)
45 {
46 m_args = v;
47 m_parms.kernelParams = m_args.data();
48 }
49 void parse(std::function<std::vector<void*>(U&)> pred)
50 {
51 m_args = pred(kernelParmData);
52 m_parms.kernelParams = m_args.data();
53 }
54 auto extra() { return m_parms.extra; }
55 void extra(void** v) { m_parms.extra = v; }
56
57 const cudaKernelNodeParams* handle() const { return &m_parms; }
58};
59} // namespace muda
Definition graph.h:18
Definition graph_base.h:27
Definition kernel_node.h:7
Definition kernel_node.h:15
Definition graph_base.h:20