MUDA
Loading...
Searching...
No Matches
compute_graph_accessor.inl
1#include <muda/compute_graph/compute_graph.h>
2#include <muda/compute_graph/nodes/compute_graph_kernel_node.h>
3#include <muda/compute_graph/nodes/compute_graph_catpure_node.h>
4#include <muda/compute_graph/nodes/compute_graph_memory_node.h>
5#include <muda/compute_graph/nodes/compute_graph_event_node.h>
6#include <muda/compute_graph/compute_graph_closure.h>
7#include <muda/compute_graph/compute_graph_builder.h>
8
9namespace muda
10{
11namespace details
12{
13 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor()
14 : m_cg(*ComputeGraphBuilder::current_graph())
15 {
16 }
17
18 MUDA_INLINE void ComputeGraphAccessor::check_allow_var_eval() const
19 {
20 if(m_cg.m_is_in_capture_func)
21 MUDA_ERROR_WITH_LOCATION("you can't eval a var in ComputeGraph::capture() function");
22 }
23
24 MUDA_INLINE void ComputeGraphAccessor::check_allow_node_adding() const
25 {
26 if(m_cg.current_graph_phase() != ComputeGraphPhase::None)
27 MUDA_ERROR_WITH_LOCATION("you are not allowed adding node at this point");
28 }
29
30 /*
31 * Set Graph Node
32 */
33 template <typename T>
34 MUDA_INLINE void ComputeGraphAccessor::set_kernel_node(const S<KernelNodeParms<T>>& kernelParms)
35 {
36 switch(ComputeGraphBuilder::current_phase())
37 {
38 case ComputeGraphPhase::TopoBuilding:
39 MUDA_ASSERT(!kernelParms,
40 "When ComputeGraphPhase == TopoBuilding, "
41 "you don't need to create NodeParms, so keep it nullptr.");
42 // fall through
43 case ComputeGraphPhase::Building:
44 add_kernel_node(kernelParms);
45 break;
46 case ComputeGraphPhase::Updating:
47 update_kernel_node(kernelParms);
48 break;
49 default:
50 MUDA_ERROR_WITH_LOCATION("invalid phase");
51 break;
52 }
53 }
54 template <typename T>
55 MUDA_INLINE void ComputeGraphAccessor::add_kernel_node(const S<KernelNodeParms<T>>& parms)
56 {
57 access_graph([&](Graph& g) { // create kernel node
58 ComputeGraphKernelNode* kernel_node = get_or_create_node<ComputeGraphKernelNode>(
59 [&]
60 {
61 const auto& [name, closure] = current_closure();
62 return new ComputeGraphKernelNode(NodeId{m_cg.m_nodes.size()},
63 m_cg.current_access_index());
64 });
65 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
66 {
67 kernel_node->set_node(g.add_kernel_node(parms));
68 }
69 });
70 }
71 template <typename T>
72 MUDA_INLINE void ComputeGraphAccessor::update_kernel_node(const S<KernelNodeParms<T>>& kernelParms)
73 {
74 access_graph_exec(
75 [&](GraphExec& g_exec)
76 {
77 const auto& [name, closure] = current_closure();
78 auto kernel_node = current_node<ComputeGraphKernelNode>();
79 g_exec.set_kernel_node_parms(kernel_node->m_node, kernelParms);
80 });
81 }
82
83 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor(ComputeGraph& graph)
84 : m_cg(graph)
85 {
86 }
87
88 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor(ComputeGraph* graph)
89 : m_cg(*graph)
90 {
91 }
92
93 MUDA_INLINE auto ComputeGraphAccessor::current_closure() const
94 -> const std::pair<std::string, ComputeGraphClosure*>&
95 {
96 return m_cg.m_closures[m_cg.current_closure_id().value()];
97 }
98
99 MUDA_INLINE auto ComputeGraphAccessor::current_closure()
100 -> std::pair<std::string, ComputeGraphClosure*>&
101 {
102 return m_cg.m_closures[m_cg.m_current_closure_id.value()];
103 }
104
105 MUDA_INLINE const ComputeGraphNodeBase* ComputeGraphAccessor::current_node() const
106 {
107 return current_closure().second->m_graph_nodes[m_cg.current_access_index()];
108 }
109
110 MUDA_INLINE ComputeGraphNodeBase* ComputeGraphAccessor::current_node()
111 {
112 return current_closure().second->m_graph_nodes[m_cg.current_access_index()];
113 }
114
115 MUDA_INLINE cudaStream_t ComputeGraphAccessor::current_stream() const
116 {
117 return m_cg.m_current_single_stream;
118 }
119
120 MUDA_INLINE cudaStream_t ComputeGraphAccessor::capture_stream() const
121 {
122 MUDA_ASSERT(m_cg.m_is_capturing, "Not Capture Phase!");
123 return m_cg.shared_capture_stream();
124 }
125
126 MUDA_INLINE bool ComputeGraphAccessor::is_topo_built() const
127 {
128 return m_cg.m_is_topo_built;
129 }
130
131 template <typename T>
132 T* ComputeGraphAccessor::current_node()
133 {
134 return dynamic_cast<T*>(current_node());
135 }
136
137 MUDA_INLINE void ComputeGraphAccessor::set_memcpy_node(void* dst,
138 const void* src,
139 size_t size_bytes,
140 cudaMemcpyKind kind)
141 {
142 switch(ComputeGraphBuilder::current_phase())
143 {
144 case ComputeGraphPhase::TopoBuilding:
145 // fall through
146 case ComputeGraphPhase::Building:
147 add_memcpy_node(dst, src, size_bytes, kind);
148 break;
149 case ComputeGraphPhase::Updating:
150 update_memcpy_node(dst, src, size_bytes, kind);
151 break;
152 default:
153 MUDA_ERROR_WITH_LOCATION("invalid phase");
154 break;
155 }
156 }
157
158 MUDA_INLINE void ComputeGraphAccessor::set_memcpy_node(const cudaMemcpy3DParms& parms)
159 {
160 switch(ComputeGraphBuilder::current_phase())
161 {
162 case ComputeGraphPhase::TopoBuilding:
163 // fall through
164 case ComputeGraphPhase::Building:
165 add_memcpy_node(parms);
166 break;
167 case ComputeGraphPhase::Updating:
168 update_memcpy_node(parms);
169 break;
170 default:
171 MUDA_ERROR_WITH_LOCATION("invalid phase");
172 break;
173 }
174 }
175
176 MUDA_INLINE void ComputeGraphAccessor::set_memset_node(const cudaMemsetParams& parms)
177 {
178 switch(ComputeGraphBuilder::current_phase())
179 {
180 case ComputeGraphPhase::TopoBuilding:
181 // fall through
182 case ComputeGraphPhase::Building:
183 add_memset_node(parms);
184 break;
185 case ComputeGraphPhase::Updating:
186 update_memset_node(parms);
187 break;
188 default:
189 MUDA_ERROR_WITH_LOCATION("invalid phase");
190 break;
191 }
192 }
193
194 MUDA_INLINE void ComputeGraphAccessor::set_event_record_node(cudaEvent_t event)
195 {
196 switch(ComputeGraphBuilder::current_phase())
197 {
198 case ComputeGraphPhase::TopoBuilding:
199 MUDA_ASSERT(!event,
200 "When ComputeGraphPhase == TopoBuilding, "
201 "you don't need to create event, so keep it nullptr.");
202 // fall through
203 case ComputeGraphPhase::Building:
204 add_event_record_node(event);
205 break;
206 case ComputeGraphPhase::Updating:
207 update_event_record_node(event);
208 break;
209 default:
210 MUDA_ERROR_WITH_LOCATION("invalid phase");
211 break;
212 }
213 }
214
215 MUDA_INLINE void ComputeGraphAccessor::add_memcpy_node(void* dst,
216 const void* src,
217 size_t size_bytes,
218 cudaMemcpyKind kind)
219 {
220 access_graph([&](Graph& g) { // create memory node
221 ComputeGraphMemcpyNode* memory_node = get_or_create_node<ComputeGraphMemcpyNode>(
222 [&]
223 {
224 const auto& [name, closure] = current_closure();
225 return new ComputeGraphMemcpyNode(NodeId{m_cg.m_nodes.size()},
226 m_cg.current_access_index());
227 });
228 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
229 memory_node->set_node(g.add_memcpy_node(dst, src, size_bytes, kind));
230 });
231 }
232 MUDA_INLINE void ComputeGraphAccessor::update_memcpy_node(void* dst,
233 const void* src,
234 size_t size_bytes,
235 cudaMemcpyKind kind)
236 {
237 access_graph_exec(
238 [&](GraphExec& g_exec)
239 {
240 const auto& [name, closure] = current_closure();
241 auto memory_node = current_node<ComputeGraphMemcpyNode>();
242 g_exec.set_memcpy_node_parms(memory_node->m_node, dst, src, size_bytes, kind);
243 });
244 }
245
246 MUDA_INLINE void ComputeGraphAccessor::add_memcpy_node(const cudaMemcpy3DParms& parms)
247 {
248 access_graph([&](Graph& g) { // create memory node
249 ComputeGraphMemcpyNode* memory_node = get_or_create_node<ComputeGraphMemcpyNode>(
250 [&]
251 {
252 const auto& [name, closure] = current_closure();
253 return new ComputeGraphMemcpyNode(NodeId{m_cg.m_nodes.size()},
254 m_cg.current_access_index());
255 });
256 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
257 memory_node->set_node(g.add_memcpy_node(parms));
258 });
259 }
260
261 MUDA_INLINE void ComputeGraphAccessor::update_memcpy_node(const cudaMemcpy3DParms& parms)
262 {
263 access_graph_exec(
264 [&](GraphExec& g_exec)
265 {
266 const auto& [name, closure] = current_closure();
267 auto memory_node = current_node<ComputeGraphMemcpyNode>();
268 g_exec.set_memcpy_node_parms(memory_node->m_node, parms);
269 });
270 }
271
272 MUDA_INLINE void ComputeGraphAccessor::add_memset_node(const cudaMemsetParams& parms)
273 {
274 access_graph([&](Graph& g) { // create memory node
275 ComputeGraphMemsetNode* memory_node = get_or_create_node<ComputeGraphMemsetNode>(
276 [&]
277 {
278 const auto& [name, closure] = current_closure();
279 return new ComputeGraphMemsetNode(NodeId{m_cg.m_nodes.size()},
280 m_cg.current_access_index());
281 });
282 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
283 memory_node->set_node(g.add_memset_node(parms));
284 });
285 }
286
287 MUDA_INLINE void ComputeGraphAccessor::update_memset_node(const cudaMemsetParams& parms)
288 {
289 access_graph_exec(
290 [&](GraphExec& g_exec)
291 {
292 const auto& [name, closure] = current_closure();
293 auto memory_node = current_node<ComputeGraphMemsetNode>();
294 g_exec.set_memset_node_parms(memory_node->m_node, parms);
295 });
296 }
297
298 MUDA_INLINE void ComputeGraphAccessor::add_event_record_node(cudaEvent_t event)
299 {
300 MUDA_ASSERT(!m_cg.m_flags.has(muda::GraphInstantiateFlagBit::DeviceLaunch),
301 "Event Record Node is not allowed in a graph that will be launched on device");
302
303 access_graph(
304 [&](Graph& g)
305 {
306 ComputeGraphEventRecordNode* event_record =
307 get_or_create_node<ComputeGraphEventRecordNode>(
308 [&]
309 {
310 const auto& [name, closure] = current_closure();
311 return new ComputeGraphEventRecordNode(
312 NodeId{m_cg.m_nodes.size()}, m_cg.current_access_index());
313 });
314
315 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
316 {
317 event_record->set_node(g.add_event_record_node(event));
318 }
319 });
320 }
321 MUDA_INLINE void ComputeGraphAccessor::update_event_record_node(cudaEvent_t event)
322 {
323 access_graph_exec(
324 [&](GraphExec& g_exec)
325 {
326 const auto& [name, closure] = current_closure();
327 auto event_record = current_node<ComputeGraphEventRecordNode>();
328 g_exec.set_event_record_node_parms(event_record->m_node, event);
329 });
330 }
331
332
333 MUDA_INLINE void ComputeGraphAccessor::set_event_wait_node(cudaEvent_t event)
334 {
335 switch(ComputeGraphBuilder::current_phase())
336 {
337 case ComputeGraphPhase::TopoBuilding:
338 MUDA_ASSERT(!event,
339 "When ComputeGraphPhase == TopoBuilding, "
340 "you don't need to create event, so keep it nullptr.");
341 // fall through
342 case ComputeGraphPhase::Building:
343 add_event_wait_node(event);
344 break;
345 case ComputeGraphPhase::Updating:
346 update_event_wait_node(event);
347 break;
348 default:
349 MUDA_ERROR_WITH_LOCATION("invalid phase");
350 break;
351 }
352 }
353
354 MUDA_INLINE void ComputeGraphAccessor::set_capture_node(cudaGraph_t sub_graph)
355 {
356 switch(ComputeGraphBuilder::current_phase())
357 {
358 case ComputeGraphPhase::TopoBuilding:
359 MUDA_ASSERT(!sub_graph,
360 "When ComputeGraphPhase == TopoBuilding, "
361 "you don't need to create sub_graph, so keep it nullptr.");
362 case ComputeGraphPhase::Building:
363 add_capture_node(sub_graph);
364 break;
365 case ComputeGraphPhase::Updating:
366 update_capture_node(sub_graph);
367 break;
368 default:
369 MUDA_ERROR_WITH_LOCATION("invalid phase");
370 break;
371 }
372 }
373
374 MUDA_INLINE void ComputeGraphAccessor::add_event_wait_node(cudaEvent_t event)
375 {
376 MUDA_ASSERT(!m_cg.m_flags.has(muda::GraphInstantiateFlagBit::DeviceLaunch),
377 "Event Wait Node is not allowed in a graph that will be launched on device");
378
379 access_graph(
380 [&](Graph& g)
381 {
382 ComputeGraphEventWaitNode* event_wait =
383 get_or_create_node<ComputeGraphEventWaitNode>(
384 [&]
385 {
386 const auto& [name, closure] = current_closure();
387 return new ComputeGraphEventWaitNode(
388 NodeId{m_cg.m_nodes.size()}, m_cg.current_access_index());
389 });
390
391 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
392 {
393 event_wait->set_node(g.add_event_wait_node(event));
394 }
395 });
396 }
397 MUDA_INLINE void ComputeGraphAccessor::update_event_wait_node(cudaEvent_t event)
398 {
399 access_graph_exec(
400 [&](GraphExec& g_exec)
401 {
402 const auto& [name, closure] = current_closure();
403 auto event_wait = current_node<ComputeGraphEventWaitNode>();
404 g_exec.set_event_wait_node_parms(event_wait->m_node, event);
405 });
406 }
407
408 MUDA_INLINE void ComputeGraphAccessor::add_capture_node(cudaGraph_t sub_graph)
409 {
410 access_graph(
411 [&](Graph& g)
412 {
413 auto capture_node = get_or_create_node<ComputeGraphCaptureNode>(
414 [&]
415 {
416 const auto& [name, closure] = current_closure();
417 return new ComputeGraphCaptureNode{NodeId{m_cg.m_nodes.size()},
418 m_cg.current_access_index()};
419 });
420 if(ComputeGraphBuilder::is_building())
421 {
422 cudaGraphNode_t node;
423 checkCudaErrors(cudaGraphAddChildGraphNode(
424 &node, g.handle(), nullptr, 0, sub_graph));
425 capture_node->set_node(node);
426 capture_node->update_sub_graph(sub_graph); // update sub graph
427 }
428 });
429 }
430
431 MUDA_INLINE void ComputeGraphAccessor::update_capture_node(cudaGraph_t sub_graph)
432 {
433 access_graph_exec(
434 [&](GraphExec& g_exec)
435 {
436 const auto& [name, closure] = current_closure();
437 auto capture_node = current_node<ComputeGraphCaptureNode>();
438 checkCudaErrors(cudaGraphExecChildGraphNodeSetParams(
439 g_exec.handle(), capture_node->handle(), sub_graph));
440 capture_node->update_sub_graph(sub_graph); // update sub graph
441 });
442 // m_is_capturing = false;
443 }
444
445
446 template <typename F>
447 void ComputeGraphAccessor::access_graph(F&& f)
448 {
449 f(m_cg.m_graph);
450 ++m_cg.m_access_graph_index;
451 }
452
453 template <typename F>
454 void ComputeGraphAccessor::access_graph_exec(F&& f)
455 {
456 f(*m_cg.m_graph_exec.get());
457 }
458
459 template <typename NodeType, typename F>
460 MUDA_INLINE NodeType* ComputeGraphAccessor::get_or_create_node(F&& f)
461 {
462 static_assert(std::is_base_of_v<ComputeGraphNodeBase, NodeType>,
463 "NodeType must be derived from ComputeGraphNodeBase");
464 if(!m_cg.m_is_topo_built)
465 {
466 NodeType* ptr = f();
467 auto& [name, closure] = current_closure();
468 closure->m_graph_nodes.emplace_back(ptr);
469 m_cg.m_nodes.emplace_back(ptr);
470 return ptr;
471 }
472 else
473 return current_node<NodeType>();
474 }
475 MUDA_INLINE void ComputeGraphAccessor::set_var_usage(VarId id, ComputeGraphVarUsage usage)
476 {
477 auto& dst_usage = current_closure().second->m_var_usages[id];
478 if(dst_usage < usage)
479 dst_usage = usage;
480 }
481
482} // namespace details
483} // namespace muda