MUDA
Loading...
Searching...
No Matches
compute_graph_var_manager.h
1#pragma once
2#include <driver_types.h>
3#include <memory>
4#include <unordered_map>
5#include <unordered_set>
6#include <vector>
7#include <memory>
8#include <muda/mstl/span.h>
9#include <muda/compute_graph/compute_graph_flag.h>
10#include <muda/compute_graph/compute_graph_fwd.h>
11#include <muda/compute_graph/graphviz_options.h>
12namespace muda
13{
15{
16 template <typename T>
17 using S = std::shared_ptr<T>;
18
19 public:
20 ComputeGraphVarManager() = default;
22
23 S<ComputeGraph> create_graph(std::string_view name = "graph",
24 ComputeGraphFlag flags = {});
25
26
27 /**************************************************************
28 *
29 * GraphVar API
30 *
31 ***************************************************************/
32 template <typename T>
33 ComputeGraphVar<T>& create_var(std::string_view name);
34 template <typename T>
35 ComputeGraphVar<T>& create_var(std::string_view name, const T& init_value);
36 template <typename T>
37 ComputeGraphVar<T>* find_var(std::string_view name);
38
39 bool is_using() const;
40 void sync() const;
41 void sync_on(cudaStream_t stream) const;
42
43 template <typename... T>
44 bool is_using(const ComputeGraphVar<T>&... vars) const;
45 template <typename... T>
46 void sync(const ComputeGraphVar<T>&... vars) const;
47 template <typename... T>
48 void sync_on(cudaStream_t stream, const ComputeGraphVar<T>&... vars) const;
49
50 bool is_using(const span<const ComputeGraphVarBase*> vars) const;
51 void sync(const span<const ComputeGraphVarBase*> vars) const;
52 void sync_on(cudaStream_t stream, const span<const ComputeGraphVarBase*> vars) const;
53
54 const auto& graphs() const { return m_graphs; }
55 void graphviz(std::ostream& os, const ComputeGraphGraphvizOptions& options = {}) const;
56
57 private:
58 friend class ComputeGraph;
59 friend class ComputeGraphNodeBase;
60 friend class ComputeGraphClosure;
61 std::vector<ComputeGraph*> unique_graphs(span<const ComputeGraphVarBase*> vars) const;
62 std::unordered_map<std::string, ComputeGraphVarBase*> m_vars_map;
63 std::vector<ComputeGraphVarBase*> m_vars;
64 std::unordered_set<ComputeGraph*> m_graphs;
65 span<const ComputeGraphVarBase*> var_span() const;
66};
67} // namespace muda
68
69#include "details/compute_graph_var_manager.inl"
Definition compute_graph_closure.h:15
Definition graphviz_options.h:6
Definition compute_graph.h:38
Definition compute_graph_node.h:13
Definition compute_graph_var.h:90
Definition compute_graph_var_manager.h:15