MUDA
Loading...
Searching...
No Matches
compute_graph_catpure_node.h
1#pragma once
2#include <muda/compute_graph/compute_graph_node.h>
3#include <muda/graph/graph.h>
4
5namespace muda
6{
8{
9 protected:
10 friend class ComputeGraph;
12 ComputeGraphCaptureNode(NodeId node_id, uint64_t access_index)
13 : ComputeGraphNodeBase(enum_name(ComputeGraphNodeType::CaptureNode),
14 node_id,
15 access_index,
16 ComputeGraphNodeType::CaptureNode)
17 {
18 auto n = std::string_view{
19 details::LaunchInfoCache::current_capture_name().auto_select()};
20 if(n.empty() || n == "")
21 m_name += std::string(":~");
22 else
23 m_name += std::string(":") + std::string(n.data());
24 }
25
26 virtual ~ComputeGraphCaptureNode() override { update_sub_graph(nullptr); }
27
28 void set_node(cudaGraphNode_t node) { set_handle(node); }
29
30 void update_sub_graph(cudaGraph_t sub_graph)
31 {
32 if(m_sub_graph)
33 checkCudaErrors(cudaGraphDestroy(m_sub_graph));
34 m_sub_graph = sub_graph;
35 }
36
37 cudaGraph_t m_sub_graph = nullptr;
38};
39} // namespace muda
Definition compute_graph_catpure_node.h:8
Definition compute_graph.h:38
Definition compute_graph_node.h:13
Definition compute_graph_node_id.h:6
Definition compute_graph_accessor.h:13