MUDA
Loading...
Searching...
No Matches
compute_graph_var.h
1#pragma once
2#include <string>
3#include <set>
4#include <map>
5#include <muda/launch/event.h>
6#include <muda/mstl/span.h>
7#include <muda/type_traits/type_modifier.h>
8#include <muda/compute_graph/compute_graph_closure_id.h>
9#include <muda/compute_graph/compute_graph_var_usage.h>
10#include <muda/compute_graph/compute_graph_var_id.h>
11#include <muda/compute_graph/graphviz_options.h>
12#include <muda/compute_graph/compute_graph_fwd.h>
13
14namespace muda
15{
17{
18 std::string_view m_name;
19 ComputeGraphVarManager* m_var_manager = nullptr;
20 VarId m_var_id;
21 bool m_is_valid;
22
23 public:
24 std::string_view name() const MUDA_NOEXCEPT { return m_name; }
25 VarId var_id() const MUDA_NOEXCEPT { return m_var_id; }
26 bool is_valid() const MUDA_NOEXCEPT { return m_is_valid; }
27 void update();
28 Event::QueryResult query();
29 bool is_using();
30 void sync();
31 virtual void graphviz_def(std::ostream& os,
32 const ComputeGraphGraphvizOptions& options) const;
33 virtual void graphviz_id(std::ostream& os, const ComputeGraphGraphvizOptions& options) const;
34
35 protected:
36 template <typename RWView>
37 RWView _eval(const RWView& view);
38 template <typename ROView>
39 ROView _ceval(ROView& view) const;
40
41 friend class ComputeGraph;
42 friend class ComputeGraphVarManager;
43
45 std::string_view name,
46 VarId var_id) MUDA_NOEXCEPT : m_var_manager(var_manager),
47 m_name(name),
48 m_var_id(var_id),
49 m_is_valid(false)
50 {
51 }
52
54 std::string_view name,
55 VarId var_id,
56 bool is_valid) MUDA_NOEXCEPT : m_var_manager(var_manager),
57 m_name(name),
58 m_var_id(var_id),
59 m_is_valid(is_valid)
60 {
61 }
62
63 virtual ~ComputeGraphVarBase() = default;
64
65
66 void base_update();
67
68 friend class LaunchCore;
69
70 mutable std::set<ClosureId> m_closure_ids;
71
72 private:
73 void _building_eval(ComputeGraphVarUsage usage) const;
74 void base_building_eval();
75 void base_building_ceval() const;
76 void remove_related_closure_infos(ComputeGraph* graph);
77
78 class RelatedClosureInfo
79 {
80 public:
81 ComputeGraph* graph;
82 std::set<ClosureId> closure_ids;
83 };
84
85 mutable std::map<ComputeGraph*, RelatedClosureInfo> m_related_closure_infos;
86};
87
88template <typename T>
90{
91 public:
92 static_assert(!std::is_const_v<T>, "T must not be const");
93 using ROViewer = read_only_viewer_t<T>;
94 using RWViewer = T;
95 static_assert(std::is_convertible_v<RWViewer, ROViewer>,
96 "RWViewer must be convertible to ROView");
97
98 protected:
99 friend class ComputeGraph;
100 friend class ComputeGraphVarManager;
101
102 using ComputeGraphVarBase::ComputeGraphVarBase;
103
104 ComputeGraphVar(ComputeGraphVarManager* var_manager, std::string_view name, VarId var_id) MUDA_NOEXCEPT
105 : ComputeGraphVarBase(var_manager, name, var_id)
106 {
107 }
108
110 std::string_view name,
111 VarId var_id,
112 const T& init_value) MUDA_NOEXCEPT
113 : ComputeGraphVarBase(var_manager, name, var_id, true),
114 m_value(init_value)
115 {
116 }
117
118 virtual ~ComputeGraphVar() = default;
119
120 public:
121 RWViewer eval() { return _eval(m_value); }
122 ROViewer ceval() const { return _ceval(m_value); }
123
124 operator ROViewer() const { return ceval(); }
125 operator RWViewer() { return eval(); }
126
127 void update(const RWViewer& view);
128 ComputeGraphVar<T>& operator=(const RWViewer& view);
129 virtual void graphviz_def(std::ostream& os,
130 const ComputeGraphGraphvizOptions& options) const override;
131
132 private:
133 RWViewer m_value;
134};
135
136// for host memory
137template <typename T>
139{
140 using type = const T*;
141};
142template <typename T>
143struct read_write_viewer<const T*>
144{
145 using type = T*;
146};
147
148// for cuda event
149template <>
150struct read_only_viewer<cudaEvent_t>
151{
152 using type = cudaEvent_t;
153};
154template <>
155struct read_write_viewer<cudaEvent_t>
156{
157 using type = cudaEvent_t;
158};
159
160} // namespace muda
161
162
163#include "details/compute_graph_var.inl"
Definition graphviz_options.h:6
Definition compute_graph.h:38
Definition compute_graph_var.h:17
Definition compute_graph_var.h:90
Definition compute_graph_var_manager.h:15
QueryResult
Definition event.h:28
Definition launch_base.h:42
Definition compute_graph_var_id.h:6
Definition type_modifier.h:22
Definition type_modifier.h:28