MUDA
Loading...
Searching...
No Matches
triplet_matrix_viewer.h
1#pragma once
2#include <string>
3#include <muda/viewer/viewer_base.h>
5#include <muda/tools/cuda_vec_utils.h>
6#include <muda/ext/eigen/eigen_core_cxx20.h>
7
8
9/*
10* - 2024/2/23 remove viewer's subview, view's subview is enough
11*/
12
13namespace muda
14{
15template <bool IsConst, typename T, int N>
16class TripletMatrixViewerBase : public ViewerBase<IsConst>
17{
19 template <typename U>
20 using auto_const_t = typename Base::template auto_const_t<U>;
21
22 public:
23 using BlockMatrix = Eigen::Matrix<T, N, N>;
27
28 struct CTriplet
29 {
30 MUDA_GENERIC CTriplet(int row_index, int col_index, const BlockMatrix& block)
31 : block_row_index(row_index)
32 , block_col_index(col_index)
33 , block_value(block)
34 {
35 }
36 int block_row_index;
37 int block_col_index;
38 const BlockMatrix& block_value;
39 };
40
41
42 protected:
43 // matrix info
44 int m_total_block_rows = 0;
45 int m_total_block_cols = 0;
46
47 // triplet info
48 int m_triplet_index_offset = 0;
49 int m_triplet_count = 0;
50 int m_total_triplet_count = 0;
51
52 // sub matrix info
53 int2 m_submatrix_offset = {0, 0};
54 int2 m_submatrix_extent = {0, 0};
55
56 // data
57 auto_const_t<int>* m_block_row_indices;
58 auto_const_t<int>* m_block_col_indices;
59 auto_const_t<BlockMatrix>* m_block_values;
60
61
62 public:
63 MUDA_GENERIC TripletMatrixViewerBase() = default;
64 MUDA_GENERIC TripletMatrixViewerBase(int total_block_rows,
65 int total_block_cols,
66 int triplet_index_offset,
67 int triplet_count,
68 int total_triplet_count,
69
70 int2 submatrix_offset,
71 int2 submatrix_extent,
72
73 auto_const_t<int>* block_row_indices,
74 auto_const_t<int>* block_col_indices,
75 auto_const_t<BlockMatrix>* block_values)
76 : m_total_block_rows(total_block_rows)
77 , m_total_block_cols(total_block_cols)
78 , m_triplet_index_offset(triplet_index_offset)
79 , m_triplet_count(triplet_count)
80 , m_total_triplet_count(total_triplet_count)
81 , m_submatrix_offset(submatrix_offset)
82 , m_submatrix_extent(submatrix_extent)
83 , m_block_row_indices(block_row_indices)
84 , m_block_col_indices(block_col_indices)
85 , m_block_values(block_values)
86 {
87 MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
88 "TripletMatrixViewer [%s:%s]: out of range, m_total_triplet_count=%d, "
89 "your triplet_index_offset=%d, triplet_count=%d",
90 this->name(),
91 this->kernel_name(),
92 total_triplet_count,
93 triplet_index_offset,
94 triplet_count);
95
96 MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
97 "TripletMatrixViewer[%s:%s]: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
98 this->name(),
99 this->kernel_name(),
100 submatrix_offset.x,
101 submatrix_offset.y);
102
103 MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_block_rows,
104 "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
105 this->name(),
106 this->kernel_name(),
107 submatrix_offset.x,
108 submatrix_extent.x,
109 total_block_rows);
110
111 MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_block_cols,
112 "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
113 this->name(),
114 this->kernel_name(),
115 submatrix_offset.y,
116 submatrix_extent.y,
117 total_block_cols);
118 }
119
120 MUDA_GENERIC ConstViewer as_const() const
121 {
122 return ConstViewer{m_total_block_rows,
123 m_total_block_cols,
124 m_triplet_index_offset,
125 m_triplet_count,
126 m_total_triplet_count,
127 m_submatrix_offset,
128 m_submatrix_extent,
129 m_block_row_indices,
130 m_block_col_indices,
131 m_block_values};
132 }
133
134 MUDA_GENERIC operator ConstViewer() const { return as_const(); }
135
136 // const accessor
137
138 MUDA_GENERIC auto total_block_rows() const { return m_total_block_rows; }
139 MUDA_GENERIC auto total_block_cols() const { return m_total_block_cols; }
140 MUDA_GENERIC auto total_extent() const
141 {
142 return int2{m_total_block_rows, m_total_block_cols};
143 }
144
145 MUDA_GENERIC auto submatrix_offset() const { return m_submatrix_offset; }
146 MUDA_GENERIC auto extent() const { return m_submatrix_extent; }
147
148 MUDA_GENERIC auto triplet_count() const { return m_triplet_count; }
149 MUDA_GENERIC auto tripet_index_offset() const
150 {
151 return m_triplet_index_offset;
152 }
153 MUDA_GENERIC auto total_triplet_count() const
154 {
155 return m_total_triplet_count;
156 }
157
158 MUDA_GENERIC CTriplet operator()(int i) const
159 {
160 auto index = get_index(i);
161 auto global_i = m_block_row_indices[index];
162 auto global_j = m_block_col_indices[index];
163 auto sub_i = global_i - m_submatrix_offset.x;
164 auto sub_j = global_j - m_submatrix_offset.y;
165 check_in_submatrix(sub_i, sub_j);
166 return CTriplet{sub_i, sub_j, m_block_values[index]};
167 }
168
169 protected:
170 MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
171 {
172
173 MUDA_KERNEL_ASSERT(i >= 0 && i < m_triplet_count,
174 "TripletMatrixViewer [%s:%s]: triplet_index out of range, block_count=%d, your index=%d",
175 this->name(),
176 this->kernel_name(),
177 m_triplet_count,
178 i);
179 auto index = i + m_triplet_index_offset;
180 return index;
181 }
182
183 MUDA_INLINE MUDA_GENERIC void check_in_submatrix(int i, int j) const noexcept
184 {
185 MUDA_KERNEL_ASSERT(i >= 0 && i < m_submatrix_extent.x,
186 "TripletMatrixViewer [%s:%s]: row index out of submatrix range, submatrix_extent.x=%d, your i=%d",
187 this->name(),
188 this->kernel_name(),
189 m_submatrix_extent.x,
190 i);
191
192 MUDA_KERNEL_ASSERT(j >= 0 && j < m_submatrix_extent.y,
193 "TripletMatrixViewer [%s:%s]: col index out of submatrix range, submatrix_extent.y=%d, your j=%d",
194 this->name(),
195 this->kernel_name(),
196 m_submatrix_extent.y,
197 j);
198 }
199};
200
201template <typename T, int N>
203{
205 MUDA_VIEWER_COMMON_NAME(CTripletMatrixViewer);
206 using BlockMatrix = typename Base::BlockMatrix;
207
208 public:
209 using Base::Base;
210
211 MUDA_GENERIC CTripletMatrixViewer(const Base& base)
212 : Base(base)
213 {
214 }
215};
216
217template <typename T, int N>
219{
221 MUDA_VIEWER_COMMON_NAME(TripletMatrixViewer);
222
223 public:
224 using Base::Base;
225 using BlockMatrix = typename Base::BlockMatrix;
226 using CTriplet = typename Base::CTriplet;
229
230
231 MUDA_GENERIC TripletMatrixViewer(const Base& base)
232 : Base(base)
233 {
234 }
235
236 using Base::operator();
237
238 class Proxy
239 {
240 friend class TripletMatrixViewer;
241 TripletMatrixViewer& m_viewer;
242 int m_index = 0;
243
244 private:
245 MUDA_GENERIC Proxy(TripletMatrixViewer& viewer, int index)
246 : m_viewer(viewer)
247 , m_index(index)
248 {
249 }
250
251 public:
252 MUDA_GENERIC auto read() &&
253 {
254 return std::as_const(m_viewer).operator()(m_index);
255 }
256
257 MUDA_GENERIC
258 void write(int block_row_index, int block_col_index, const BlockMatrix& block) &&
259 {
260 auto index = m_viewer.get_index(m_index);
261
262 m_viewer.check_in_submatrix(block_row_index, block_col_index);
263
264 auto global_i = m_viewer.m_submatrix_offset.x + block_row_index;
265 auto global_j = m_viewer.m_submatrix_offset.y + block_col_index;
266
267 m_viewer.m_block_row_indices[index] = global_i;
268 m_viewer.m_block_col_indices[index] = global_j;
269 m_viewer.m_block_values[index] = block;
270 }
271
272 MUDA_GENERIC ~Proxy() = default;
273 };
274
275 MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
276};
277
278
279template <bool IsConst, typename T>
280class TripletMatrixViewerBase<IsConst, T, 1> : public ViewerBase<IsConst>
281{
283 protected:
284 template <typename U>
285 using auto_const_t = typename Base::template auto_const_t<U>;
286
287 public:
291
292 struct CTriplet
293 {
294 MUDA_GENERIC CTriplet(int row_index, int col_index, const T& block)
295 : row_index(row_index)
296 , col_index(col_index)
297 , value(block)
298 {
299 }
300 int row_index;
301 int col_index;
302 const T& value;
303 };
304
305 protected:
306 // matrix info
307 int m_total_rows = 0;
308 int m_total_cols = 0;
309
310 // triplet info
311 int m_triplet_index_offset = 0;
312 int m_triplet_count = 0;
313 int m_total_triplet_count = 0;
314
315 // sub matrix info
316 int2 m_submatrix_offset = {0, 0};
317 int2 m_submatrix_extent = {0, 0};
318
319 // data
320 auto_const_t<int>* m_row_indices;
321 auto_const_t<int>* m_col_indices;
322 auto_const_t<T>* m_values;
323
324 public:
325 MUDA_GENERIC TripletMatrixViewerBase() = default;
326 MUDA_GENERIC TripletMatrixViewerBase(int total_rows,
327 int total_cols,
328
329 int triplet_index_offset,
330 int triplet_count,
331 int total_triplet_count,
332
333 int2 submatrix_offset,
334 int2 submatrix_extent,
335
336 auto_const_t<int>* row_indices,
337 auto_const_t<int>* col_indices,
338 auto_const_t<T>* values)
339 : m_total_rows(total_rows)
340 , m_total_cols(total_cols)
341 , m_triplet_index_offset(triplet_index_offset)
342 , m_triplet_count(triplet_count)
343 , m_total_triplet_count(total_triplet_count)
344 , m_submatrix_offset(submatrix_offset)
345 , m_submatrix_extent(submatrix_extent)
346 , m_row_indices(row_indices)
347 , m_col_indices(col_indices)
348 , m_values(values)
349 {
350 MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
351 "TripletMatrixViewer [%s:%s]: out of range, m_total_triplet_count=%d, "
352 "your triplet_index_offset=%d, triplet_count=%d",
353 this->name(),
354 this->kernel_name(),
355 total_triplet_count,
356 triplet_index_offset,
357 triplet_count);
358
359 MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
360 "TripletMatrixViewer [%s:%s]: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
361 this->name(),
362 this->kernel_name(),
363 submatrix_offset.x,
364 submatrix_offset.y);
365
366 MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_rows,
367 "TripletMatrixViewer [%s:%s]: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, rows=%d",
368 this->name(),
369 this->kernel_name(),
370 submatrix_offset.x,
371 submatrix_extent.x,
372 total_rows);
373
374 MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_cols,
375 "TripletMatrixViewer [%s:%s]: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, cols=%d",
376 this->name(),
377 this->kernel_name(),
378 submatrix_offset.y,
379 submatrix_extent.y,
380 total_cols);
381 }
382
383 // implicit conversion
384
385 MUDA_GENERIC ConstViewer as_const() const
386 {
387 return ConstViewer{m_total_rows,
388 m_total_cols,
389 m_triplet_index_offset,
390 m_triplet_count,
391 m_total_triplet_count,
392 m_submatrix_offset,
393 m_submatrix_extent,
394 m_row_indices,
395 m_col_indices,
396 m_values};
397 }
398
399 MUDA_GENERIC operator ConstViewer() const { return as_const(); }
400
401
402 MUDA_GENERIC CTriplet operator()(int i) const
403 {
404 auto index = get_index(i);
405
406 auto global_i = m_row_indices[index];
407 auto global_j = m_col_indices[index];
408 auto sub_i = global_i - m_submatrix_offset.x;
409 auto sub_j = global_j - m_submatrix_offset.y;
410 check_in_submatrix(sub_i, sub_j);
411 return CTriplet{sub_i, sub_j, m_values[index]};
412 }
413
414 auto total_rows() const { return m_total_rows; }
415 auto total_cols() const { return m_total_cols; }
416
417 auto triplet_count() const { return m_triplet_count; }
418 auto tripet_index_offset() const { return m_triplet_index_offset; }
419 auto total_triplet_count() const { return m_total_triplet_count; }
420
421 auto submatrix_offset() const { return m_submatrix_offset; }
422 auto extent() const { return m_submatrix_extent; }
423 auto total_extent() const { return int2{m_total_rows, m_total_cols}; }
424
425 protected:
426 MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
427 {
428
429 MUDA_KERNEL_ASSERT(i >= 0 && i < m_triplet_count,
430 "TripletMatrixViewer [%s:%s]: triplet_index out of range, block_count=%d, your index=%d",
431 this->name(),
432 this->kernel_name(),
433 m_triplet_count,
434 i);
435 auto index = i + m_triplet_index_offset;
436 return index;
437 }
438
439 MUDA_INLINE MUDA_GENERIC void check_in_submatrix(int i, int j) const noexcept
440 {
441 MUDA_KERNEL_ASSERT(i >= 0 && i < m_submatrix_extent.x,
442 "TripletMatrixViewer [%s:%s]: row index out of submatrix range, submatrix_extent.x=%d, yours=%d",
443 this->name(),
444 this->kernel_name(),
445 m_submatrix_extent.x,
446 i);
447
448 MUDA_KERNEL_ASSERT(j >= 0 && j < m_submatrix_extent.y,
449 "TripletMatrixViewer [%s:%s]: col index out of submatrix range, submatrix_extent.y=%d, yours=%d",
450 this->name(),
451 this->kernel_name(),
452 m_submatrix_extent.y,
453 j);
454 }
455};
456
457template <typename T>
458class CTripletMatrixViewer<T, 1> : public TripletMatrixViewerBase<true, T, 1>
459{
461 MUDA_VIEWER_COMMON_NAME(CTripletMatrixViewer);
462
463 public:
464 using Base::Base;
466
467 MUDA_GENERIC CTripletMatrixViewer(const Base& base)
468 : Base(base)
469 {
470 }
471};
472
473template <typename T>
474class TripletMatrixViewer<T, 1> : public TripletMatrixViewerBase<false, T, 1>
475{
477 MUDA_VIEWER_COMMON_NAME(TripletMatrixViewer);
478
479 public:
482
483 using Base::Base;
484 using CTriplet = typename Base::CTriplet;
485 MUDA_GENERIC TripletMatrixViewer(const Base& base)
486 : Base(base)
487 {
488 }
489
490 class Proxy
491 {
492 friend class TripletMatrixViewer;
493 TripletMatrixViewer& m_viewer;
494 int m_index = 0;
495
496 private:
497 MUDA_GENERIC Proxy(TripletMatrixViewer& viewer, int index)
498 : m_viewer(viewer)
499 , m_index(index)
500 {
501 }
502
503 public:
504 MUDA_GENERIC auto read() &&
505 {
506 return std::as_const(m_viewer).operator()(m_index);
507 }
508
509 MUDA_GENERIC void write(int row_index, int col_index, const T& value) &&
510 {
511 auto index = m_viewer.get_index(m_index);
512 m_viewer.check_in_submatrix(row_index, col_index);
513
514 auto global_i = m_viewer.m_submatrix_offset.x + row_index;
515 auto global_j = m_viewer.m_submatrix_offset.y + col_index;
516
517 m_viewer.m_row_indices[index] = global_i;
518 m_viewer.m_col_indices[index] = global_j;
519 m_viewer.m_values[index] = value;
520 }
521
522 MUDA_GENERIC ~Proxy() = default;
523 };
524
525 using Base::operator();
526
527 MUDA_GENERIC Proxy operator()(int i)
528 {
529 auto index = Base::get_index(i);
530 return Proxy{*this, index};
531 }
532};
533} // namespace muda
534
535#include "details/triplet_matrix_viewer.inl"
Definition triplet_matrix_viewer.h:459
Definition triplet_matrix_viewer.h:203
Definition triplet_matrix_viewer.h:239
Definition triplet_matrix_viewer.h:475
Definition triplet_matrix_viewer.h:281
Definition triplet_matrix_viewer.h:17
Definition triplet_matrix_viewer.h:219
Definition viewer_base.h:18
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...
Definition triplet_matrix_viewer.h:29