43 std::string m_node_name;
76 using U = std::unique_ptr<T>;
78 using S = std::shared_ptr<T>;
83 S<GraphExec> m_graph_exec{
nullptr};
85 std::unordered_map<NodeId::value_type, cudaGraph_t> m_sub_graphs;
87 std::vector<std::pair<std::string, ComputeGraphClosure*>> m_closures;
89 std::map<VarId, details::LocalVarId> m_global_to_local_var_id;
90 std::vector<details::LocalVarInfo> m_related_vars;
94 std::vector<ComputeGraphNodeBase*> m_nodes;
95 std::vector<std::vector<ComputeGraphNodeBase*>> m_graph_nodes;
96 std::vector<Dependency> m_deps;
98 std::vector<int> m_closure_need_update;
109 std::string_view name =
"graph",
110 ComputeGraphFlag flag = ComputeGraphFlag::HostLaunch);
120 std::string_view name()
const {
return m_name; }
128 AddNodeProxy create_node(std::string_view node_name);
141 void launch(
bool single_stream, cudaStream_t s =
nullptr);
143 void launch(cudaStream_t s =
nullptr) {
return launch(
false, s); }
159 void capture(std::function<
void(cudaStream_t)>&& f);
160 void capture(std::string_view name, std::function<
void(cudaStream_t)>&& f);
168 void graphviz(std::ostream& o,
const ComputeGraphGraphvizOptions& options = {});
176 GraphViewer viewer();
178 operator GraphViewer() {
return viewer(); }
183 void cuda_graph_add_deps();
187 void serial_launch();
191 void check_vars_valid();
193 friend class AddNodeProxy;
194 ComputeGraph& add_node(std::string&& name,
const std::function<
void()>& f);
196 friend class ComputeGraphNodeBase;
197 friend class ComputeGraphClosure;
198 span<const Dependency> dep_span(
size_t begin,
size_t count)
const;
200 void set_current_graph_as_this();
202 static void clear_current_graph();
204 static Stream& shared_capture_stream();
206 friend class ComputeGraphBuilder;
207 ClosureId current_closure_id()
const {
return m_current_closure_id; };
209 NodeId current_node_id()
const {
return m_current_node_id; };
211 size_t current_access_index()
const {
return m_access_graph_index; }
213 ComputeGraphPhase current_graph_phase()
const;
218 bool m_need_update =
false;
219 ClosureId m_current_closure_id;
220 NodeId m_current_node_id;
221 ComputeGraphPhase m_current_graph_phase = ComputeGraphPhase::None;
222 bool m_allow_access_graph =
false;
223 size_t m_access_graph_index = 0;
224 bool m_allow_node_adding =
true;
226 cudaStream_t m_current_single_stream =
nullptr;
227 bool m_is_capturing =
false;
229 bool m_is_in_capture_func =
false;
231 bool m_is_topo_built =
false;