MUDA
Loading...
Searching...
No Matches
compute_graph.h
1#pragma once
2#include <map>
3#include <functional>
4#include <set>
5#include <muda/launch/stream.h>
6#include <muda/launch/event.h>
7#include <muda/mstl/span.h>
8#include <muda/graph/graph.h>
9#include <muda/graph/graph_viewer.h>
10#include <muda/compute_graph/compute_graph_flag.h>
11#include <muda/compute_graph/compute_graph_phase.h>
12#include <muda/compute_graph/compute_graph_node_type.h>
13#include <muda/compute_graph/compute_graph_node_id.h>
14#include <muda/compute_graph/compute_graph_closure_id.h>
15#include <muda/compute_graph/compute_graph_var_id.h>
16#include <muda/compute_graph/compute_graph_var_usage.h>
17#include <muda/compute_graph/compute_graph_dependency.h>
18#include <muda/compute_graph/graphviz_options.h>
19#include <muda/compute_graph/compute_graph_fwd.h>
20
21namespace muda
22{
23namespace details
24{
26 {
27 using U64IdWithType::U64IdWithType;
28 };
30 {
31 public:
32 LocalVarId id{};
33 ComputeGraphVarBase* var = nullptr;
34 };
35} // namespace details
36
38{
39 public:
41 {
42 ComputeGraph& m_cg;
43 std::string m_node_name;
44
45 public:
46 AddNodeProxy(ComputeGraph& cg, std::string_view node_name);
47 ComputeGraph& operator<<(std::function<void()>&& f) &&;
48 };
49 // A depends on B : from B to A
51
53 {
54 ComputeGraph& m_cg;
55
56 public:
57 GraphPhaseGuard(ComputeGraph& cg, ComputeGraphPhase phase);
59 };
60
61 // delete copy
62 ComputeGraph(const ComputeGraph&) = delete;
63 ComputeGraph& operator=(const ComputeGraph&) = delete;
64
65 // delete move
66 ComputeGraph(ComputeGraph&&) = delete;
67 ComputeGraph& operator=(ComputeGraph&&) = delete;
68
69 private:
70 //class TempNodeInfo
71 //{
72 // public:
73 // std::map<VarId, ComputeGraphVarUsage> var_usage;
74 //};
75 template <typename T>
76 using U = std::unique_ptr<T>;
77 template <typename T>
78 using S = std::shared_ptr<T>;
79
80 friend class ComputeGraphVarBase;
81
82 Graph m_graph;
83 S<GraphExec> m_graph_exec{nullptr};
84
85 std::unordered_map<NodeId::value_type, cudaGraph_t> m_sub_graphs;
86
87 std::vector<std::pair<std::string, ComputeGraphClosure*>> m_closures;
88
89 std::map<VarId, details::LocalVarId> m_global_to_local_var_id;
90 std::vector<details::LocalVarInfo> m_related_vars;
91 void emplace_related_var(ComputeGraphVarBase* var);
92
93
94 std::vector<ComputeGraphNodeBase*> m_nodes;
95 std::vector<std::vector<ComputeGraphNodeBase*>> m_graph_nodes;
96 std::vector<Dependency> m_deps;
97
98 std::vector<int> m_closure_need_update;
99 ComputeGraphVarManager* m_var_manager = nullptr;
100
101 friend class ComputeGraphVarManager;
102
103 Event m_event;
104 mutable Event::QueryResult m_event_result = Event::QueryResult::eFinished;
106
107 public:
109 std::string_view name = "graph",
110 ComputeGraphFlag flag = ComputeGraphFlag::HostLaunch);
111
113
114 /**************************************************************
115 *
116 * Info API
117 *
118 ***************************************************************/
119
120 std::string_view name() const { return m_name; }
121
122 /**************************************************************
123 *
124 * GraphNode API
125 *
126 ***************************************************************/
127
128 AddNodeProxy create_node(std::string_view node_name);
129
130
131 /**************************************************************
132 *
133 * Graph Launch API
134 *
135 ***************************************************************/
136
137 void update();
138
139 void build();
140
141 void launch(bool single_stream, cudaStream_t s = nullptr);
142
143 void launch(cudaStream_t s = nullptr) { return launch(false, s); }
144
145 /**************************************************************
146 *
147 * Graph Event Query API
148 *
149 ***************************************************************/
150
151 Event::QueryResult query() const;
152
153 /**************************************************************
154 *
155 * Graph Closure Capture Node API
156 *
157 ***************************************************************/
158
159 void capture(std::function<void(cudaStream_t)>&& f);
160 void capture(std::string_view name, std::function<void(cudaStream_t)>&& f);
161
162 /**************************************************************
163 *
164 * Graph Visualization API
165 *
166 ***************************************************************/
167
168 void graphviz(std::ostream& o, const ComputeGraphGraphvizOptions& options = {});
169
170 /**************************************************************
171 *
172 * Graph Viewer API
173 *
174 ***************************************************************/
175
176 GraphViewer viewer();
177
178 operator GraphViewer() { return viewer(); }
179
180 private: // internal method
181 void topo_build();
182
183 void cuda_graph_add_deps();
184
185 void build_deps();
186
187 void serial_launch();
188
189 void _update();
190
191 void check_vars_valid();
192
193 friend class AddNodeProxy;
194 ComputeGraph& add_node(std::string&& name, const std::function<void()>& f);
195
196 friend class ComputeGraphNodeBase;
197 friend class ComputeGraphClosure;
198 span<const Dependency> dep_span(size_t begin, size_t count) const;
199
200 void set_current_graph_as_this();
201
202 static void clear_current_graph();
203
204 static Stream& shared_capture_stream();
205
206 friend class ComputeGraphBuilder;
207 ClosureId current_closure_id() const { return m_current_closure_id; };
208
209 NodeId current_node_id() const { return m_current_node_id; };
210
211 size_t current_access_index() const { return m_access_graph_index; }
212
213 ComputeGraphPhase current_graph_phase() const;
214
215 private: // internal data
217 std::string m_name;
218 bool m_need_update = false;
219 ClosureId m_current_closure_id;
220 NodeId m_current_node_id;
221 ComputeGraphPhase m_current_graph_phase = ComputeGraphPhase::None;
222 bool m_allow_access_graph = false;
223 size_t m_access_graph_index = 0;
224 bool m_allow_node_adding = true;
225 // TempNodeInfo m_temp_node_info;
226 cudaStream_t m_current_single_stream = nullptr;
227 bool m_is_capturing = false;
228 // in capture func, we don't allow any var eval()
229 bool m_is_in_capture_func = false;
230 // if we have already built the topo, we don't do that again
231 bool m_is_topo_built = false;
232};
233} // namespace muda
234
235#include "details/compute_graph.inl"
Definition compute_graph.h:41
Definition compute_graph.h:53
Definition compute_graph_dependency.h:6
Definition compute_graph.h:38
Definition compute_graph_var.h:17
Definition compute_graph_var_manager.h:15
RAII wrapper for cudaEvent.
Definition event.h:15
QueryResult
Definition event.h:28
Definition flag.h:9
Definition graph.h:18
Definition id_with_type.h:10
Definition compute_graph_accessor.h:13
Definition compute_graph.h:26
Definition compute_graph.h:30