2#include <muda/exception.h>
4#include <muda/compute_graph/compute_graph.h>
5#include <muda/compute_graph/compute_graph_builder.h>
6#include <muda/compute_graph/compute_graph_var.h>
7#include <muda/compute_graph/compute_graph_var_manager.h>
8#include <muda/compute_graph/compute_graph_node.h>
9#include <muda/compute_graph/compute_graph_closure.h>
10#include <muda/compute_graph/compute_graph_accessor.h>
14MUDA_INLINE ComputeGraph::AddNodeProxy::AddNodeProxy(ComputeGraph& cg, std::string_view node_name)
16 , m_node_name(node_name)
20MUDA_INLINE ComputeGraph::GraphPhaseGuard::GraphPhaseGuard(ComputeGraph& cg, ComputeGraphPhase phase)
23 m_cg.set_current_graph_as_this();
24 m_cg.m_current_graph_phase = phase;
27MUDA_INLINE ComputeGraph::GraphPhaseGuard::~GraphPhaseGuard()
29 m_cg.m_current_graph_phase = ComputeGraphPhase::None;
30 ComputeGraph::clear_current_graph();
33MUDA_INLINE ComputeGraph& ComputeGraph::AddNodeProxy::operator<<(std::function<
void()>&& f) &&
35 m_cg.add_node(std::move(m_node_name), f);
39MUDA_INLINE
void ComputeGraph::graphviz(std::ostream& o,
const ComputeGraphGraphvizOptions& options)
43 if(options.as_subgraph)
45 o <<
"subgraph cluster_" << options.graph_id;
47 o <<
"label=\"" << name() <<
"\";\n";
48 o << options.cluster_style <<
"\n";
53 o << options.graph_font <<
"\n";
57 if(options.show_vars && !options.as_subgraph)
60 for(
auto&& [local_id, var] : m_related_vars)
62 var->graphviz_def(o, options);
68 if(options.show_nodes)
71 o <<
"node_g" << options.graph_id <<
"[label=\"" << name() <<
"\""
72 << options.node_style <<
"]\n";
74 for(
auto& [name, node] : m_closures)
76 node->graphviz_def(o, options);
82 o <<
"// node var usages: \n";
83 for(
auto& [name, node] : m_closures)
85 node->graphviz_var_usages(o, options);
90 o <<
"// node deps: \n";
91 for(
auto& [name, node] : m_closures)
93 if(node->deps().size() != 0)
95 o <<
"node_g" << options.graph_id <<
"->";
96 node->graphviz_id(o, options);
98 o <<
"[" << options.arc_style
103 for(
auto dep : m_deps)
105 auto src = m_closures[dep.to.value()];
106 auto dst = m_closures[dep.from.value()];
107 dst.second->graphviz_id(o, options);
109 src.second->graphviz_id(o, options);
110 o <<
"[" << options.arc_style
118MUDA_INLINE GraphViewer ComputeGraph::viewer()
120 MUDA_ASSERT(m_graph_exec,
"graph is not built yet, call ComputeGraph::build() to build it.");
121 return GraphViewer{m_graph_exec->handle(), m_flags};
124MUDA_INLINE
void ComputeGraph::topo_build()
129 m_closure_need_update.clear();
130 m_closure_need_update.resize(m_closures.size(),
false);
131 GraphPhaseGuard guard(*
this, ComputeGraphPhase::TopoBuilding);
132 for(
size_t i = 0; i < m_closures.size(); ++i)
135 m_current_closure_id = ClosureId{i};
136 m_allow_access_graph =
true;
137 m_access_graph_index = 0;
138 m_closures[i].second->operator()();
144MUDA_INLINE
void ComputeGraph::build()
149 GraphPhaseGuard guard(*
this, ComputeGraphPhase::Building);
152 m_closure_need_update.clear();
153 m_closure_need_update.resize(m_closures.size(),
false);
155 for(
size_t i = 0; i < m_closures.size(); ++i)
158 m_current_closure_id = ClosureId{i};
159 m_allow_access_graph =
true;
160 m_access_graph_index = 0;
161 m_closures[i].second->operator()();
165 cuda_graph_add_deps();
167 m_graph_exec = m_graph.instantiate(m_flags);
168 m_graph_exec->upload();
171MUDA_INLINE
void ComputeGraph::serial_launch()
173 GraphPhaseGuard guard(*
this, ComputeGraphPhase::SerialLaunching);
175 for(
size_t i = 0; i < m_closures.size(); ++i)
178 m_current_closure_id = ClosureId{i};
179 m_allow_access_graph =
false;
180 m_closures[i].second->operator()();
181 m_is_capturing =
false;
185MUDA_INLINE
void ComputeGraph::check_vars_valid()
187 for(
auto&& [local_id, var] : m_related_vars)
190 MUDA_ERROR_WITH_LOCATION(
191 "var[%s] is not valid, "
192 "you need update the var before launch this graph",
197MUDA_INLINE ComputeGraph& ComputeGraph::add_node(std::string&& name,
198 const std::function<
void()>& f)
200 details::ComputeGraphAccessor(
this).check_allow_node_adding();
201 MUDA_ASSERT(m_allow_node_adding,
202 "This graph is built or updated, so you can't add new nodes any more.");
203 auto size = m_closures.size();
204 auto closure =
new ComputeGraphClosure{
this, ClosureId{size}, name, f};
205 m_closures.emplace_back(name, closure);
209MUDA_INLINE
auto ComputeGraph::dep_span(
size_t begin,
size_t count)
const
210 -> span<const Dependency>
212 return span<const Dependency>{m_deps}.subspan(begin, count);
215MUDA_INLINE
void ComputeGraph::set_current_graph_as_this()
217 ComputeGraphBuilder::current_graph(
this);
220MUDA_INLINE
void ComputeGraph::clear_current_graph()
222 ComputeGraphBuilder::current_graph(
nullptr);
225MUDA_INLINE Stream& ComputeGraph::shared_capture_stream()
227 static thread_local Stream s(Stream::Flag::eNonBlocking);
231MUDA_INLINE
void ComputeGraph::capture(std::function<
void(cudaStream_t)>&& f)
233 capture(
"", std::move(f));
236MUDA_INLINE
void ComputeGraph::capture(std::string_view name,
237 std::function<
void(cudaStream_t)>&& f)
239 m_is_in_capture_func =
true;
241 auto do_capture = [&]
243 auto& s = shared_capture_stream();
245 m_is_capturing =
true;
250 details::ComputeGraphAccessor(
this).set_capture_node(g);
251 m_is_capturing =
false;
254 switch(current_graph_phase())
256 case ComputeGraphPhase::TopoBuilding:
259 details::LaunchInfoCache::current_capture_name(name);
260 details::ComputeGraphAccessor(
this).set_capture_node(
nullptr);
261 details::LaunchInfoCache::current_capture_name(
"");
263 case ComputeGraphPhase::SerialLaunching: {
265 f(m_current_single_stream);
268 case ComputeGraphPhase::Updating: {
272 case ComputeGraphPhase::Building: {
273 details::LaunchInfoCache::current_capture_name(name);
275 details::LaunchInfoCache::current_capture_name(
"");
279 MUDA_ERROR_WITH_LOCATION(
"invoking capture() outside Graph Closure is not allowed");
282 m_is_in_capture_func =
false;
285MUDA_INLINE ComputeGraphPhase ComputeGraph::current_graph_phase()
const
287 return m_current_graph_phase;
290MUDA_INLINE
void ComputeGraph::_update()
295 GraphPhaseGuard guard(*
this, ComputeGraphPhase::Updating);
297 for(
size_t i = 0; i < m_closure_need_update.size(); ++i)
299 auto& need_update = m_closure_need_update[i];
302 m_current_closure_id = ClosureId{i};
304 m_allow_access_graph =
true;
305 m_access_graph_index = 0;
306 m_closures[i].second->operator()();
315MUDA_INLINE ComputeGraph::~ComputeGraph()
317 for(
auto var_info : m_related_vars)
318 var_info.var->remove_related_closure_infos(this);
320 m_var_manager->m_graphs.erase(
this);
322 for(
auto node : m_nodes)
325 for(
auto& [name, closure] : m_closures)
329MUDA_INLINE
void ComputeGraph::emplace_related_var(ComputeGraphVarBase* var)
331 auto global_var_id = var->var_id();
332 auto iter = m_global_to_local_var_id.find(global_var_id);
333 if(iter == m_global_to_local_var_id.end())
335 auto local_id = details::LocalVarId{m_related_vars.size()};
336 m_related_vars.emplace_back(details::LocalVarInfo{local_id, var});
337 m_global_to_local_var_id.emplace(std::make_pair(global_var_id, local_id));
341MUDA_INLINE ComputeGraph::ComputeGraph(ComputeGraphVarManager& manager,
342 std::string_view name,
343 ComputeGraphFlag flag)
344 : m_var_manager(&manager)
347 if constexpr(!COMPUTE_GRAPH_ON)
349 MUDA_ERROR_WITH_LOCATION(
"ComputeGraph is disabled, please define MUDA_COMPUTE_GRAPH_ON=1 to enable it.");
351 m_var_manager->m_graphs.insert(
this);
354 case ComputeGraphFlag::DeviceLaunch:
355 m_flags |= GraphInstantiateFlagBit::DeviceLaunch;
363MUDA_INLINE ComputeGraph::AddNodeProxy ComputeGraph::create_node(std::string_view node_name)
365 return AddNodeProxy{*
this, node_name};
368MUDA_INLINE
void ComputeGraph::update()
370 m_allow_node_adding =
false;
375MUDA_INLINE
void ComputeGraph::launch(
bool single_stream, cudaStream_t s)
377 m_allow_node_adding =
false;
380 m_current_single_stream = s;
388 m_graph_exec->launch(s);
390 m_event_result = Event::QueryResult::eNotReady;
391 checkCudaErrors(cudaEventRecord(m_event, s));
393 if(Debug::is_debug_sync_all())
394 checkCudaErrors(cudaStreamSynchronize(s));
398MUDA_INLINE Event::QueryResult ComputeGraph::query()
const
400 if(m_event_result == Event::QueryResult::eNotReady)
401 m_event_result = m_event.query();
402 return m_event_result;
415 MUDA_INLINE
void process_node(std::vector<ComputeGraph::Dependency>& deps,
416 std::vector<ClosureId>& last_read_or_write_nodes,
417 std::vector<ClosureId>& last_write_nodes,
418 ComputeGraphClosure& closure,
419 const std::vector<std::pair<LocalVarId, ComputeGraphVarUsage>>& local_var_usage,
423 auto is_read_write = [](ComputeGraphVarUsage usage)
424 {
return usage == ComputeGraphVarUsage::ReadWrite; };
425 auto is_read_only = [](ComputeGraphVarUsage usage)
426 {
return usage == ComputeGraphVarUsage::Read; };
428 std::unordered_set<ClosureId> unique_deps;
430 for(
auto& [local_var_id, usage] : local_var_usage)
435 if(is_read_write(usage))
437 auto dst_nid = last_read_or_write_nodes[local_var_id.value()];
438 if(dst_nid.is_valid())
441 if(unique_deps.find(dst_nid) == unique_deps.end())
444 unique_deps.insert(dst_nid);
452 else if(is_read_only(usage))
454 auto dst_nid = last_write_nodes[local_var_id.value()];
455 if(dst_nid.is_valid())
458 if(unique_deps.find(dst_nid) == unique_deps.end())
461 unique_deps.insert(dst_nid);
467 auto current_closure_id = closure.clousure_id();
470 for(
auto& [local_var_id, usage] : local_var_usage)
475 if(is_read_write(usage))
477 last_read_or_write_nodes[local_var_id.value()] = current_closure_id;
478 last_write_nodes[local_var_id.value()] = current_closure_id;
483 else if(is_read_only(usage))
485 last_read_or_write_nodes[local_var_id.value()] = current_closure_id;
490 dep_begin = deps.size();
491 for(
auto dep : unique_deps)
492 deps.emplace_back(ComputeGraph::Dependency{dep, current_closure_id});
493 dep_count = unique_deps.size();
497MUDA_INLINE
void ComputeGraph::cuda_graph_add_deps()
499 std::vector<cudaGraphNode_t> froms;
500 froms.reserve(m_deps.size());
501 std::vector<cudaGraphNode_t> tos;
502 tos.reserve(m_deps.size());
506 for(
auto& [name, closure] : m_closures)
508 MUDA_ASSERT(closure->m_graph_nodes.size() > 0,
"closure[%s] has no nodes", name.data());
509 if(closure->m_graph_nodes.size() == 1)
512 auto from = closure->m_graph_nodes.front();
514 for(
size_t i = 1; i < closure->m_graph_nodes.size(); ++i)
516 to = closure->m_graph_nodes[i];
517 froms.emplace_back(from->handle());
518 tos.emplace_back(to->handle());
524 for(
auto dep : m_deps)
526 auto from = m_closures[dep.from.value()].second->m_graph_nodes.back();
527 auto to = m_closures[dep.to.value()].second->m_graph_nodes.front();
528 froms.emplace_back(from->handle());
529 tos.emplace_back(to->handle());
532 checkCudaErrors(cudaGraphAddDependencies(
533 m_graph.handle(), froms.data(), tos.data(), froms.size()));
536MUDA_INLINE
void ComputeGraph::build_deps()
539 auto local_var_count = m_related_vars.size();
542 auto last_write_nodes = std::vector<ClosureId>(local_var_count, ClosureId{});
544 auto last_read_or_write_nodes = std::vector<ClosureId>(local_var_count, ClosureId{});
547 for(
size_t i = 0u; i < m_closures.size(); i++)
549 auto& [name, closure] = m_closures[i];
552 std::vector<std::pair<details::LocalVarId, ComputeGraphVarUsage>> local_var_usage;
553 local_var_usage.reserve(closure->var_usages().size());
554 for(
auto&& [var_id, usage] : closure->var_usages())
556 auto local_id = m_global_to_local_var_id[var_id];
557 local_var_usage.emplace_back(local_id, usage);
560 size_t dep_begin, dep_count;
561 details::process_node(
562 m_deps, last_read_or_write_nodes, last_write_nodes, *closure, local_var_usage, dep_begin, dep_count);
563 closure->set_deps_range(dep_begin, dep_count);
566 m_is_topo_built =
true;