MUDA
Loading...
Searching...
No Matches
device_csr_matrix.inl
1#include <muda/check/check_cusparse.h>
2#include <muda/ext/linear_system/type_mapper/data_type_mapper.h>
3
4namespace muda
5{
6template <typename Ty>
7DeviceCSRMatrix<Ty>::~DeviceCSRMatrix()
8{
9 destroy_all_descr();
10}
11
12template <typename Ty>
13DeviceCSRMatrix<Ty>::DeviceCSRMatrix(const DeviceCSRMatrix& other)
14 : m_row(other.m_row)
15 , m_col(other.m_col)
16 , m_row_offsets(other.m_row_offsets)
17 , m_col_indices(other.m_col_indices)
18 , m_values(other.m_values)
19{
20}
21
22
23template <typename Ty>
24DeviceCSRMatrix<Ty>::DeviceCSRMatrix(DeviceCSRMatrix&& other) noexcept
25 : m_row(other.m_row)
26 , m_col(other.m_col)
27 , m_row_offsets(std::move(other.m_row_offsets))
28 , m_col_indices(std::move(other.m_col_indices))
29 , m_values(std::move(other.m_values))
30 , m_descr(other.m_descr)
31{
32 other.m_row = 0;
33 other.m_col = 0;
34 other.m_descr = nullptr;
35 other.m_legacy_descr = nullptr;
36}
37
38
39template <typename Ty>
40DeviceCSRMatrix<Ty>& DeviceCSRMatrix<Ty>::operator=(const DeviceCSRMatrix& other)
41{
42 if(this != &other)
43 {
44 m_row = other.m_row;
45 m_col = other.m_col;
46 m_row_offsets = other.m_row_offsets;
47 m_col_indices = other.m_col_indices;
48 m_values = other.m_values;
49 destroy_all_descr();
50
51 m_descr = nullptr;
52 m_legacy_descr = nullptr;
53 }
54 return *this;
55}
56
57
58template <typename Ty>
59DeviceCSRMatrix<Ty>& DeviceCSRMatrix<Ty>::operator=(DeviceCSRMatrix&& other) noexcept
60{
61 if(this != &other)
62 {
63 m_row = other.m_row;
64 m_col = other.m_col;
65 m_row_offsets = std::move(other.m_row_offsets);
66 m_col_indices = std::move(other.m_col_indices);
67 m_values = std::move(other.m_values);
68 destroy_all_descr();
69
70 m_descr = other.m_descr;
71 m_legacy_descr = other.m_legacy_descr;
72
73 other.m_row = 0;
74 other.m_col = 0;
75 other.m_descr = nullptr;
76 other.m_legacy_descr = nullptr;
77 }
78 return *this;
79}
80
81template <typename Ty>
82void DeviceCSRMatrix<Ty>::reshape(int row, int col)
83{
84 m_row = row;
85 m_row_offsets.resize(row + 1);
86 m_col = col;
87 destroy_all_descr();
88}
89
90template <typename Ty>
91void DeviceCSRMatrix<Ty>::reserve(int non_zeros)
92{
93 m_col_indices.reserve(non_zeros);
94 m_values.reserve(non_zeros);
95}
96
97template <typename Ty>
98cusparseSpMatDescr_t DeviceCSRMatrix<Ty>::descr() const
99{
100 if(m_descr == nullptr)
101 {
102 checkCudaErrors(cusparseCreateCsr(
103 &m_descr,
104 this->m_row,
105 this->m_col,
106 this->m_values.size(),
107 remove_const(this->m_row_offsets.data()),
108 remove_const(this->m_col_indices.data()),
109 remove_const(this->m_values.data()),
110 cusparse_index_type<typename decltype(this->m_row_offsets)::value_type>(),
111 cusparse_index_type<typename decltype(this->m_col_indices)::value_type>(),
112 CUSPARSE_INDEX_BASE_ZERO,
113 cuda_data_type<Ty>()));
114 }
115 return m_descr;
116}
117template <typename Ty>
118cusparseMatDescr_t DeviceCSRMatrix<Ty>::legacy_descr() const
119{
120 if(m_legacy_descr == nullptr)
121 {
122 checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
123 checkCudaErrors(cusparseSetMatType(m_legacy_descr, CUSPARSE_MATRIX_TYPE_GENERAL));
124 checkCudaErrors(cusparseSetMatIndexBase(m_legacy_descr, CUSPARSE_INDEX_BASE_ZERO));
125 checkCudaErrors(cusparseSetMatDiagType(m_legacy_descr, CUSPARSE_DIAG_TYPE_NON_UNIT));
126 }
127 return m_legacy_descr;
128}
129template <typename Ty>
130void DeviceCSRMatrix<Ty>::clear()
131{
132 m_row = 0;
133 m_col = 0;
134 m_row_offsets.clear();
135 m_col_indices.clear();
136 m_values.clear();
137 destroy_all_descr();
138}
139
140template <typename Ty>
141void DeviceCSRMatrix<Ty>::destroy_all_descr() const
142{
143 if(m_descr)
144 {
145 checkCudaErrors(cusparseDestroySpMat(m_descr));
146 m_descr = nullptr;
147 }
148 if(m_legacy_descr)
149 {
150 checkCudaErrors(cusparseDestroyMatDescr(m_legacy_descr));
151 m_legacy_descr = nullptr;
152 }
153}
154} // namespace muda