MUDA
Loading...
Searching...
No Matches
compute_graph_var_manager.inl
1#include <numeric>
2#include <algorithm>
3#include <muda/compute_graph/compute_graph_var.h>
4#include <muda/compute_graph/compute_graph.h>
5
6namespace muda
7{
8MUDA_INLINE ComputeGraphVarManager::~ComputeGraphVarManager()
9{
10 for(auto& var : m_vars)
11 delete var;
12}
13
14template <typename T>
15constexpr void check_var_type()
16{
17 static_assert(!std::is_same_v<T, ::muda::Event>,
18 "please use cudaEvent_t as a ComputeGraphVar");
19}
20
21template <typename T>
22MUDA_INLINE ComputeGraphVar<T>& ComputeGraphVarManager::create_var(std::string_view name)
23{
24 check_var_type<T>();
25 auto ptr = new ComputeGraphVar<T>(this, name, VarId{m_vars.size()});
26 m_vars.emplace_back(ptr);
27 if(m_vars_map.find(std::string{name}) != m_vars_map.end())
28 MUDA_ERROR_WITH_LOCATION("var[%s] already exists", name.data());
29 m_vars_map.emplace(name, ptr);
30 return *ptr;
31}
32template <typename T>
33MUDA_INLINE ComputeGraphVar<T>& ComputeGraphVarManager::create_var(std::string_view name,
34 const T& init_value)
35{
36 check_var_type<T>();
37 auto ptr = new ComputeGraphVar<T>(this, name, VarId{m_vars.size()}, init_value);
38 m_vars.emplace_back(ptr);
39 m_vars_map.emplace(name, ptr);
40 return *ptr;
41}
42template <typename T>
43MUDA_INLINE ComputeGraphVar<T>* ComputeGraphVarManager::find_var(std::string_view name)
44{
45 auto it = m_vars_map.find(std::string{name});
46 if(it == m_vars_map.end())
47 return nullptr;
48 return dynamic_cast<ComputeGraphVar<T>*>(it->second);
49}
50
51template <typename... T>
52MUDA_INLINE bool ComputeGraphVarManager::is_using(const ComputeGraphVar<T>&... vars) const
53{
54 std::array<const ComputeGraphVarBase*, sizeof...(T)> var_array{&vars...};
55 return is_using(span<const ComputeGraphVarBase*>{var_array});
56}
57template <typename... T>
58MUDA_INLINE void ComputeGraphVarManager::sync(const ComputeGraphVar<T>&... vars) const
59{
60 std::array<const ComputeGraphVarBase*, sizeof...(T)> var_array{&vars...};
61 sync(span<const ComputeGraphVarBase*>{var_array});
62}
63template <typename... T>
64MUDA_INLINE void ComputeGraphVarManager::sync_on(cudaStream_t stream,
65 const ComputeGraphVar<T>&... vars) const
66{
67 std::array<const ComputeGraphVarBase*, sizeof...(T)> var_array{&vars...};
68 sync_on(stream, span<const ComputeGraphVarBase*>{var_array});
69};
70
71MUDA_INLINE auto ComputeGraphVarManager::create_graph(std::string_view name, ComputeGraphFlag flags)
72 -> S<ComputeGraph>
73{
74 return std::make_shared<ComputeGraph>(*this, name, flags);
75}
76
77MUDA_INLINE bool ComputeGraphVarManager::is_using() const
78{
79 return is_using(var_span());
80}
81
82MUDA_INLINE void ComputeGraphVarManager::sync() const
83{
84 sync(var_span());
85}
86
87MUDA_INLINE void ComputeGraphVarManager::sync_on(cudaStream_t stream) const
88{
89 sync_on(stream, var_span());
90}
91
92MUDA_INLINE bool ComputeGraphVarManager::is_using(const span<const ComputeGraphVarBase*> vars) const
93{
94 auto graphs = unique_graphs(vars);
95 return std::any_of(graphs.begin(),
96 graphs.end(),
97 [](ComputeGraph* graph) {
98 return graph->query() == Event::QueryResult::eNotReady;
99 });
100}
101
102MUDA_INLINE void ComputeGraphVarManager::sync(const span<const ComputeGraphVarBase*> vars) const
103{
104 auto graphs = unique_graphs(vars);
105 std::for_each(graphs.begin(),
106 graphs.end(),
107 [&](ComputeGraph* graph)
108 { checkCudaErrors(cudaEventSynchronize(graph->m_event)); });
109}
110
111MUDA_INLINE void ComputeGraphVarManager::sync_on(cudaStream_t stream,
112 const span<const ComputeGraphVarBase*> vars) const
113{
114 auto graphs = unique_graphs(vars);
115 std::for_each(graphs.begin(),
116 graphs.end(),
117 [&](ComputeGraph* graph) {
118 checkCudaErrors(cudaStreamWaitEvent(stream, graph->m_event, 0));
119 });
120}
121
122MUDA_INLINE void ComputeGraphVarManager::graphviz(std::ostream& o,
123 const ComputeGraphGraphvizOptions& options) const
124{
125 auto opt = options;
126
127 o << "digraph G {\n";
128 o << options.graph_font << "\n";
129 if(opt.show_vars)
130 {
131 o << "subgraph cluster_" << opt.graph_id++;
132 o << " {\n"
133 "beautify=true;\n";
134 o << opt.cluster_var_style << "\n";
135 o << "// vars: \n";
136 for(auto var : m_vars)
137 {
138 var->graphviz_def(o, opt);
139 o << "\n";
140 }
141 o << "}\n";
142 }
143
144 opt.as_subgraph = true;
145
146 for(auto graph : m_graphs)
147 {
148 graph->graphviz(o, opt);
149 opt.graph_id++;
150 }
151 o << "}\n";
152}
153
154MUDA_INLINE std::vector<ComputeGraph*> ComputeGraphVarManager::unique_graphs(
155 span<const ComputeGraphVarBase*> vars) const
156{
157 std::vector<ComputeGraph*> graphs;
158 for(auto var : vars)
159 {
160 for(auto& [graph, _] : var->m_related_closure_infos)
161 {
162 graphs.emplace_back(graph);
163 }
164 }
165 std::sort(graphs.begin(), graphs.end());
166 // get unique graphs
167 graphs.erase(std::unique(graphs.begin(), graphs.end()), graphs.end());
168 return graphs;
169}
170
171MUDA_INLINE span<const ComputeGraphVarBase*> ComputeGraphVarManager::var_span() const
172{
173 return span<const ComputeGraphVarBase*>{
174 const_cast<const ComputeGraphVarBase**>(m_vars.data()), m_vars.size()};
175}
176} // namespace muda