3#include <muda/compute_graph/compute_graph_var.h>
4#include <muda/compute_graph/compute_graph.h>
8MUDA_INLINE ComputeGraphVarManager::~ComputeGraphVarManager()
10 for(
auto& var : m_vars)
15constexpr void check_var_type()
17 static_assert(!std::is_same_v<T, ::muda::Event>,
18 "please use cudaEvent_t as a ComputeGraphVar");
22MUDA_INLINE ComputeGraphVar<T>& ComputeGraphVarManager::create_var(std::string_view name)
25 auto ptr =
new ComputeGraphVar<T>(
this, name, VarId{m_vars.size()});
26 m_vars.emplace_back(ptr);
27 if(m_vars_map.find(std::string{name}) != m_vars_map.end())
28 MUDA_ERROR_WITH_LOCATION(
"var[%s] already exists", name.data());
29 m_vars_map.emplace(name, ptr);
33MUDA_INLINE ComputeGraphVar<T>& ComputeGraphVarManager::create_var(std::string_view name,
37 auto ptr =
new ComputeGraphVar<T>(
this, name, VarId{m_vars.size()}, init_value);
38 m_vars.emplace_back(ptr);
39 m_vars_map.emplace(name, ptr);
43MUDA_INLINE ComputeGraphVar<T>* ComputeGraphVarManager::find_var(std::string_view name)
45 auto it = m_vars_map.find(std::string{name});
46 if(it == m_vars_map.end())
48 return dynamic_cast<ComputeGraphVar<T>*
>(it->second);
51template <
typename... T>
52MUDA_INLINE
bool ComputeGraphVarManager::is_using(
const ComputeGraphVar<T>&... vars)
const
54 std::array<
const ComputeGraphVarBase*,
sizeof...(T)> var_array{&vars...};
55 return is_using(span<const ComputeGraphVarBase*>{var_array});
57template <
typename... T>
58MUDA_INLINE
void ComputeGraphVarManager::sync(
const ComputeGraphVar<T>&... vars)
const
60 std::array<
const ComputeGraphVarBase*,
sizeof...(T)> var_array{&vars...};
61 sync(span<const ComputeGraphVarBase*>{var_array});
63template <
typename... T>
64MUDA_INLINE
void ComputeGraphVarManager::sync_on(cudaStream_t stream,
65 const ComputeGraphVar<T>&... vars)
const
67 std::array<
const ComputeGraphVarBase*,
sizeof...(T)> var_array{&vars...};
68 sync_on(stream, span<const ComputeGraphVarBase*>{var_array});
71MUDA_INLINE
auto ComputeGraphVarManager::create_graph(std::string_view name, ComputeGraphFlag flags)
74 return std::make_shared<ComputeGraph>(*
this, name, flags);
77MUDA_INLINE
bool ComputeGraphVarManager::is_using()
const
79 return is_using(var_span());
82MUDA_INLINE
void ComputeGraphVarManager::sync()
const
87MUDA_INLINE
void ComputeGraphVarManager::sync_on(cudaStream_t stream)
const
89 sync_on(stream, var_span());
92MUDA_INLINE
bool ComputeGraphVarManager::is_using(
const span<const ComputeGraphVarBase*> vars)
const
94 auto graphs = unique_graphs(vars);
95 return std::any_of(graphs.begin(),
97 [](ComputeGraph* graph) {
98 return graph->query() == Event::QueryResult::eNotReady;
102MUDA_INLINE
void ComputeGraphVarManager::sync(
const span<const ComputeGraphVarBase*> vars)
const
104 auto graphs = unique_graphs(vars);
105 std::for_each(graphs.begin(),
107 [&](ComputeGraph* graph)
108 { checkCudaErrors(cudaEventSynchronize(graph->m_event)); });
111MUDA_INLINE
void ComputeGraphVarManager::sync_on(cudaStream_t stream,
112 const span<const ComputeGraphVarBase*> vars)
const
114 auto graphs = unique_graphs(vars);
115 std::for_each(graphs.begin(),
117 [&](ComputeGraph* graph) {
118 checkCudaErrors(cudaStreamWaitEvent(stream, graph->m_event, 0));
122MUDA_INLINE
void ComputeGraphVarManager::graphviz(std::ostream& o,
123 const ComputeGraphGraphvizOptions& options)
const
127 o <<
"digraph G {\n";
128 o << options.graph_font <<
"\n";
131 o <<
"subgraph cluster_" << opt.graph_id++;
134 o << opt.cluster_var_style <<
"\n";
136 for(
auto var : m_vars)
138 var->graphviz_def(o, opt);
144 opt.as_subgraph =
true;
146 for(
auto graph : m_graphs)
148 graph->graphviz(o, opt);
154MUDA_INLINE std::vector<ComputeGraph*> ComputeGraphVarManager::unique_graphs(
155 span<const ComputeGraphVarBase*> vars)
const
157 std::vector<ComputeGraph*> graphs;
160 for(
auto& [graph, _] : var->m_related_closure_infos)
162 graphs.emplace_back(graph);
165 std::sort(graphs.begin(), graphs.end());
167 graphs.erase(std::unique(graphs.begin(), graphs.end()), graphs.end());
171MUDA_INLINE span<const ComputeGraphVarBase*> ComputeGraphVarManager::var_span()
const
173 return span<const ComputeGraphVarBase*>{
174 const_cast<const ComputeGraphVarBase**
>(m_vars.data()), m_vars.size()};