MUDA
Loading...
Searching...
No Matches
triplet_matrix_view.h
1#pragma once
3#include <muda/ext/linear_system/triplet_matrix_viewer.h>
4#include <muda/view/view_base.h>
5
6namespace muda
7{
8template <bool IsConst, typename Ty, int N>
9class TripletMatrixViewBase : public ViewBase<IsConst>
10{
11 using Base = ViewBase<IsConst>;
12 template <typename U>
13 using auto_const_t = typename Base::template auto_const_t<U>;
14
15 public:
16 static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
20
21 private:
24 using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
25
26 public:
27 using BlockMatrix = Eigen::Matrix<Ty, N, N>;
28
29 protected:
30 // matrix info
31 int m_total_block_rows = 0;
32 int m_total_block_cols = 0;
33
34 // triplet info
35 int m_triplet_index_offset = 0;
36 int m_triplet_count = 0;
37 int m_total_triplet_count = 0;
38
39 // sub matrix info
40 int2 m_submatrix_offset = {0, 0};
41 int2 m_submatrix_extent = {0, 0};
42
43 // data
44 auto_const_t<int>* m_block_row_indices = nullptr;
45 auto_const_t<int>* m_block_col_indices = nullptr;
46 auto_const_t<BlockMatrix>* m_block_values = nullptr;
47
48 public:
49 MUDA_GENERIC TripletMatrixViewBase() = default;
50 MUDA_GENERIC TripletMatrixViewBase(int total_block_rows,
51 int total_block_cols,
52
53 int triplet_index_offset,
54 int triplet_count,
55 int total_triplet_count,
56
57 int2 submatrix_offset,
58 int2 submatrix_extent,
59
60 auto_const_t<int>* block_row_indices,
61 auto_const_t<int>* block_col_indices,
62 auto_const_t<BlockMatrix>* block_values)
63 : m_total_block_rows(total_block_rows)
64 , m_total_block_cols(total_block_cols)
65 , m_triplet_index_offset(triplet_index_offset)
66 , m_triplet_count(triplet_count)
67 , m_total_triplet_count(total_triplet_count)
68 , m_block_row_indices(block_row_indices)
69 , m_block_col_indices(block_col_indices)
70 , m_block_values(block_values)
71 , m_submatrix_offset(submatrix_offset)
72 , m_submatrix_extent(submatrix_extent)
73 {
74 MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
75 "TripletMatrixView: out of range, m_total_triplet_count=%d, "
76 "your triplet_index_offset=%d, triplet_count=%d",
77 total_triplet_count,
78 triplet_index_offset,
79 triplet_count);
80
81 MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
82 "TripletMatrixView: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
83 submatrix_offset.x,
84 submatrix_offset.y);
85
86 MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_block_rows,
87 "TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
88 submatrix_offset.x,
89 submatrix_extent.x,
90 total_block_rows);
91
92 MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_block_cols,
93 "TripletMatrixView: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
94 submatrix_offset.y,
95 submatrix_extent.y,
96 total_block_cols);
97 }
98
99 MUDA_GENERIC TripletMatrixViewBase(int total_block_rows,
100 int total_block_cols,
101 int total_triplet_count,
102 auto_const_t<int>* block_row_indices,
103 auto_const_t<int>* block_col_indices,
104 auto_const_t<BlockMatrix>* block_values)
105 : TripletMatrixViewBase(total_block_rows,
106 total_block_cols,
107 0,
108 total_triplet_count,
109 total_triplet_count,
110 {0, 0},
111 {total_block_rows, total_block_cols},
112 block_row_indices,
113 block_col_indices,
114 block_values)
115 {
116 }
117
118 // explicit conversion to non-const
119 MUDA_GENERIC ConstView as_const() const
120 {
121 return ConstView{m_total_block_rows,
122 m_total_block_cols,
123 m_triplet_index_offset,
124 m_triplet_count,
125 m_total_triplet_count,
126 m_submatrix_offset,
127 m_submatrix_extent,
128 m_block_row_indices,
129 m_block_col_indices,
130 m_block_values};
131 }
132
133 // implicit conversion to const
134 MUDA_GENERIC operator ConstView() const { return as_const(); }
135
136 MUDA_GENERIC auto subview(int offset, int count) const
137 {
138 MUDA_ASSERT(offset + count <= m_triplet_count,
139 "TripletMatrixView: offset is out of range, size=%d, your offset=%d, your count=%d",
140 m_triplet_count,
141 offset,
142 count);
143
144 return ThisView{m_total_block_rows,
145 m_total_block_cols,
146 m_triplet_index_offset + offset,
147 count,
148 m_total_triplet_count,
149 m_submatrix_offset,
150 m_submatrix_extent,
151 m_block_row_indices,
152 m_block_col_indices,
153 m_block_values};
154 }
155
156 MUDA_GENERIC auto subview(int offset) const
157 {
158 return subview(offset, m_triplet_count - offset);
159 }
160
161 MUDA_GENERIC auto cviewer() const
162 {
163 return ConstViewer{m_total_block_rows,
164 m_total_block_cols,
165 m_triplet_index_offset,
166 m_triplet_count,
167 m_total_triplet_count,
168 m_submatrix_offset,
169 m_submatrix_extent,
170 m_block_row_indices,
171 m_block_col_indices,
172 m_block_values};
173 }
174
175 MUDA_GENERIC auto viewer()
176 {
177 return ThisViewer{m_total_block_rows,
178 m_total_block_cols,
179 m_triplet_index_offset,
180 m_triplet_count,
181 m_total_triplet_count,
182 m_submatrix_offset,
183 m_submatrix_extent,
184 m_block_row_indices,
185 m_block_col_indices,
186 m_block_values};
187 }
188
189 // non-const access
190 MUDA_GENERIC auto_const_t<BlockMatrix>* block_values()
191 {
192 return m_block_values;
193 }
194 MUDA_GENERIC auto_const_t<int>* block_row_indices()
195 {
196 return m_block_row_indices;
197 }
198 MUDA_GENERIC auto_const_t<int>* block_col_indices()
199 {
200 return m_block_col_indices;
201 }
202
203 MUDA_GENERIC auto submatrix(int2 offset, int2 extent) const
204 {
205 MUDA_KERNEL_ASSERT(offset.x >= 0 && offset.y >= 0,
206 "TripletMatrixView: submatrix is out of range, offset=(%d, %d)",
207 offset.x,
208 offset.y);
209
210 MUDA_KERNEL_ASSERT(offset.x + extent.x <= m_submatrix_extent.x
211 && offset.y + extent.y <= m_submatrix_extent.y,
212 "TripletMatrixView: submatrix is out of range, offset=(%d, %d), extent=(%d, %d), origin offset=(%d,%d), extent(%d,%d).",
213 offset.x,
214 offset.y,
215 extent.x,
216 extent.y,
217 m_submatrix_offset.x,
218 m_submatrix_offset.y,
219 m_submatrix_extent.x,
220 m_submatrix_extent.y);
221
222 return ThisView{m_total_block_rows,
223 m_total_block_cols,
224 m_triplet_index_offset,
225 m_triplet_count,
226 m_total_triplet_count,
227 m_submatrix_offset + offset,
228 extent,
229 m_block_row_indices,
230 m_block_col_indices,
231 m_block_values};
232 }
233
234 // const access
235 MUDA_GENERIC auto total_block_rows() const { return m_total_block_rows; }
236 MUDA_GENERIC auto total_block_cols() const { return m_total_block_cols; }
237 MUDA_GENERIC auto total_extent() const
238 {
239 return int2{m_total_block_rows, m_total_block_cols};
240 }
241
242 MUDA_GENERIC auto submatrix_offset() const { return m_submatrix_offset; }
243 MUDA_GENERIC auto extent() const { return m_submatrix_extent; }
244
245 MUDA_GENERIC auto triplet_count() const { return m_triplet_count; }
246 MUDA_GENERIC auto tripet_index_offset() const
247 {
248 return m_triplet_index_offset;
249 }
250 MUDA_GENERIC auto total_triplet_count() const
251 {
252 return m_total_triplet_count;
253 }
254};
255
256template <bool IsConst, typename Ty>
257class TripletMatrixViewBase<IsConst, Ty, 1> : public ViewBase<IsConst>
258{
259 using Base = ViewBase<IsConst>;
260
261 protected:
262 template <typename U>
263 using auto_const_t = typename Base::template auto_const_t<U>;
264
265 public:
266 static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
270
271 private:
274 using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
275
276 protected:
277 // matrix info
278 int m_total_rows = 0;
279 int m_total_cols = 0;
280
281 // triplet info
282 int m_triplet_index_offset = 0;
283 int m_triplet_count = 0;
284 int m_total_triplet_count = 0;
285
286 // sub matrix info
287 int2 m_submatrix_offset = {0, 0};
288 int2 m_submatrix_extent = {0, 0};
289
290 // data
291 auto_const_t<int>* m_row_indices;
292 auto_const_t<int>* m_col_indices;
293 auto_const_t<Ty>* m_values;
294
295
296 public:
297 MUDA_GENERIC TripletMatrixViewBase() = default;
298
299 MUDA_GENERIC TripletMatrixViewBase(int total_rows,
300 int total_cols,
301 int triplet_index_offset,
302 int triplet_count,
303 int total_triplet_count,
304
305 int2 submatrix_offset,
306 int2 submatrix_extent,
307
308 auto_const_t<int>* row_indices,
309 auto_const_t<int>* col_indices,
310 auto_const_t<Ty>* values)
311 : m_total_rows(total_rows)
312 , m_total_cols(total_cols)
313 , m_triplet_index_offset(triplet_index_offset)
314 , m_triplet_count(triplet_count)
315 , m_total_triplet_count(total_triplet_count)
316 , m_submatrix_offset(submatrix_offset)
317 , m_submatrix_extent(submatrix_extent)
318 , m_row_indices(row_indices)
319 , m_col_indices(col_indices)
320 , m_values(values)
321 {
322 MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
323 "TripletMatrixView: out of range, m_total_triplet_count=%d, "
324 "your triplet_index_offset=%d, triplet_count=%d",
325 total_triplet_count,
326 triplet_index_offset,
327 triplet_count);
328
329 MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
330 "TripletMatrixView: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
331 submatrix_offset.x,
332 submatrix_offset.y);
333
334 MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_rows,
335 "TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_rows=%d",
336 submatrix_offset.x,
337 submatrix_extent.x,
338 total_rows);
339
340 MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_cols,
341 "TripletMatrixView: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_cols=%d",
342 submatrix_offset.y,
343 submatrix_extent.y,
344 total_cols);
345 }
346
347
348 MUDA_GENERIC TripletMatrixViewBase(int total_rows,
349 int total_cols,
350 int total_triplet_count,
351 auto_const_t<int>* row_indices,
352 auto_const_t<int>* col_indices,
353 auto_const_t<Ty>* values)
354 : TripletMatrixViewBase(total_rows,
355 total_cols,
356 0,
357 total_triplet_count,
358 total_triplet_count,
359 {0, 0},
360 {total_rows, total_cols},
361 row_indices,
362 col_indices,
363 values)
364 {
365 }
366
367
368 // explicit conversion to non-const
369 MUDA_GENERIC ConstView as_const() const
370 {
371 return ConstView{m_total_rows,
372 m_total_cols,
373 m_triplet_index_offset,
374 m_triplet_count,
375 m_total_triplet_count,
376 m_submatrix_offset,
377 m_submatrix_extent,
378 m_row_indices,
379 m_col_indices,
380 m_values};
381 }
382
383 // implicit conversion to const
384 MUDA_GENERIC operator ConstView() const { return as_const(); }
385
386 MUDA_GENERIC auto subview(int offset, int count) const
387 {
388 MUDA_ASSERT(offset + count < m_triplet_count,
389 "TripletMatrixView: offset is out of range, size=%d, your offset=%d, your count=%d",
390 m_triplet_count,
391 offset,
392 count);
393
394 return ThisView{m_total_rows,
395 m_total_cols,
396 m_triplet_index_offset + offset,
397 count,
398 m_total_triplet_count,
399 m_submatrix_offset,
400 m_submatrix_extent,
401 m_row_indices,
402 m_col_indices,
403 m_values};
404 }
405
406 MUDA_GENERIC auto submatrix(int2 offset, int2 extent) const
407 {
408 MUDA_KERNEL_ASSERT(offset.x >= 0 && offset.y >= 0,
409 "TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
410 offset.x,
411 offset.y);
412
413 MUDA_KERNEL_ASSERT(offset.x + extent.x <= m_submatrix_extent.x
414 && offset.y + extent.y <= m_submatrix_extent.y,
415 "TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, submatrix_offset.y=%d, submatrix_extent.y=%d",
416 offset.x,
417 m_submatrix_extent.x,
418 offset.y,
419 m_submatrix_extent.y);
420
421 return ThisView{m_total_rows,
422 m_total_cols,
423 m_triplet_index_offset,
424 m_triplet_count,
425 m_total_triplet_count,
426 m_submatrix_offset + offset,
427 extent,
428 m_row_indices,
429 m_col_indices,
430 m_values};
431 }
432
433 MUDA_GENERIC auto subview(int offset) const
434 {
435 return subview(offset, m_triplet_count - offset);
436 }
437
438 MUDA_GENERIC auto cviewer() const
439 {
440 return ConstViewer{m_total_rows,
441 m_total_cols,
442 m_triplet_index_offset,
443 m_triplet_count,
444 m_total_triplet_count,
445 m_submatrix_offset,
446 m_submatrix_extent,
447 m_row_indices,
448 m_col_indices,
449 m_values};
450 }
451
452 MUDA_GENERIC auto viewer()
453 {
454 return ThisViewer{m_total_rows,
455 m_total_cols,
456 m_triplet_index_offset,
457 m_triplet_count,
458 m_total_triplet_count,
459 m_submatrix_offset,
460 m_submatrix_extent,
461 m_row_indices,
462 m_col_indices,
463 m_values};
464 }
465
466 // non-const access
467 MUDA_GENERIC auto_const_t<Ty>* values() { return m_values; }
468 MUDA_GENERIC auto_const_t<int>* row_indices() { return m_row_indices; }
469 MUDA_GENERIC auto_const_t<int>* col_indices() { return m_col_indices; }
470
471
472 // const access
473 MUDA_GENERIC auto values() const { return m_values; }
474 MUDA_GENERIC auto row_indices() const { return m_row_indices; }
475 MUDA_GENERIC auto col_indices() const { return m_col_indices; }
476
477 MUDA_GENERIC auto total_rows() const { return m_total_rows; }
478 MUDA_GENERIC auto total_cols() const { return m_total_cols; }
479
480 MUDA_GENERIC auto triplet_count() const { return m_triplet_count; }
481 MUDA_GENERIC auto tripet_index_offset() const
482 {
483 return m_triplet_index_offset;
484 }
485 MUDA_GENERIC auto total_triplet_count() const
486 {
487 return m_total_triplet_count;
488 }
489
490 MUDA_GENERIC auto submatrix_offset() const { return m_submatrix_offset; }
491 MUDA_GENERIC auto extent() const { return m_submatrix_extent; }
492 MUDA_GENERIC auto total_extent() const
493 {
494 return int2{m_total_rows, m_total_cols};
495 }
496};
497
498template <typename Ty, int N>
500template <typename Ty, int N>
502} // namespace muda
503
504namespace muda
505{
506template <typename Ty, int N>
511
512template <typename Ty, int N>
517} // namespace muda
518
519
520#include "details/triplet_matrix_view.inl"
A view interface for any array-like liner memory, which can be constructed from DeviceBuffer/DeviceVe...
Definition triplet_matrix_viewer.h:203
Definition triplet_matrix_view.h:258
Definition triplet_matrix_view.h:10
Definition triplet_matrix_viewer.h:219
Definition view_base.h:8
Definition type_modifier.h:22
Definition type_modifier.h:28