MUDA
Loading...
Searching...
No Matches
doublet_vector_viewer.h
1#pragma once
2
3#include <string>
4#include <muda/viewer/viewer_base.h>
5#include <muda/ext/eigen/eigen_core_cxx20.h>
6
7/*
8* - 2024/2/23 remove viewer's subview, view's subview is enough
9*/
10
11namespace muda
12{
13template <bool IsConst, typename T, int N>
14class DoubletVectorViewerBase : public ViewerBase<IsConst>
15{
17 template <typename U>
18 using auto_const_t = typename Base::template auto_const_t<U>;
19
20 public:
21 using SegmentVector = Eigen::Matrix<T, N, 1>;
25
26
27 struct CDoublet
28 {
29 MUDA_GENERIC CDoublet(int index, const SegmentVector& segment)
30 : index(index)
31 , segment_value(segment)
32 {
33 }
34 int index;
35 const SegmentVector& segment_value;
36 };
37
38 protected:
39 // vector info
40 int m_total_segment_count = 0;
41
42 // doublet info
43 int m_doublet_index_offset = 0;
44 int m_doublet_count = 0;
45 int m_total_doublet_count = 0;
46
47 // subvector info
48 int m_subvector_offset = 0;
49 int m_subvector_extent = 0;
50
51 // data
52 auto_const_t<int>* m_segment_indices;
53 auto_const_t<SegmentVector>* m_segment_values;
54
55 public:
56 MUDA_GENERIC DoubletVectorViewerBase() = default;
57 MUDA_GENERIC DoubletVectorViewerBase(int total_segment_count,
58 int doublet_index_offset,
59 int doublet_count,
60 int total_doublet_count,
61 int subvector_offset,
62 int subvector_extent,
63 auto_const_t<int>* segment_indices,
64 auto_const_t<SegmentVector>* segment_values)
65 : m_total_segment_count(total_segment_count)
66 , m_doublet_index_offset(doublet_index_offset)
67 , m_doublet_count(doublet_count)
68 , m_total_doublet_count(total_doublet_count)
69 , m_subvector_offset(subvector_offset)
70 , m_subvector_extent(subvector_extent)
71 , m_segment_indices(segment_indices)
72 , m_segment_values(segment_values)
73 {
74 MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
75 "DoubletVectorViewer: out of range, m_total_doublet_count=%d, "
76 "your doublet_index_offset=%d, doublet_count=%d",
77 m_total_doublet_count,
78 doublet_index_offset,
79 doublet_count);
80
81 MUDA_KERNEL_ASSERT(subvector_offset + subvector_extent <= total_segment_count,
82 "DoubletVectorViewer: out of range, m_total_segment_count=%d, "
83 "your subvector_offset=%d, subvector_extent=%d",
84 m_total_segment_count,
85 subvector_offset,
86 subvector_extent);
87 }
88
89 // implicit conversion
90
91 MUDA_GENERIC ConstViewer as_const() const noexcept
92 {
93 return ConstViewer{m_total_segment_count,
94 m_doublet_index_offset,
95 m_doublet_count,
96 m_total_doublet_count,
97 m_subvector_offset,
98 m_subvector_extent,
99 m_segment_indices,
100 m_segment_values};
101 }
102
103 MUDA_GENERIC operator ConstViewer() const noexcept { return as_const(); }
104
105 // const access
106 MUDA_GENERIC int doublet_count() const noexcept { return m_doublet_count; }
107 MUDA_GENERIC int total_doublet_count() const noexcept
108 {
109 return m_total_doublet_count;
110 }
111
112 MUDA_GENERIC CDoublet operator()(int i) const
113 {
114 auto index = get_index(i);
115 auto global_i = m_segment_indices[index];
116 auto sub_i = global_i - m_subvector_offset;
117
118 check_in_subvector(sub_i);
119 return CDoublet{sub_i, m_segment_values[index]};
120 }
121
122 protected:
123 MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
124 {
125 MUDA_KERNEL_ASSERT(i >= 0 && i < m_doublet_count,
126 "DoubletVectorViewer [%s:%s]: index out of range, m_doublet_count=%d, your index=%d",
127 this->name(),
128 this->kernel_name(),
129 m_doublet_count,
130 i);
131 auto index = i + m_doublet_index_offset;
132 return index;
133 }
134
135 MUDA_INLINE MUDA_GENERIC void check_in_subvector(int i) const noexcept
136 {
137 MUDA_KERNEL_ASSERT(i >= 0 && i < m_subvector_extent,
138 "DoubletVectorViewer [%s:%s]: index out of range, m_subvector_extent=%d, your index=%d",
139 this->name(),
140 this->kernel_name(),
141 m_subvector_extent,
142 i);
143 }
144};
145
146template <typename T, int N>
148{
150 MUDA_VIEWER_COMMON_NAME(CDoubletVectorViewer);
151
152 public:
153 using Base::Base;
154 using SegmentVector = typename Base::SegmentVector;
155 MUDA_GENERIC CDoubletVectorViewer(const Base& base)
156 : Base(base)
157 {
158 }
159};
160
161template <typename T, int N>
163{
165 MUDA_VIEWER_COMMON_NAME(DoubletVectorViewer);
166
167 public:
168 using SegmentVector = typename Base::SegmentVector;
169 using CDoublet = typename Base::CDoublet;
170 using Base::Base;
171 MUDA_GENERIC DoubletVectorViewer(const Base& base)
172 : Base(base)
173 {
174 }
175
176 using Base::operator();
177
178 class Proxy
179 {
180 friend class DoubletVectorViewer;
181 DoubletVectorViewer& m_viewer;
182 int m_index = 0;
183
184 private:
185 MUDA_GENERIC Proxy(DoubletVectorViewer& viewer, int index)
186 : m_viewer(viewer)
187 , m_index(index)
188 {
189 }
190
191 public:
192 MUDA_GENERIC auto read() &&
193 {
194 return std::as_const(m_viewer).operator()(m_index);
195 }
196
197 MUDA_GENERIC void write(int segment_i, const SegmentVector& block) &&
198 {
199 auto index = m_viewer.get_index(m_index);
200
201 m_viewer.check_in_subvector(segment_i);
202
203 auto global_i = segment_i + m_viewer.m_subvector_offset;
204
205 m_viewer.m_segment_indices[index] = global_i;
206 m_viewer.m_segment_values[index] = block;
207 }
208
209 MUDA_GENERIC ~Proxy() = default;
210 };
211
212 MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
213};
214
215template <bool IsConst, typename T>
216class DoubletVectorViewerBase<IsConst, T, 1> : public ViewerBase<IsConst>
217{
219 protected:
220 template <typename U>
221 using auto_const_t = typename Base::template auto_const_t<U>;
222 public:
226
227
228 struct CDoublet
229 {
230 MUDA_GENERIC CDoublet(int index, const T& segment)
231 : index(index)
232 , value(segment)
233 {
234 }
235 int index;
236 const T& value;
237 };
238
239 protected:
240 // vector info
241 int m_total_count = 0;
242
243 // doublet info
244 int m_doublet_index_offset = 0;
245 int m_doublet_count = 0;
246 int m_total_doublet_count = 0;
247
248 // subvector info
249 int m_subvector_offset = 0;
250 int m_subvector_extent = 0;
251
252 auto_const_t<int>* m_indices;
253 auto_const_t<T>* m_values;
254
255 public:
256 MUDA_GENERIC DoubletVectorViewerBase() = default;
257 MUDA_GENERIC DoubletVectorViewerBase(int total_count,
258 int doublet_index_offset,
259 int doublet_count,
260 int total_doublet_count,
261 int subvector_offset,
262 int subvector_extent,
263 auto_const_t<int>* indices,
264 auto_const_t<T>* values)
265 : m_total_count(total_count)
266 , m_doublet_index_offset(doublet_index_offset)
267 , m_doublet_count(doublet_count)
268 , m_total_doublet_count(total_doublet_count)
269 , m_indices(indices)
270 , m_values(values)
271 {
272 MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
273 "DoubletVectorViewer: out of range, m_total_doublet_count=%d, "
274 "your doublet_index_offset=%d, doublet_count=%d",
275 m_total_doublet_count,
276 doublet_index_offset,
277 doublet_count);
278
279 MUDA_KERNEL_ASSERT(subvector_offset + subvector_extent <= total_count,
280 "DoubletVectorViewer: out of range, m_total_segment_count=%d, "
281 "your subvector_offset=%d, subvector_extent=%d",
282 m_total_count,
283 subvector_offset,
284 subvector_extent);
285 }
286
287 // implicit conversion
288
289 MUDA_GENERIC ConstViewer as_const() const noexcept
290 {
291 return ConstViewer{m_total_count,
292 m_doublet_index_offset,
293 m_doublet_count,
294 m_total_doublet_count,
295 m_subvector_offset,
296 m_subvector_extent,
297 m_indices,
298 m_values};
299 }
300
301 MUDA_GENERIC operator ConstViewer() const noexcept { return as_const(); }
302
303 // non-const access
304
305 MUDA_GENERIC CDoublet operator()(int i) const
306 {
307 check_in_subvector(i);
308 auto index = get_index(i);
309 auto global_i = m_indices[index];
310 auto sub_i = global_i - m_subvector_offset;
311
312 return CDoublet{sub_i, m_values[index]};
313 }
314
315 MUDA_GENERIC int extent() const noexcept { return m_subvector_extent; }
316 MUDA_GENERIC int total_extent() const noexcept { return m_total_count; }
317
318 MUDA_GENERIC int subvector_offset() const noexcept
319 {
320 return m_subvector_offset;
321 }
322
323 MUDA_GENERIC int doublet_count() const noexcept { return m_doublet_count; }
324 MUDA_GENERIC int total_doublet_count() const noexcept
325 {
326 return m_total_doublet_count;
327 }
328
329 protected:
330 MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
331 {
332
333 MUDA_KERNEL_ASSERT(i >= 0 && i < m_doublet_count,
334 "DoubletVectorViewer [%s:%s]: index out of range, m_doublet_count=%d, your index=%d",
335 this->name(),
336 this->kernel_name(),
337 m_doublet_count,
338 i);
339 auto index = i + m_doublet_index_offset;
340 return index;
341 }
342
343 MUDA_INLINE MUDA_GENERIC void check_in_subvector(int i) const noexcept
344 {
345 MUDA_KERNEL_ASSERT(i >= 0 && i < m_subvector_extent,
346 "DoubletVectorViewer [%s:%s]: index out of range, m_subvector_extent=%d, your index=%d",
347 this->name(),
348 this->kernel_name(),
349 m_subvector_extent,
350 i);
351 }
352};
353
354template <typename T>
355class CDoubletVectorViewer<T, 1> : public DoubletVectorViewerBase<true, T, 1>
356{
358 MUDA_VIEWER_COMMON_NAME(CDoubletVectorViewer);
359
360 public:
361 using Base::Base;
363 MUDA_GENERIC CDoubletVectorViewer(const Base& base)
364 : Base(base)
365 {
366 }
367};
368
369template <typename T>
370class DoubletVectorViewer<T, 1> : public DoubletVectorViewerBase<false, T, 1>
371{
373 MUDA_VIEWER_COMMON_NAME(DoubletVectorViewer);
374
375 public:
376 using CDoublet = typename Base::CDoublet;
379 using Base::Base;
380 MUDA_GENERIC DoubletVectorViewer(const Base& base)
381 : Base(base)
382 {
383 }
384
385 using Base::operator();
386
387 class Proxy
388 {
389 friend class DoubletVectorViewer;
390 DoubletVectorViewer& m_viewer;
391 int m_index = 0;
392
393 private:
394 MUDA_GENERIC Proxy(DoubletVectorViewer& viewer, int index)
395 : m_viewer(viewer)
396 , m_index(index)
397 {
398 }
399
400 public:
401 MUDA_GENERIC auto read() &&
402 {
403 return std::as_const(m_viewer).operator()(m_index);
404 }
405
406 MUDA_GENERIC void write(int i, const T& value) &&
407 {
408 m_viewer.check_in_subvector(i);
409
410 auto index = m_viewer.get_index(m_index);
411
412 auto global_i = i + m_viewer.m_subvector_offset;
413 m_viewer.m_indices[index] = global_i;
414 m_viewer.m_values[index] = value;
415 }
416
417 MUDA_GENERIC ~Proxy() = default;
418 };
419
420 MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
421};
422} // namespace muda
423
424#include "details/doublet_vector_viewer.inl"
Definition doublet_vector_viewer.h:356
Definition doublet_vector_viewer.h:148
Definition doublet_vector_viewer.h:179
Definition doublet_vector_viewer.h:371
Definition doublet_vector_viewer.h:217
Definition doublet_vector_viewer.h:15
Definition doublet_vector_viewer.h:163
Definition viewer_base.h:18
Definition doublet_vector_viewer.h:28