MUDA
Loading...
Searching...
No Matches
device_triplet_matrix.h
1#pragma once
3#include <muda/ext/linear_system/triplet_matrix_view.h>
4namespace muda::details
5{
6template <typename T, int N>
7class MatrixFormatConverter;
8}
9
10namespace muda
11{
12template <typename T, int N>
14{
15 public:
16 template <typename U, int M>
18 using BlockMatrix = Eigen::Matrix<T, N, N>;
19
20 protected:
21 DeviceBuffer<BlockMatrix> m_block_values;
22 DeviceBuffer<int> m_block_row_indices;
23 DeviceBuffer<int> m_block_col_indices;
24
25 int m_block_rows = 0;
26 int m_block_cols = 0;
27
28 public:
29 DeviceTripletMatrix() = default;
30 ~DeviceTripletMatrix() = default;
33 DeviceTripletMatrix& operator=(const DeviceTripletMatrix&) = default;
34 DeviceTripletMatrix& operator=(DeviceTripletMatrix&&) = default;
35
36 void reshape(int row, int col)
37 {
38 m_block_rows = row;
39 m_block_cols = col;
40 }
41
42 void resize_triplets(size_t nonzero_count)
43 {
44 m_block_values.resize(nonzero_count);
45 m_block_row_indices.resize(nonzero_count);
46 m_block_col_indices.resize(nonzero_count);
47 }
48
49 void reserve_triplets(size_t nonzero_count)
50 {
51 m_block_values.reserve(nonzero_count);
52 m_block_row_indices.reserve(nonzero_count);
53 m_block_col_indices.reserve(nonzero_count);
54 }
55
56 void resize(int row, int col, size_t nonzero_count)
57 {
58 reshape(row, col);
59 resize_triplets(nonzero_count);
60 }
61
62 static constexpr int block_dim() { return N; }
63
64 auto block_values() { return m_block_values.view(); }
65 auto block_values() const { return m_block_values.view(); }
66 auto block_row_indices() { return m_block_row_indices.view(); }
67 auto block_row_indices() const { return m_block_row_indices.view(); }
68 auto block_col_indices() { return m_block_col_indices.view(); }
69 auto block_col_indices() const { return m_block_col_indices.view(); }
70
71 auto block_rows() const { return m_block_rows; }
72 auto block_cols() const { return m_block_cols; }
73 auto triplet_count() const { return m_block_values.size(); }
74 auto triplet_capacity() const { return m_block_values.capacity(); }
75
76 auto view()
77 {
78 return TripletMatrixView<T, N>{m_block_rows,
79 m_block_cols,
80 (int)m_block_values.size(),
81 m_block_row_indices.data(),
82 m_block_col_indices.data(),
83 m_block_values.data()};
84 }
85
86 auto view() const { return remove_const(*this).view().as_const(); }
87
88 auto cview() const { return view(); }
89
90 auto viewer() { return view().viewer(); }
91
92 auto cviewer() const { return view().cviewer(); }
93
94 operator TripletMatrixView<T, N>() { return view(); }
95 operator CTripletMatrixView<T, N>() const { return view(); }
96
97 void clear()
98 {
99 m_block_rows = 0;
100 m_block_cols = 0;
101 m_block_values.clear();
102 m_block_row_indices.clear();
103 m_block_col_indices.clear();
104 }
105};
106
107template <typename T>
109{
110 public:
111 template <typename U, int M>
113
114 protected:
115 DeviceBuffer<T> m_values;
116 DeviceBuffer<int> m_row_indices;
117 DeviceBuffer<int> m_col_indices;
118
119 int m_rows = 0;
120 int m_cols = 0;
121
122 public:
123 DeviceTripletMatrix() = default;
124 ~DeviceTripletMatrix() = default;
127 DeviceTripletMatrix& operator=(const DeviceTripletMatrix&) = default;
128 DeviceTripletMatrix& operator=(DeviceTripletMatrix&&) = default;
129
130 void reshape(int row, int col)
131 {
132 m_rows = row;
133 m_cols = col;
134 }
135
136 void resize_triplets(size_t nonzero_count)
137 {
138 m_values.resize(nonzero_count);
139 m_row_indices.resize(nonzero_count);
140 m_col_indices.resize(nonzero_count);
141 }
142
143 void reserve_triplets(size_t nonzero_count)
144 {
145 m_values.reserve(nonzero_count);
146 m_row_indices.reserve(nonzero_count);
147 m_col_indices.reserve(nonzero_count);
148 }
149
150 void resize(int row, int col, size_t nonzero_count)
151 {
152 reshape(row, col);
153 resize_triplets(nonzero_count);
154 }
155
156 static constexpr int block_size() { return 1; }
157
158 auto values() { return m_values.view(); }
159 auto values() const { return m_values.view(); }
160 auto row_indices() { return m_row_indices.view(); }
161 auto row_indices() const { return m_row_indices.view(); }
162 auto col_indices() { return m_col_indices.view(); }
163 auto col_indices() const { return m_col_indices.view(); }
164
165 auto rows() const { return m_rows; }
166 auto cols() const { return m_cols; }
167 auto triplet_count() const { return m_values.size(); }
168
169 auto view() const { return remove_const(*this).view().as_const(); }
170
171 auto view()
172 {
173 return TripletMatrixView<T, 1>{m_rows,
174 m_cols,
175 (int)m_values.size(),
176 m_row_indices.data(),
177 m_col_indices.data(),
178 m_values.data()};
179 }
180
181 auto viewer() { return view().viewer(); }
182 auto cviewer() const { return view().cviewer(); }
183
184 operator TripletMatrixView<T, 1>() { return view(); }
185 operator CTripletMatrixView<T, 1>() const { return view(); }
186
187 void clear()
188 {
189 m_rows = 0;
190 m_cols = 0;
191 m_values.clear();
192 m_row_indices.clear();
193 m_col_indices.clear();
194 }
195};
196} // namespace muda
197#include "details/device_triplet_matrix.inl"
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition device_triplet_matrix.h:14
Definition triplet_matrix_view.h:10
Definition matrix_format_converter_impl.h:53
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...