MUDA
Loading...
Searching...
No Matches
graph.h
1#pragma once
2#include <unordered_map>
3#include <unordered_set>
4
5#include <muda/graph/graph_base.h>
6#include <muda/graph/graph_exec.h>
7
8#include <muda/graph/kernel_node.h>
9#include <muda/graph/memory_node.h>
10#include <muda/graph/host_node.h>
11#include <muda/graph/event_node.h>
12
13#include <muda/graph/graph_instantiate_flag.h>
14
15namespace muda
16{
17class Graph
18{
19 template <typename T>
20 using S = std::shared_ptr<T>;
21 template <typename T>
22 using U = std::unique_ptr<T>;
23
24 public:
25 Graph();
26 ~Graph();
27
28 // delete copy
29 Graph(const Graph&) = delete;
30 Graph& operator=(const Graph&) = delete;
31
32 // move
33 Graph(Graph&&);
34 Graph& operator=(Graph&&);
35
36
37 friend class GraphExec;
38 friend class std::shared_ptr<Graph>;
39
40 MUDA_NODISCARD S<GraphExec> instantiate();
41 MUDA_NODISCARD S<GraphExec> instantiate(Flags<GraphInstantiateFlagBit> flags);
42
43 template <typename T>
44 S<KernelNode> add_kernel_node(const S<KernelNodeParms<T>>& kernelParms,
45 const std::vector<S<GraphNode>>& deps);
46 template <typename T>
47 S<KernelNode> add_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
48
49
50 template <typename T>
51 S<HostNode> add_host_node(const S<HostNodeParms<T>>& hostParms,
52 const std::vector<S<GraphNode>>& deps);
53 template <typename T>
54 S<HostNode> add_host_node(const S<HostNodeParms<T>>& hostParms);
55
56
57 S<MemcpyNode> add_memcpy_node(void* dst,
58 const void* src,
59 size_t size_bytes,
60 cudaMemcpyKind kind,
61 const std::vector<S<GraphNode>>& deps);
62 S<MemcpyNode> add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
63 S<MemcpyNode> add_memcpy_node(const cudaMemcpy3DParms& parms);
64 S<MemcpyNode> add_memcpy_node(const cudaMemcpy3DParms& parms,
65 const std::vector<S<GraphNode>>& deps);
66
67 S<MemsetNode> add_memset_node(const cudaMemsetParams& parms,
68 const std::vector<S<GraphNode>>& deps);
69 S<MemsetNode> add_memset_node(const cudaMemsetParams& parms);
70
71
72 S<EventRecordNode> add_event_record_node(cudaEvent_t e,
73 const std::vector<S<GraphNode>>& deps);
74 S<EventRecordNode> add_event_record_node(cudaEvent_t e);
75 S<EventWaitNode> add_event_wait_node(cudaEvent_t e,
76 const std::vector<S<GraphNode>>& deps);
77 S<EventWaitNode> add_event_wait_node(cudaEvent_t e);
78
79
80 void add_dependency(S<GraphNode> from, S<GraphNode> to);
81
82 cudaGraph_t handle() const { return m_handle; }
83 cudaGraph_t handle() { return m_handle; }
84 static auto create() { return std::make_shared<Graph>(); }
85
86 private:
87 cudaGraph_t m_handle;
88 // keep the ref count > 0 for those whose data should be kept alive for the graph life.
89 std::list<S<NodeParms>> m_cached;
90 static std::vector<cudaGraphNode_t> map_dependencies(const std::vector<S<GraphNode>>& deps);
91};
92} // namespace muda
93
94#include "details/graph.inl"
Definition flag.h:9
Definition graph_exec.h:11
Definition graph.h:18
Definition host_node.h:15
Definition kernel_node.h:15