MUDA
Loading...
Searching...
No Matches
graph.inl
1namespace muda
2{
3MUDA_INLINE Graph::Graph()
4{
5 checkCudaErrors(cudaGraphCreate(&m_handle, 0));
6}
7
8MUDA_INLINE Graph::~Graph()
9{
10 if(m_handle)
11 checkCudaErrors(cudaGraphDestroy(m_handle));
12}
13
14
15MUDA_INLINE Graph::Graph(Graph&& o)
16 : m_handle(std::move(o.m_handle))
17 , m_cached(std::move(o.m_cached))
18{
19 o.m_handle = nullptr;
20}
21
22MUDA_INLINE Graph& Graph::operator=(Graph&& o)
23{
24 if(this == &o)
25 return *this;
26 m_handle = std::move(o.m_handle);
27 m_cached = std::move(o.m_cached);
28 o.m_handle = nullptr;
29 return *this;
30}
31
32MUDA_INLINE auto Graph::instantiate() -> S<GraphExec>
33{
34 auto ret = std::make_shared<GraphExec>();
35 checkCudaErrors(cudaGraphInstantiate(&ret->m_handle, m_handle, nullptr, nullptr, 0));
36 return ret;
37}
38
39MUDA_INLINE auto Graph::instantiate(Flags<GraphInstantiateFlagBit> flags) -> S<GraphExec>
40{
41 auto ret = std::make_shared<GraphExec>();
42#if MUDA_WITH_DEVICE_STREAM_MODEL
43 checkCudaErrors(
44 cudaGraphInstantiateWithFlags(&ret->m_handle, m_handle, static_cast<int>(flags)));
45#else
46 checkCudaErrors(cudaGraphInstantiateWithFlags(
47 &ret->m_handle, m_handle, static_cast<int>(flags & GraphInstantiateFlagBit::FreeOnLaunch)));
48#endif
49 ret->m_flags = flags;
50 return ret;
51}
52
53template <typename T>
54auto Graph::add_kernel_node(const S<KernelNodeParms<T>>& kernelParms,
55 const std::vector<S<GraphNode>>& deps) -> S<KernelNode>
56{
57 auto ret = std::make_shared<KernelNode>();
58 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
59 checkCudaErrors(cudaGraphAddKernelNode(
60 &ret->m_handle, m_handle, nodes.data(), nodes.size(), kernelParms->handle()));
61 return ret;
62}
63
64template <typename T>
65auto Graph::add_kernel_node(const S<KernelNodeParms<T>>& kernelParms) -> S<KernelNode>
66{
67 auto ret = std::make_shared<KernelNode>();
68 checkCudaErrors(cudaGraphAddKernelNode(
69 &ret->m_handle, m_handle, nullptr, 0, kernelParms->handle()));
70 return ret;
71}
72
73template <typename T>
74auto Graph::add_host_node(const S<HostNodeParms<T>>& hostParms,
75 const std::vector<S<GraphNode>>& deps) -> S<HostNode>
76{
77 m_cached.push_back(hostParms);
78 auto ret = std::make_shared<HostNode>();
79 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
80 checkCudaErrors(cudaGraphAddHostNode(
81 &ret->m_handle, m_handle, nodes.data(), nodes.size(), hostParms->handle()));
82 return ret;
83}
84
85template <typename T>
86auto Graph::add_host_node(const S<HostNodeParms<T>>& hostParms) -> S<HostNode>
87{
88 m_cached.push_back(hostParms);
89 auto ret = std::make_shared<HostNode>();
90 checkCudaErrors(
91 cudaGraphAddHostNode(&ret->m_handle, m_handle, nullptr, 0, hostParms->handle()));
92 return ret;
93}
94
95
96MUDA_INLINE auto Graph::add_memcpy_node(void* dst,
97 const void* src,
98 size_t size_bytes,
99 cudaMemcpyKind kind,
100 const std::vector<S<GraphNode>>& deps) -> S<MemcpyNode>
101{
102 auto ret = std::make_shared<MemcpyNode>();
103 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
104 checkCudaErrors(cudaGraphAddMemcpyNode1D(
105 &ret->m_handle, m_handle, nodes.data(), nodes.size(), dst, src, size_bytes, kind));
106 return ret;
107}
108
109MUDA_INLINE auto Graph::add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind)
110 -> S<MemcpyNode>
111{
112 auto ret = std::make_shared<MemcpyNode>();
113 checkCudaErrors(cudaGraphAddMemcpyNode1D(
114 &ret->m_handle, m_handle, nullptr, 0, dst, src, size_bytes, kind));
115 return ret;
116}
117
118
119MUDA_INLINE auto Graph::add_memcpy_node(const cudaMemcpy3DParms& parms,
120 const std::vector<S<GraphNode>>& deps) -> S<MemcpyNode>
121{
122 auto ret = std::make_shared<MemcpyNode>();
123 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
124 checkCudaErrors(cudaGraphAddMemcpyNode(
125 &ret->m_handle, m_handle, nodes.data(), nodes.size(), &parms));
126 return ret;
127}
128
129MUDA_INLINE auto Graph::add_memset_node(const cudaMemsetParams& parms,
130 const std::vector<S<GraphNode>>& deps) -> S<MemsetNode>
131{
132 auto ret = std::make_shared<MemsetNode>();
133 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
134 checkCudaErrors(cudaGraphAddMemsetNode(
135 &ret->m_handle, m_handle, nodes.data(), nodes.size(), &parms));
136 return ret;
137}
138
139MUDA_INLINE auto Graph::add_memset_node(const cudaMemsetParams& parms) -> S<MemsetNode>
140{
141 auto ret = std::make_shared<MemsetNode>();
142 checkCudaErrors(cudaGraphAddMemsetNode(&ret->m_handle, m_handle, nullptr, 0, &parms));
143 return ret;
144}
145
146MUDA_INLINE auto Graph::add_memcpy_node(const cudaMemcpy3DParms& parms) -> S<MemcpyNode>
147{
148 auto ret = std::make_shared<MemcpyNode>();
149 checkCudaErrors(cudaGraphAddMemcpyNode(&ret->m_handle, m_handle, nullptr, 0, &parms));
150 return ret;
151}
152
153MUDA_INLINE auto Graph::add_event_record_node(cudaEvent_t e,
154 const std::vector<S<GraphNode>>& deps)
155 -> S<EventRecordNode>
156{
157 auto ret = std::make_shared<EventRecordNode>();
158 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
159 checkCudaErrors(cudaGraphAddEventRecordNode(
160 &ret->m_handle, m_handle, nodes.data(), nodes.size(), e));
161 return ret;
162}
163
164MUDA_INLINE auto Graph::add_event_record_node(cudaEvent_t e) -> S<EventRecordNode>
165{
166 auto ret = std::make_shared<EventRecordNode>();
167 checkCudaErrors(cudaGraphAddEventRecordNode(&ret->m_handle, m_handle, nullptr, 0, e));
168 return ret;
169}
170
171MUDA_INLINE auto Graph::add_event_wait_node(cudaEvent_t e,
172 const std::vector<S<GraphNode>>& deps)
173 -> S<EventWaitNode>
174{
175 auto ret = std::make_shared<EventWaitNode>();
176 std::vector<cudaGraphNode_t> nodes = map_dependencies(deps);
177 checkCudaErrors(cudaGraphAddEventWaitNode(
178 &ret->m_handle, m_handle, nodes.data(), nodes.size(), e));
179 return ret;
180}
181
182MUDA_INLINE auto Graph::add_event_wait_node(cudaEvent_t e) -> S<EventWaitNode>
183{
184 auto ret = std::make_shared<EventWaitNode>();
185 checkCudaErrors(cudaGraphAddEventWaitNode(&ret->m_handle, m_handle, nullptr, 0, e));
186 return ret;
187}
188
189MUDA_INLINE void Graph::add_dependency(S<GraphNode> from, S<GraphNode> to)
190{
191 checkCudaErrors(
192 cudaGraphAddDependencies(m_handle, &(from->m_handle), &(to->m_handle), 1));
193}
194
195MUDA_INLINE std::vector<cudaGraphNode_t> Graph::map_dependencies(const std::vector<S<GraphNode>>& deps)
196{
197 std::vector<cudaGraphNode_t> nodes;
198 nodes.reserve(deps.size());
199 for(auto d : deps)
200 nodes.push_back(d->m_handle);
201 return nodes;
202}
203} // namespace muda