1#include <muda/compute_graph/compute_graph.h>
2#include <muda/compute_graph/nodes/compute_graph_kernel_node.h>
3#include <muda/compute_graph/nodes/compute_graph_catpure_node.h>
4#include <muda/compute_graph/nodes/compute_graph_memory_node.h>
5#include <muda/compute_graph/nodes/compute_graph_event_node.h>
6#include <muda/compute_graph/compute_graph_closure.h>
7#include <muda/compute_graph/compute_graph_builder.h>
13 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor()
14 : m_cg(*ComputeGraphBuilder::current_graph())
18 MUDA_INLINE
void ComputeGraphAccessor::check_allow_var_eval()
const
20 if(m_cg.m_is_in_capture_func)
21 MUDA_ERROR_WITH_LOCATION(
"you can't eval a var in ComputeGraph::capture() function");
24 MUDA_INLINE
void ComputeGraphAccessor::check_allow_node_adding()
const
26 if(m_cg.current_graph_phase() != ComputeGraphPhase::None)
27 MUDA_ERROR_WITH_LOCATION(
"you are not allowed adding node at this point");
34 MUDA_INLINE
void ComputeGraphAccessor::set_kernel_node(
const S<KernelNodeParms<T>>& kernelParms)
36 switch(ComputeGraphBuilder::current_phase())
38 case ComputeGraphPhase::TopoBuilding:
39 MUDA_ASSERT(!kernelParms,
40 "When ComputeGraphPhase == TopoBuilding, "
41 "you don't need to create NodeParms, so keep it nullptr.");
43 case ComputeGraphPhase::Building:
44 add_kernel_node(kernelParms);
46 case ComputeGraphPhase::Updating:
47 update_kernel_node(kernelParms);
50 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
55 MUDA_INLINE
void ComputeGraphAccessor::add_kernel_node(
const S<KernelNodeParms<T>>& parms)
57 access_graph([&](Graph& g) {
58 ComputeGraphKernelNode* kernel_node = get_or_create_node<ComputeGraphKernelNode>(
61 const auto& [name, closure] = current_closure();
62 return new ComputeGraphKernelNode(NodeId{m_cg.m_nodes.size()},
63 m_cg.current_access_index());
65 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
67 kernel_node->set_node(g.add_kernel_node(parms));
72 MUDA_INLINE
void ComputeGraphAccessor::update_kernel_node(
const S<KernelNodeParms<T>>& kernelParms)
75 [&](GraphExec& g_exec)
77 const auto& [name, closure] = current_closure();
78 auto kernel_node = current_node<ComputeGraphKernelNode>();
79 g_exec.set_kernel_node_parms(kernel_node->m_node, kernelParms);
83 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor(ComputeGraph& graph)
88 MUDA_INLINE ComputeGraphAccessor::ComputeGraphAccessor(ComputeGraph* graph)
93 MUDA_INLINE
auto ComputeGraphAccessor::current_closure() const
94 -> const std::pair<std::
string, ComputeGraphClosure*>&
96 return m_cg.m_closures[m_cg.current_closure_id().value()];
99 MUDA_INLINE
auto ComputeGraphAccessor::current_closure()
100 -> std::pair<std::string, ComputeGraphClosure*>&
102 return m_cg.m_closures[m_cg.m_current_closure_id.value()];
105 MUDA_INLINE
const ComputeGraphNodeBase* ComputeGraphAccessor::current_node()
const
107 return current_closure().second->m_graph_nodes[m_cg.current_access_index()];
110 MUDA_INLINE ComputeGraphNodeBase* ComputeGraphAccessor::current_node()
112 return current_closure().second->m_graph_nodes[m_cg.current_access_index()];
115 MUDA_INLINE cudaStream_t ComputeGraphAccessor::current_stream()
const
117 return m_cg.m_current_single_stream;
120 MUDA_INLINE cudaStream_t ComputeGraphAccessor::capture_stream()
const
122 MUDA_ASSERT(m_cg.m_is_capturing,
"Not Capture Phase!");
123 return m_cg.shared_capture_stream();
126 MUDA_INLINE
bool ComputeGraphAccessor::is_topo_built()
const
128 return m_cg.m_is_topo_built;
131 template <
typename T>
132 T* ComputeGraphAccessor::current_node()
134 return dynamic_cast<T*
>(current_node());
137 MUDA_INLINE
void ComputeGraphAccessor::set_memcpy_node(
void* dst,
142 switch(ComputeGraphBuilder::current_phase())
144 case ComputeGraphPhase::TopoBuilding:
146 case ComputeGraphPhase::Building:
147 add_memcpy_node(dst, src, size_bytes, kind);
149 case ComputeGraphPhase::Updating:
150 update_memcpy_node(dst, src, size_bytes, kind);
153 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
158 MUDA_INLINE
void ComputeGraphAccessor::set_memcpy_node(
const cudaMemcpy3DParms& parms)
160 switch(ComputeGraphBuilder::current_phase())
162 case ComputeGraphPhase::TopoBuilding:
164 case ComputeGraphPhase::Building:
165 add_memcpy_node(parms);
167 case ComputeGraphPhase::Updating:
168 update_memcpy_node(parms);
171 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
176 MUDA_INLINE
void ComputeGraphAccessor::set_memset_node(
const cudaMemsetParams& parms)
178 switch(ComputeGraphBuilder::current_phase())
180 case ComputeGraphPhase::TopoBuilding:
182 case ComputeGraphPhase::Building:
183 add_memset_node(parms);
185 case ComputeGraphPhase::Updating:
186 update_memset_node(parms);
189 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
194 MUDA_INLINE
void ComputeGraphAccessor::set_event_record_node(cudaEvent_t event)
196 switch(ComputeGraphBuilder::current_phase())
198 case ComputeGraphPhase::TopoBuilding:
200 "When ComputeGraphPhase == TopoBuilding, "
201 "you don't need to create event, so keep it nullptr.");
203 case ComputeGraphPhase::Building:
204 add_event_record_node(event);
206 case ComputeGraphPhase::Updating:
207 update_event_record_node(event);
210 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
215 MUDA_INLINE
void ComputeGraphAccessor::add_memcpy_node(
void* dst,
220 access_graph([&](Graph& g) {
221 ComputeGraphMemcpyNode* memory_node = get_or_create_node<ComputeGraphMemcpyNode>(
224 const auto& [name, closure] = current_closure();
225 return new ComputeGraphMemcpyNode(NodeId{m_cg.m_nodes.size()},
226 m_cg.current_access_index());
228 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
229 memory_node->set_node(g.add_memcpy_node(dst, src, size_bytes, kind));
232 MUDA_INLINE
void ComputeGraphAccessor::update_memcpy_node(
void* dst,
238 [&](GraphExec& g_exec)
240 const auto& [name, closure] = current_closure();
241 auto memory_node = current_node<ComputeGraphMemcpyNode>();
242 g_exec.set_memcpy_node_parms(memory_node->m_node, dst, src, size_bytes, kind);
246 MUDA_INLINE
void ComputeGraphAccessor::add_memcpy_node(
const cudaMemcpy3DParms& parms)
248 access_graph([&](Graph& g) {
249 ComputeGraphMemcpyNode* memory_node = get_or_create_node<ComputeGraphMemcpyNode>(
252 const auto& [name, closure] = current_closure();
253 return new ComputeGraphMemcpyNode(NodeId{m_cg.m_nodes.size()},
254 m_cg.current_access_index());
256 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
257 memory_node->set_node(g.add_memcpy_node(parms));
261 MUDA_INLINE
void ComputeGraphAccessor::update_memcpy_node(
const cudaMemcpy3DParms& parms)
264 [&](GraphExec& g_exec)
266 const auto& [name, closure] = current_closure();
267 auto memory_node = current_node<ComputeGraphMemcpyNode>();
268 g_exec.set_memcpy_node_parms(memory_node->m_node, parms);
272 MUDA_INLINE
void ComputeGraphAccessor::add_memset_node(
const cudaMemsetParams& parms)
274 access_graph([&](Graph& g) {
275 ComputeGraphMemsetNode* memory_node = get_or_create_node<ComputeGraphMemsetNode>(
278 const auto& [name, closure] = current_closure();
279 return new ComputeGraphMemsetNode(NodeId{m_cg.m_nodes.size()},
280 m_cg.current_access_index());
282 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
283 memory_node->set_node(g.add_memset_node(parms));
287 MUDA_INLINE
void ComputeGraphAccessor::update_memset_node(
const cudaMemsetParams& parms)
290 [&](GraphExec& g_exec)
292 const auto& [name, closure] = current_closure();
293 auto memory_node = current_node<ComputeGraphMemsetNode>();
294 g_exec.set_memset_node_parms(memory_node->m_node, parms);
298 MUDA_INLINE
void ComputeGraphAccessor::add_event_record_node(cudaEvent_t event)
300 MUDA_ASSERT(!m_cg.m_flags.has(muda::GraphInstantiateFlagBit::DeviceLaunch),
301 "Event Record Node is not allowed in a graph that will be launched on device");
306 ComputeGraphEventRecordNode* event_record =
307 get_or_create_node<ComputeGraphEventRecordNode>(
310 const auto& [name, closure] = current_closure();
311 return new ComputeGraphEventRecordNode(
312 NodeId{m_cg.m_nodes.size()}, m_cg.current_access_index());
315 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
317 event_record->set_node(g.add_event_record_node(event));
321 MUDA_INLINE
void ComputeGraphAccessor::update_event_record_node(cudaEvent_t event)
324 [&](GraphExec& g_exec)
326 const auto& [name, closure] = current_closure();
327 auto event_record = current_node<ComputeGraphEventRecordNode>();
328 g_exec.set_event_record_node_parms(event_record->m_node, event);
333 MUDA_INLINE
void ComputeGraphAccessor::set_event_wait_node(cudaEvent_t event)
335 switch(ComputeGraphBuilder::current_phase())
337 case ComputeGraphPhase::TopoBuilding:
339 "When ComputeGraphPhase == TopoBuilding, "
340 "you don't need to create event, so keep it nullptr.");
342 case ComputeGraphPhase::Building:
343 add_event_wait_node(event);
345 case ComputeGraphPhase::Updating:
346 update_event_wait_node(event);
349 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
354 MUDA_INLINE
void ComputeGraphAccessor::set_capture_node(cudaGraph_t sub_graph)
356 switch(ComputeGraphBuilder::current_phase())
358 case ComputeGraphPhase::TopoBuilding:
359 MUDA_ASSERT(!sub_graph,
360 "When ComputeGraphPhase == TopoBuilding, "
361 "you don't need to create sub_graph, so keep it nullptr.");
362 case ComputeGraphPhase::Building:
363 add_capture_node(sub_graph);
365 case ComputeGraphPhase::Updating:
366 update_capture_node(sub_graph);
369 MUDA_ERROR_WITH_LOCATION(
"invalid phase");
374 MUDA_INLINE
void ComputeGraphAccessor::add_event_wait_node(cudaEvent_t event)
376 MUDA_ASSERT(!m_cg.m_flags.has(muda::GraphInstantiateFlagBit::DeviceLaunch),
377 "Event Wait Node is not allowed in a graph that will be launched on device");
382 ComputeGraphEventWaitNode* event_wait =
383 get_or_create_node<ComputeGraphEventWaitNode>(
386 const auto& [name, closure] = current_closure();
387 return new ComputeGraphEventWaitNode(
388 NodeId{m_cg.m_nodes.size()}, m_cg.current_access_index());
391 if(ComputeGraphBuilder::current_phase() == ComputeGraphPhase::Building)
393 event_wait->set_node(g.add_event_wait_node(event));
397 MUDA_INLINE
void ComputeGraphAccessor::update_event_wait_node(cudaEvent_t event)
400 [&](GraphExec& g_exec)
402 const auto& [name, closure] = current_closure();
403 auto event_wait = current_node<ComputeGraphEventWaitNode>();
404 g_exec.set_event_wait_node_parms(event_wait->m_node, event);
408 MUDA_INLINE
void ComputeGraphAccessor::add_capture_node(cudaGraph_t sub_graph)
413 auto capture_node = get_or_create_node<ComputeGraphCaptureNode>(
416 const auto& [name, closure] = current_closure();
417 return new ComputeGraphCaptureNode{NodeId{m_cg.m_nodes.size()},
418 m_cg.current_access_index()};
420 if(ComputeGraphBuilder::is_building())
422 cudaGraphNode_t node;
423 checkCudaErrors(cudaGraphAddChildGraphNode(
424 &node, g.handle(),
nullptr, 0, sub_graph));
425 capture_node->set_node(node);
426 capture_node->update_sub_graph(sub_graph);
431 MUDA_INLINE
void ComputeGraphAccessor::update_capture_node(cudaGraph_t sub_graph)
434 [&](GraphExec& g_exec)
436 const auto& [name, closure] = current_closure();
437 auto capture_node = current_node<ComputeGraphCaptureNode>();
438 checkCudaErrors(cudaGraphExecChildGraphNodeSetParams(
439 g_exec.handle(), capture_node->handle(), sub_graph));
440 capture_node->update_sub_graph(sub_graph);
446 template <
typename F>
447 void ComputeGraphAccessor::access_graph(F&& f)
450 ++m_cg.m_access_graph_index;
453 template <
typename F>
454 void ComputeGraphAccessor::access_graph_exec(F&& f)
456 f(*m_cg.m_graph_exec.get());
459 template <
typename NodeType,
typename F>
460 MUDA_INLINE NodeType* ComputeGraphAccessor::get_or_create_node(F&& f)
462 static_assert(std::is_base_of_v<ComputeGraphNodeBase, NodeType>,
463 "NodeType must be derived from ComputeGraphNodeBase");
464 if(!m_cg.m_is_topo_built)
467 auto& [name, closure] = current_closure();
468 closure->m_graph_nodes.emplace_back(ptr);
469 m_cg.m_nodes.emplace_back(ptr);
473 return current_node<NodeType>();
475 MUDA_INLINE
void ComputeGraphAccessor::set_var_usage(VarId
id, ComputeGraphVarUsage usage)
477 auto& dst_usage = current_closure().second->m_var_usages[id];
478 if(dst_usage < usage)