MUDA
Loading...
Searching...
No Matches
compute_graph.inl
1#include <memory>
2#include <muda/exception.h>
3#include <muda/debug.h>
4#include <muda/compute_graph/compute_graph.h>
5#include <muda/compute_graph/compute_graph_builder.h>
6#include <muda/compute_graph/compute_graph_var.h>
7#include <muda/compute_graph/compute_graph_var_manager.h>
8#include <muda/compute_graph/compute_graph_node.h>
9#include <muda/compute_graph/compute_graph_closure.h>
10#include <muda/compute_graph/compute_graph_accessor.h>
11
12namespace muda
13{
14MUDA_INLINE ComputeGraph::AddNodeProxy::AddNodeProxy(ComputeGraph& cg, std::string_view node_name)
15 : m_cg(cg)
16 , m_node_name(node_name)
17{
18}
19
20MUDA_INLINE ComputeGraph::GraphPhaseGuard::GraphPhaseGuard(ComputeGraph& cg, ComputeGraphPhase phase)
21 : m_cg(cg)
22{
23 m_cg.set_current_graph_as_this();
24 m_cg.m_current_graph_phase = phase;
25}
26
27MUDA_INLINE ComputeGraph::GraphPhaseGuard::~GraphPhaseGuard()
28{
29 m_cg.m_current_graph_phase = ComputeGraphPhase::None;
30 ComputeGraph::clear_current_graph();
31}
32
33MUDA_INLINE ComputeGraph& ComputeGraph::AddNodeProxy::operator<<(std::function<void()>&& f) &&
34{
35 m_cg.add_node(std::move(m_node_name), f);
36 return m_cg;
37}
38
39MUDA_INLINE void ComputeGraph::graphviz(std::ostream& o, const ComputeGraphGraphvizOptions& options)
40{
41 topo_build();
42
43 if(options.as_subgraph)
44 {
45 o << "subgraph cluster_" << options.graph_id;
46 o << " {\n";
47 o << "label=\"" << name() << "\";\n";
48 o << options.cluster_style << "\n";
49 }
50 else
51 {
52 o << "digraph G {\n";
53 o << options.graph_font << "\n";
54 }
55
56
57 if(options.show_vars && !options.as_subgraph)
58 {
59 o << "// vars: \n";
60 for(auto&& [local_id, var] : m_related_vars)
61 {
62 var->graphviz_def(o, options);
63 o << "\n";
64 }
65 o << "\n";
66 }
67
68 if(options.show_nodes)
69 {
70 o << "// nodes: \n";
71 o << "node_g" << options.graph_id << "[label=\"" << name() << "\""
72 << options.node_style << "]\n";
73
74 for(auto& [name, node] : m_closures)
75 {
76 node->graphviz_def(o, options);
77 o << "\n";
78 }
79 o << "\n";
80 if(options.show_vars)
81 {
82 o << "// node var usages: \n";
83 for(auto& [name, node] : m_closures)
84 {
85 node->graphviz_var_usages(o, options);
86 o << "\n";
87 }
88 o << "\n";
89 }
90 o << "// node deps: \n";
91 for(auto& [name, node] : m_closures)
92 {
93 if(node->deps().size() != 0)
94 continue;
95 o << "node_g" << options.graph_id << "->";
96 node->graphviz_id(o, options);
97
98 o << "[" << options.arc_style
99 << "]"
100 "\n";
101 }
102
103 for(auto dep : m_deps)
104 {
105 auto src = m_closures[dep.to.value()];
106 auto dst = m_closures[dep.from.value()];
107 dst.second->graphviz_id(o, options);
108 o << "->";
109 src.second->graphviz_id(o, options);
110 o << "[" << options.arc_style
111 << "]"
112 "\n";
113 }
114 }
115 o << "}\n";
116}
117
118MUDA_INLINE GraphViewer ComputeGraph::viewer()
119{
120 MUDA_ASSERT(m_graph_exec, "graph is not built yet, call ComputeGraph::build() to build it.");
121 return GraphViewer{m_graph_exec->handle(), m_flags};
122}
123
124MUDA_INLINE void ComputeGraph::topo_build()
125{
126 if(m_is_topo_built)
127 return;
128
129 m_closure_need_update.clear();
130 m_closure_need_update.resize(m_closures.size(), false);
131 GraphPhaseGuard guard(*this, ComputeGraphPhase::TopoBuilding);
132 for(size_t i = 0; i < m_closures.size(); ++i)
133 {
134 //m_current_node_id = NodeId{i};
135 m_current_closure_id = ClosureId{i};
136 m_allow_access_graph = true;
137 m_access_graph_index = 0;
138 m_closures[i].second->operator()();
139 }
140
141 build_deps();
142}
143
144MUDA_INLINE void ComputeGraph::build()
145{
146 if(m_graph_exec)
147 return;
148
149 GraphPhaseGuard guard(*this, ComputeGraphPhase::Building);
150 if(!m_is_topo_built)
151 {
152 m_closure_need_update.clear();
153 m_closure_need_update.resize(m_closures.size(), false);
154 }
155 for(size_t i = 0; i < m_closures.size(); ++i)
156 {
157 //m_current_node_id = NodeId{i};
158 m_current_closure_id = ClosureId{i};
159 m_allow_access_graph = true;
160 m_access_graph_index = 0;
161 m_closures[i].second->operator()();
162 }
163 if(!m_is_topo_built)
164 build_deps();
165 cuda_graph_add_deps();
166
167 m_graph_exec = m_graph.instantiate(m_flags);
168 m_graph_exec->upload();
169}
170
171MUDA_INLINE void ComputeGraph::serial_launch()
172{
173 GraphPhaseGuard guard(*this, ComputeGraphPhase::SerialLaunching);
174
175 for(size_t i = 0; i < m_closures.size(); ++i)
176 {
177 // m_current_node_id = NodeId{i};
178 m_current_closure_id = ClosureId{i};
179 m_allow_access_graph = false; // no need to access graph
180 m_closures[i].second->operator()();
181 m_is_capturing = false;
182 }
183}
184
185MUDA_INLINE void ComputeGraph::check_vars_valid()
186{
187 for(auto&& [local_id, var] : m_related_vars)
188 if(!var->is_valid())
189 {
190 MUDA_ERROR_WITH_LOCATION(
191 "var[%s] is not valid, "
192 "you need update the var before launch this graph",
193 var->name().data());
194 }
195}
196
197MUDA_INLINE ComputeGraph& ComputeGraph::add_node(std::string&& name,
198 const std::function<void()>& f)
199{
200 details::ComputeGraphAccessor(this).check_allow_node_adding();
201 MUDA_ASSERT(m_allow_node_adding,
202 "This graph is built or updated, so you can't add new nodes any more.");
203 auto size = m_closures.size();
204 auto closure = new ComputeGraphClosure{this, ClosureId{size}, name, f};
205 m_closures.emplace_back(name, closure);
206 return *this;
207}
208
209MUDA_INLINE auto ComputeGraph::dep_span(size_t begin, size_t count) const
210 -> span<const Dependency>
211{
212 return span<const Dependency>{m_deps}.subspan(begin, count);
213}
214
215MUDA_INLINE void ComputeGraph::set_current_graph_as_this()
216{
217 ComputeGraphBuilder::current_graph(this);
218}
219
220MUDA_INLINE void ComputeGraph::clear_current_graph()
221{
222 ComputeGraphBuilder::current_graph(nullptr);
223}
224
225MUDA_INLINE Stream& ComputeGraph::shared_capture_stream()
226{
227 static thread_local Stream s(Stream::Flag::eNonBlocking);
228 return s;
229}
230
231MUDA_INLINE void ComputeGraph::capture(std::function<void(cudaStream_t)>&& f)
232{
233 capture("", std::move(f));
234}
235
236MUDA_INLINE void ComputeGraph::capture(std::string_view name,
237 std::function<void(cudaStream_t)>&& f)
238{
239 m_is_in_capture_func = true;
240
241 auto do_capture = [&]
242 {
243 auto& s = shared_capture_stream();
244 // begin capture and pass the stream to f
245 m_is_capturing = true;
246 s.begin_capture();
247 f(s);
248 cudaGraph_t g;
249 s.end_capture(&g);
250 details::ComputeGraphAccessor(this).set_capture_node(g);
251 m_is_capturing = false;
252 };
253
254 switch(current_graph_phase())
255 {
256 case ComputeGraphPhase::TopoBuilding:
257 // if this is called in topo building phase, we do nothing
258 // but just create an empty capture node
259 details::LaunchInfoCache::current_capture_name(name);
260 details::ComputeGraphAccessor(this).set_capture_node(nullptr);
261 details::LaunchInfoCache::current_capture_name("");
262 break;
263 case ComputeGraphPhase::SerialLaunching: {
264 // simply call it
265 f(m_current_single_stream);
266 }
267 break;
268 case ComputeGraphPhase::Updating: {
269 do_capture();
270 }
271 break;
272 case ComputeGraphPhase::Building: {
273 details::LaunchInfoCache::current_capture_name(name);
274 do_capture();
275 details::LaunchInfoCache::current_capture_name("");
276 }
277 break;
278 default:
279 MUDA_ERROR_WITH_LOCATION("invoking capture() outside Graph Closure is not allowed");
280 break;
281 }
282 m_is_in_capture_func = false;
283}
284
285MUDA_INLINE ComputeGraphPhase ComputeGraph::current_graph_phase() const
286{
287 return m_current_graph_phase;
288}
289
290MUDA_INLINE void ComputeGraph::_update()
291{
292 if(!m_need_update)
293 return;
294
295 GraphPhaseGuard guard(*this, ComputeGraphPhase::Updating);
296
297 for(size_t i = 0; i < m_closure_need_update.size(); ++i)
298 {
299 auto& need_update = m_closure_need_update[i];
300 if(need_update)
301 {
302 m_current_closure_id = ClosureId{i};
303 // m_current_node_id = NodeId{i};
304 m_allow_access_graph = true;
305 m_access_graph_index = 0;
306 m_closures[i].second->operator()();
307 //if(m_is_capturing)
308 // update_capture_node(m_sub_graphs[i]);
309 //m_is_capturing = false;
310 need_update = false;
311 }
312 }
313}
314
315MUDA_INLINE ComputeGraph::~ComputeGraph()
316{
317 for(auto var_info : m_related_vars)
318 var_info.var->remove_related_closure_infos(this);
319
320 m_var_manager->m_graphs.erase(this);
321
322 for(auto node : m_nodes)
323 delete node;
324
325 for(auto& [name, closure] : m_closures)
326 delete closure;
327}
328
329MUDA_INLINE void ComputeGraph::emplace_related_var(ComputeGraphVarBase* var)
330{
331 auto global_var_id = var->var_id();
332 auto iter = m_global_to_local_var_id.find(global_var_id);
333 if(iter == m_global_to_local_var_id.end())
334 {
335 auto local_id = details::LocalVarId{m_related_vars.size()};
336 m_related_vars.emplace_back(details::LocalVarInfo{local_id, var});
337 m_global_to_local_var_id.emplace(std::make_pair(global_var_id, local_id));
338 }
339}
340
341MUDA_INLINE ComputeGraph::ComputeGraph(ComputeGraphVarManager& manager,
342 std::string_view name,
343 ComputeGraphFlag flag)
344 : m_var_manager(&manager)
345 , m_name(name)
346{
347 if constexpr(!COMPUTE_GRAPH_ON)
348 {
349 MUDA_ERROR_WITH_LOCATION("ComputeGraph is disabled, please define MUDA_COMPUTE_GRAPH_ON=1 to enable it.");
350 }
351 m_var_manager->m_graphs.insert(this);
352 switch(flag)
353 {
354 case ComputeGraphFlag::DeviceLaunch:
355 m_flags |= GraphInstantiateFlagBit::DeviceLaunch;
356 break;
357 default:
358 break;
359 }
360}
361
362
363MUDA_INLINE ComputeGraph::AddNodeProxy ComputeGraph::create_node(std::string_view node_name)
364{
365 return AddNodeProxy{*this, node_name};
366}
367
368MUDA_INLINE void ComputeGraph::update()
369{
370 m_allow_node_adding = false;
371 check_vars_valid();
372 _update();
373}
374
375MUDA_INLINE void ComputeGraph::launch(bool single_stream, cudaStream_t s)
376{
377 m_allow_node_adding = false;
378 if(single_stream)
379 {
380 m_current_single_stream = s;
381 serial_launch();
382 }
383 else
384 {
385 check_vars_valid();
386 build();
387 _update();
388 m_graph_exec->launch(s);
389 }
390 m_event_result = Event::QueryResult::eNotReady;
391 checkCudaErrors(cudaEventRecord(m_event, s));
392#if MUDA_CHECK_ON
393 if(Debug::is_debug_sync_all())
394 checkCudaErrors(cudaStreamSynchronize(s));
395#endif
396}
397
398MUDA_INLINE Event::QueryResult ComputeGraph::query() const
399{
400 if(m_event_result == Event::QueryResult::eNotReady)
401 m_event_result = m_event.query();
402 return m_event_result;
403}
404} // namespace muda
405
406/********************************************************************************
407 *
408 * Build Graph Dependencies
409 *
410 ********************************************************************************/
411namespace muda
412{
413namespace details
414{
415 MUDA_INLINE void process_node(std::vector<ComputeGraph::Dependency>& deps,
416 std::vector<ClosureId>& last_read_or_write_nodes,
417 std::vector<ClosureId>& last_write_nodes,
418 ComputeGraphClosure& closure,
419 const std::vector<std::pair<LocalVarId, ComputeGraphVarUsage>>& local_var_usage,
420 uint64_t& dep_begin,
421 uint64_t& dep_count)
422 {
423 auto is_read_write = [](ComputeGraphVarUsage usage)
424 { return usage == ComputeGraphVarUsage::ReadWrite; };
425 auto is_read_only = [](ComputeGraphVarUsage usage)
426 { return usage == ComputeGraphVarUsage::Read; };
427
428 std::unordered_set<ClosureId> unique_deps;
429
430 for(auto& [local_var_id, usage] : local_var_usage)
431 {
432 // if this is a written resource,
433 // this should depend on any write and read before it
434 // to get newest data or to avoid data corruption
435 if(is_read_write(usage))
436 {
437 auto dst_nid = last_read_or_write_nodes[local_var_id.value()];
438 if(dst_nid.is_valid())
439 {
440 // the last accessing node reads or writes this resrouce, so I should depend on it
441 if(unique_deps.find(dst_nid) == unique_deps.end())
442 {
443 // record this dependency
444 unique_deps.insert(dst_nid);
445 }
446 }
447 }
448 // if this is a read resource,
449 // this should depend on any write before it
450 // to get newest data
451 // but it has no need to depend on any read before it
452 else if(is_read_only(usage))
453 {
454 auto dst_nid = last_write_nodes[local_var_id.value()];
455 if(dst_nid.is_valid())
456 {
457 // the last accessing node writes this resrouce, so I should depend on it
458 if(unique_deps.find(dst_nid) == unique_deps.end())
459 {
460 // record this dependency
461 unique_deps.insert(dst_nid);
462 }
463 }
464 }
465 }
466
467 auto current_closure_id = closure.clousure_id();
468
469 // set up res node map with pair [res, node]
470 for(auto& [local_var_id, usage] : local_var_usage)
471 {
472 // if this is a write resource,
473 // the latter read/write kernel should depend on this
474 // to get the newest data.
475 if(is_read_write(usage))
476 {
477 last_read_or_write_nodes[local_var_id.value()] = current_closure_id;
478 last_write_nodes[local_var_id.value()] = current_closure_id;
479 }
480 // if this is a read resource,
481 // the latter write kernel should depend on this
482 // to avoid data corruption.
483 else if(is_read_only(usage))
484 {
485 last_read_or_write_nodes[local_var_id.value()] = current_closure_id;
486 }
487 }
488
489 // add dependencies to deps
490 dep_begin = deps.size();
491 for(auto dep : unique_deps)
492 deps.emplace_back(ComputeGraph::Dependency{dep, current_closure_id});
493 dep_count = unique_deps.size();
494 }
495} // namespace details
496
497MUDA_INLINE void ComputeGraph::cuda_graph_add_deps()
498{
499 std::vector<cudaGraphNode_t> froms;
500 froms.reserve(m_deps.size());
501 std::vector<cudaGraphNode_t> tos;
502 tos.reserve(m_deps.size());
503
504 // in closure deps
505
506 for(auto& [name, closure] : m_closures)
507 {
508 MUDA_ASSERT(closure->m_graph_nodes.size() > 0, "closure[%s] has no nodes", name.data());
509 if(closure->m_graph_nodes.size() == 1)
510 continue;
511
512 auto from = closure->m_graph_nodes.front();
513 auto to = from;
514 for(size_t i = 1; i < closure->m_graph_nodes.size(); ++i)
515 {
516 to = closure->m_graph_nodes[i];
517 froms.emplace_back(from->handle());
518 tos.emplace_back(to->handle());
519 from = to;
520 }
521 }
522
523
524 for(auto dep : m_deps)
525 {
526 auto from = m_closures[dep.from.value()].second->m_graph_nodes.back();
527 auto to = m_closures[dep.to.value()].second->m_graph_nodes.front();
528 froms.emplace_back(from->handle());
529 tos.emplace_back(to->handle());
530 };
531
532 checkCudaErrors(cudaGraphAddDependencies(
533 m_graph.handle(), froms.data(), tos.data(), froms.size()));
534}
535
536MUDA_INLINE void ComputeGraph::build_deps()
537{
538 m_deps.clear();
539 auto local_var_count = m_related_vars.size();
540
541 // map: var_id -> node_id, uint64_t{-1} means no write node yet
542 auto last_write_nodes = std::vector<ClosureId>(local_var_count, ClosureId{});
543 // map: var_id -> node_id, uint64_t{-1} means no read node yet
544 auto last_read_or_write_nodes = std::vector<ClosureId>(local_var_count, ClosureId{});
545
546 // process all nodes
547 for(size_t i = 0u; i < m_closures.size(); i++)
548 {
549 auto& [name, closure] = m_closures[i];
550
551 // map global var id to local var id
552 std::vector<std::pair<details::LocalVarId, ComputeGraphVarUsage>> local_var_usage;
553 local_var_usage.reserve(closure->var_usages().size());
554 for(auto&& [var_id, usage] : closure->var_usages())
555 {
556 auto local_id = m_global_to_local_var_id[var_id];
557 local_var_usage.emplace_back(local_id, usage);
558 }
559
560 size_t dep_begin, dep_count;
561 details::process_node(
562 m_deps, last_read_or_write_nodes, last_write_nodes, *closure, local_var_usage, dep_begin, dep_count);
563 closure->set_deps_range(dep_begin, dep_count);
564 }
565
566 m_is_topo_built = true;
567}
568} // namespace muda