MUDA
Loading...
Searching...
No Matches
device_bsr_matrix.inl
1#include <muda/check/check_cusparse.h>
2#include <muda/ext/linear_system/type_mapper/data_type_mapper.h>
3
4namespace muda
5{
6template <typename Ty, int N>
7DeviceBSRMatrix<Ty, N>::~DeviceBSRMatrix()
8{
9 destroy_all_descr();
10}
11
12template <typename Ty, int N>
13DeviceBSRMatrix<Ty, N>::DeviceBSRMatrix(const DeviceBSRMatrix& other)
14 : m_row(other.m_row)
15 , m_col(other.m_col)
16 , m_block_row_offsets(other.m_block_row_offsets)
17 , m_block_col_indices(other.m_block_col_indices)
18 , m_block_values(other.m_block_values)
19{
20}
21
22template <typename Ty, int N>
23DeviceBSRMatrix<Ty, N>::DeviceBSRMatrix(DeviceBSRMatrix&& other)
24 : m_row(other.m_row)
25 , m_col(other.m_col)
26 , m_block_row_offsets(std::move(other.m_block_row_offsets))
27 , m_block_col_indices(std::move(other.m_block_col_indices))
28 , m_block_values(std::move(other.m_block_values))
29 , m_legacy_descr(other.m_legacy_descr)
30{
31 other.m_row = 0;
32 other.m_col = 0;
33 other.m_legacy_descr = nullptr;
34 other.m_descr = nullptr;
35}
36
37template <typename Ty, int N>
38DeviceBSRMatrix<Ty, N>& DeviceBSRMatrix<Ty, N>::operator=(const DeviceBSRMatrix& other)
39{
40 if(this != &other)
41 {
42 m_row = other.m_row;
43 m_col = other.m_col;
44 m_block_row_offsets = other.m_block_row_offsets;
45 m_block_col_indices = other.m_block_col_indices;
46 m_block_values = other.m_block_values;
47
48 destroy_all_descr();
49
50 m_legacy_descr = nullptr;
51 m_descr = nullptr;
52 }
53 return *this;
54}
55
56template <typename Ty, int N>
57DeviceBSRMatrix<Ty, N>& DeviceBSRMatrix<Ty, N>::operator=(DeviceBSRMatrix&& other)
58{
59 if(this != &other)
60 {
61 m_row = other.m_row;
62 m_col = other.m_col;
63 m_block_row_offsets = std::move(other.m_block_row_offsets);
64 m_block_col_indices = std::move(other.m_block_col_indices);
65 m_block_values = std::move(other.m_block_values);
66
67 destroy_all_descr();
68
69 m_legacy_descr = other.m_legacy_descr;
70 m_descr = other.m_descr;
71
72 other.m_row = 0;
73 other.m_col = 0;
74 other.m_legacy_descr = nullptr;
75 other.m_descr = nullptr;
76 }
77 return *this;
78}
79
80template <typename Ty, int N>
81void DeviceBSRMatrix<Ty, N>::reshape(int row, int col)
82{
83 m_row = row;
84 m_block_row_offsets.resize(row + 1);
85 m_col = col;
86 m_descr = nullptr;
87}
88template <typename Ty, int N>
89void DeviceBSRMatrix<Ty, N>::reserve(int non_zero_blocks)
90{
91 m_block_col_indices.reserve(non_zero_blocks);
92 m_block_values.reserve(non_zero_blocks);
93}
94template <typename Ty, int N>
95void DeviceBSRMatrix<Ty, N>::reserve_offsets(int size)
96{
97 m_block_row_offsets.reserve(size);
98}
99template <typename Ty, int N>
100void DeviceBSRMatrix<Ty, N>::resize(int non_zero_blocks)
101{
102 m_block_col_indices.resize(non_zero_blocks);
103 m_block_values.resize(non_zero_blocks);
104}
105template <typename Ty, int N>
106cusparseMatDescr_t DeviceBSRMatrix<Ty, N>::legacy_descr() const
107{
108 if(m_legacy_descr == nullptr)
109 {
110 checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
111 }
112 return m_legacy_descr;
113}
114
115template <typename Ty, int N>
116void DeviceBSRMatrix<Ty, N>::clear()
117{
118 m_row = 0;
119 m_col = 0;
120 m_block_row_offsets = decltype(m_block_row_offsets)();
121 m_block_col_indices = decltype(m_block_col_indices)();
122 m_block_values = decltype(m_block_values)();
123 destroy_all_descr();
124}
125
126template <typename Ty, int N>
127void DeviceBSRMatrix<Ty, N>::destroy_all_descr() const
128{
129 if(m_legacy_descr)
130 {
131 checkCudaErrors(cusparseDestroyMatDescr(m_legacy_descr));
132 m_legacy_descr = nullptr;
133 }
134 if(m_descr)
135 {
136 checkCudaErrors(cusparseDestroySpMat(m_descr));
137 m_descr = nullptr;
138 }
139}
140
141template <typename Ty, int N>
142cusparseSpMatDescr_t DeviceBSRMatrix<Ty, N>::descr() const
143{
144 if(m_descr == nullptr)
145 {
146 //checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
147 checkCudaErrors(cusparseCreateBsr(
148 &m_descr,
149 m_row,
150 m_col,
151 m_block_values.size(),
152 N,
153 N,
154 remove_const(m_block_row_offsets.data()),
155 remove_const(m_block_col_indices.data()),
156 remove_const(m_block_values.data()),
157 cusparse_index_type<typename decltype(m_block_row_offsets)::value_type>(),
158 cusparse_index_type<typename decltype(m_block_col_indices)::value_type>(),
159 CUSPARSE_INDEX_BASE_ZERO,
160 cuda_data_type<Ty>(),
161 cusparseOrder_t::CUSPARSE_ORDER_COL));
162 }
163 return m_descr;
164}
165} // namespace muda