MUDA
Loading...
Searching...
No Matches
compute_graph_node.h
1#pragma once
2#include <map>
3#include <string>
4#include <muda/compute_graph/compute_graph_node_type.h>
5#include <muda/compute_graph/compute_graph_node_id.h>
6#include <muda/compute_graph/compute_graph_var_usage.h>
7#include <muda/compute_graph/compute_graph_var_id.h>
8#include <muda/compute_graph/compute_graph_fwd.h>
9
10namespace muda
11{
13{
14 public:
15 auto node_id() const { return m_node_id; }
16 auto access_index() const { return m_access_index; }
17 auto type() const { return m_type; }
18 auto name() const { return std::string_view{m_name}; }
19
20 virtual ~ComputeGraphNodeBase() = default;
21
22 protected:
23 template <typename T>
24 using S = std::shared_ptr<T>;
25
26 friend class ComputeGraph;
27 friend class ComputeGraphVarBase;
28 ComputeGraphNodeBase(std::string_view name, NodeId node_id, uint64_t access_index, ComputeGraphNodeType type)
29 : m_name(name)
30 , m_node_id(node_id)
31 , m_access_index(access_index)
32 , m_type(type)
33 {
34 }
35
36 std::string m_name;
37 NodeId m_node_id;
38 uint64_t m_access_index;
39
40 ComputeGraphNodeType m_type;
41 cudaGraphNode_t m_cuda_node = nullptr;
42
43
44 auto handle() const { return m_cuda_node; }
45 void set_handle(cudaGraphNode_t handle) { m_cuda_node = handle; }
46 auto is_valid() const { return m_cuda_node; }
47};
48
49template <typename NodeT, ComputeGraphNodeType Type>
51{
52 protected:
53 friend class ComputeGraph;
55 ComputeGraphNode(NodeId node_id, uint64_t access_graph_index);
56
57 S<NodeT> m_node;
58 void set_node(S<NodeT> node);
59 virtual ~ComputeGraphNode() = default;
60};
61} // namespace muda
62
63#include "details/compute_graph_node.inl"
Definition compute_graph.h:38
Definition compute_graph_node.h:13
Definition compute_graph_node.h:51
Definition compute_graph_var.h:17
Definition compute_graph_node_id.h:6
Definition compute_graph_accessor.h:13