MUDA
Loading...
Searching...
No Matches
compute_graph_var.inl
1#include <muda/compute_graph/compute_graph_builder.h>
2#include <muda/compute_graph/compute_graph_accessor.h>
3
4namespace muda
5{
6MUDA_INLINE void ComputeGraphVarBase::base_update()
7{
8 for(auto& [graph, info] : m_related_closure_infos)
9 {
10 graph->m_need_update = true;
11 for(auto& id : info.closure_ids)
12 graph->m_closure_need_update[id.value()] = true;
13 }
14 m_is_valid = true;
15}
16
17MUDA_INLINE void ComputeGraphVarBase::base_building_eval()
18{
19 _building_eval(ComputeGraphVarUsage::ReadWrite);
20}
21
22MUDA_INLINE void ComputeGraphVarBase::base_building_ceval() const
23{
24 _building_eval(ComputeGraphVarUsage::Read);
25}
26
27MUDA_INLINE void ComputeGraphVarBase::_building_eval(ComputeGraphVarUsage usage) const
28{
29 auto acc = details::ComputeGraphAccessor();
30 auto graph = ComputeGraphBuilder::instance().current_graph();
31 m_related_closure_infos[graph].closure_ids.insert(graph->current_closure_id());
32 graph->emplace_related_var(const_cast<ComputeGraphVarBase*>(this));
33 acc.set_var_usage(var_id(), usage);
34}
35
36MUDA_INLINE void ComputeGraphVarBase::remove_related_closure_infos(ComputeGraph* graph)
37{
38 auto iter = m_related_closure_infos.find(graph);
39 if(iter != m_related_closure_infos.end())
40 {
41 m_related_closure_infos.erase(iter);
42 }
43}
44
45MUDA_INLINE void ComputeGraphVarBase::graphviz_def(std::ostream& o,
46 const ComputeGraphGraphvizOptions& options) const
47{
48 graphviz_id(o, options);
49 o << "[";
50 if(!name().empty())
51 o << "label=\"" << name() << "\",";
52 o << options.var_style << "]";
53}
54
55MUDA_INLINE void ComputeGraphVarBase::graphviz_id(std::ostream& o,
56 const ComputeGraphGraphvizOptions& options) const
57{
58 o << "var_v" << var_id();
59}
60
61MUDA_INLINE void ComputeGraphVarBase::update()
62{
63 MUDA_ASSERT(!is_using(), "ComputeGraphVar is using, can't update");
64 this->base_update();
65}
66
67MUDA_INLINE Event::QueryResult ComputeGraphVarBase::query()
68{
69 for(auto& [graph, info] : m_related_closure_infos)
70 {
71 if(graph->query() == Event::QueryResult::eNotReady)
73 }
75}
76
77MUDA_INLINE bool ComputeGraphVarBase::is_using()
78{
79 return query() == Event::QueryResult::eNotReady;
80}
81
82MUDA_INLINE void ComputeGraphVarBase::sync()
83{
84 for (auto& [graph, info] : m_related_closure_infos)
85 {
86 checkCudaErrors(cudaEventSynchronize(graph->m_event));
87 }
88}
89
90template <typename RWView>
91RWView ComputeGraphVarBase::_eval(const RWView& view)
92{
93 auto phase = ComputeGraphBuilder::current_phase();
94 switch(phase)
95 {
96 //case ComputeGraphPhase::None: {
97 // MUDA_ERROR_WITH_LOCATION("ComputeGraphVar.eval() is not allowed outside Graph Closure");
98 //}
99 //break;
100 case ComputeGraphPhase::TopoBuilding:
101 case ComputeGraphPhase::Building: {
102 auto acc = details::ComputeGraphAccessor();
103 acc.check_allow_var_eval();
104 MUDA_ASSERT(ComputeGraphBuilder::is_topo_building() || is_valid(),
105 "ComputeGraphVar[%s] is not valid, please update it before use",
106 name().data());
107
108 constexpr auto const_eval = is_uniform_viewer_v<RWView>;
109
110 if constexpr(const_eval)
111 {
112 // they are all read only(e.g. host float/int ...)
113 this->base_building_ceval();
114 }
115 else
116 {
117 this->base_building_eval();
118 }
119 }
120 break;
121 case ComputeGraphPhase::Updating:
122 default: // nothing to do
123 break;
124 }
125 return view;
126}
127
128template <typename ROView>
129ROView ComputeGraphVarBase::_ceval(ROView& view) const
130{
131 auto phase = ComputeGraphBuilder::current_phase();
132 switch(phase)
133 {
134 //case ComputeGraphPhase::None: {
135 // MUDA_ERROR_WITH_LOCATION("ComputeGraphVar.eval() is not allowed outside Graph Closure");
136 //}
137 //break;
138 case ComputeGraphPhase::TopoBuilding:
139 case ComputeGraphPhase::Building: {
140 auto acc = details::ComputeGraphAccessor();
141 acc.check_allow_var_eval();
142 MUDA_ASSERT(ComputeGraphBuilder::is_topo_building() || is_valid(),
143 "ComputeGraphVar[%s] is not valid, please update it before use",
144 name().data());
145
146 this->base_building_ceval();
147 }
148 break;
149 case ComputeGraphPhase::Updating: {
150 // nothing to do
151 }
152 default:
153 break;
154 }
155 return view;
156}
157
158// ComputeGraphVar<T>:
159
160template <typename T>
161MUDA_INLINE void ComputeGraphVar<T>::update(const RWViewer& view)
162{
163 ComputeGraphVarBase::update();
164 m_value = view;
165}
166
167template <typename T>
168MUDA_INLINE ComputeGraphVar<T>& ComputeGraphVar<T>::operator=(const RWViewer& view)
169{
170 update(view);
171 return *this;
172}
173
174template <typename T>
175MUDA_INLINE void ComputeGraphVar<T>::graphviz_def(std::ostream& o,
176 const ComputeGraphGraphvizOptions& options) const
177{
178 graphviz_id(o, options);
179 o << "[";
180 if(!name().empty())
181 o << "label=\"" << name() << "\",";
182
183 if constexpr(std::is_same_v<T, cudaEvent_t>)
184 {
185 o << options.event_style;
186 }
187 else
188 {
189 o << options.var_style;
190 }
191
192 o << "]";
193}
194} // namespace muda
QueryResult
Definition event.h:28