MUDA
Loading...
Searching...
No Matches
bcoo_matrix_view.h
1#pragma once
2#include <muda/ext/linear_system/triplet_matrix_view.h>
3#include <muda/ext/linear_system/bcoo_matrix_viewer.h>
4namespace muda
5{
6template <typename T, int N>
7using BCOOMatrixView = TripletMatrixView<T, N>;
8template <typename T, int N>
9using CBCOOMatrixView = CTripletMatrixView<T, N>;
10} // namespace muda
11namespace muda
12{
13template <bool IsConst, typename Ty>
14class COOMatrixViewBase : public ViewBase<IsConst>
15{
16 using Base = ViewBase<IsConst>;
17 template<typename U>
18 using auto_const_t = typename Base::template auto_const_t<U>;
19
20 public:
21 static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
25
26 protected:
27 // matrix info
28 int m_rows = 0;
29 int m_cols = 0;
30
31 // triplet info
32 int m_triplet_index_offset = 0;
33 int m_triplet_count = 0;
34 int m_total_triplet_count = 0;
35
36 // sub matrix info
37 int2 m_submatrix_offset = {0, 0};
38 int2 m_submatrix_extent = {0, 0};
39
40 // data
41 auto_const_t<int>* m_row_indices;
42 auto_const_t<int>* m_col_indices;
43 auto_const_t<Ty>* m_values;
44
45 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
46 mutable cusparseSpMatDescr_t m_descr = nullptr;
47 bool m_trans = false;
48
49 public:
50 MUDA_GENERIC COOMatrixViewBase() = default;
51 MUDA_GENERIC COOMatrixViewBase(int rows,
52 int cols,
53 int triplet_index_offset,
54 int triplet_count,
55 int total_triplet_count,
56 int2 submatrix_offset,
57 int2 submatrix_extent,
58 auto_const_t<int>* row_indices,
59 auto_const_t<int>* col_indices,
60 auto_const_t<Ty>* values,
61 cusparseSpMatDescr_t descr,
62 cusparseMatDescr_t legacy_descr,
63 bool trans)
64
65 : m_rows(rows)
66 , m_cols(cols)
67 , m_triplet_index_offset(triplet_index_offset)
68 , m_triplet_count(triplet_count)
69 , m_total_triplet_count(total_triplet_count)
70 , m_row_indices(row_indices)
71 , m_col_indices(col_indices)
72 , m_values(values)
73 , m_submatrix_offset(submatrix_offset)
74 , m_submatrix_extent(submatrix_extent)
75 , m_descr(descr)
76 , m_legacy_descr(legacy_descr)
77 , m_trans(trans)
78 {
79 MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
80 "COOMatrixView: out of range, m_total_triplet_count=%d, "
81 "your triplet_index_offset=%d, triplet_count=%d",
82 total_triplet_count,
83 triplet_index_offset,
84 triplet_count);
85
86
87 MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
88 "TripletMatrixView: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
89 submatrix_offset.x,
90 submatrix_offset.y);
91
92 MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= rows,
93 "TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
94 submatrix_offset.x,
95 submatrix_extent.x,
96 rows);
97
98 MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= cols,
99 "TripletMatrixView: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
100 submatrix_offset.y,
101 submatrix_extent.y,
102 cols);
103 }
104
105 MUDA_GENERIC COOMatrixViewBase(int rows,
106 int cols,
107 int total_triplet_count,
108 auto_const_t<int>* row_indices,
109 auto_const_t<int>* col_indices,
110 auto_const_t<Ty>* values,
111 cusparseSpMatDescr_t descr,
112 cusparseMatDescr_t legacy_descr,
113 bool trans)
114 : COOMatrixViewBase(rows,
115 cols,
116 0,
117 total_triplet_count,
118 total_triplet_count,
119 {0, 0},
120 {rows, cols},
121 row_indices,
122 col_indices,
123 values,
124 descr,
125 legacy_descr,
126 trans)
127 {
128 }
129
130 MUDA_GENERIC auto as_const() const
131 {
132 return ConstView{m_rows,
133 m_cols,
134 m_triplet_index_offset,
135 m_triplet_count,
136 m_total_triplet_count,
137 m_submatrix_offset,
138 m_submatrix_extent,
139 m_row_indices,
140 m_col_indices,
141 m_values,
142 m_descr,
143 m_legacy_descr,
144 m_trans};
145 }
146
147 MUDA_GENERIC operator ConstView() const { return as_const(); }
148
149 MUDA_GENERIC auto cviewer() const
150 {
151 MUDA_KERNEL_ASSERT(!m_trans,
152 "COOMatrixView: cviewer() is not supported for "
153 "transposed matrix, please use a non-transposed view of this matrix");
154 return CTripletMatrixViewer<Ty, 1>{m_rows,
155 m_cols,
156 m_triplet_index_offset,
157 m_triplet_count,
158 m_total_triplet_count,
159 m_submatrix_offset,
160 m_submatrix_extent,
161 m_row_indices,
162 m_col_indices,
163 m_values};
164 }
165
166 MUDA_GENERIC auto viewer()
167 {
168 MUDA_ASSERT(!m_trans,
169 "COOMatrixView: viewer() is not supported for "
170 "transposed matrix, please use a non-transposed view of this matrix");
171 return TripletMatrixViewer<Ty, 1>{m_rows,
172 m_cols,
173 m_triplet_index_offset,
174 m_triplet_count,
175 m_total_triplet_count,
176 m_submatrix_offset,
177 m_submatrix_extent,
178 m_row_indices,
179 m_col_indices,
180 m_values};
181 }
182
183 // non-const access
184 auto_const_t<Ty>* block_values() { return m_values; }
185 auto_const_t<int>* block_row_indices() { return m_row_indices; }
186 auto_const_t<int>* block_col_indices() { return m_col_indices; }
187
188
189 // const access
190 auto block_values() const { return m_values; }
191 auto block_row_indices() const { return m_row_indices; }
192 auto block_col_indices() const { return m_col_indices; }
193
194 auto block_rows() const { return m_rows; }
195 auto block_cols() const { return m_cols; }
196 auto triplet_count() const { return m_triplet_count; }
197 auto tripet_index_offset() const { return m_triplet_index_offset; }
198 auto total_triplet_count() const { return m_total_triplet_count; }
199 auto is_trans() const { return m_trans; }
200
201 auto legacy_descr() const { return m_legacy_descr; }
202 auto descr() const { return m_descr; }
203};
204
205template <typename Ty>
207template <typename Ty>
209} // namespace muda
210
211namespace muda
212{
213template <typename T>
215{
216 using type = CCOOMatrixView<T>;
217};
218
219template <typename T>
221{
222 using type = COOMatrixView<T>;
223};
224} // namespace muda
225#include "details/bcoo_matrix_view.inl"
Definition bcoo_matrix_view.h:15
Definition triplet_matrix_viewer.h:203
Definition triplet_matrix_viewer.h:219
Definition view_base.h:8
Definition type_modifier.h:22
Definition type_modifier.h:28