1#include <muda/check/check_cusparse.h>
2#include <muda/ext/linear_system/type_mapper/data_type_mapper.h>
7DeviceCSRMatrix<Ty>::~DeviceCSRMatrix()
13DeviceCSRMatrix<Ty>::DeviceCSRMatrix(
const DeviceCSRMatrix& other)
16 , m_row_offsets(other.m_row_offsets)
17 , m_col_indices(other.m_col_indices)
18 , m_values(other.m_values)
24DeviceCSRMatrix<Ty>::DeviceCSRMatrix(DeviceCSRMatrix&& other) noexcept
27 , m_row_offsets(std::move(other.m_row_offsets))
28 , m_col_indices(std::move(other.m_col_indices))
29 , m_values(std::move(other.m_values))
30 , m_descr(other.m_descr)
34 other.m_descr =
nullptr;
35 other.m_legacy_descr =
nullptr;
40DeviceCSRMatrix<Ty>& DeviceCSRMatrix<Ty>::operator=(
const DeviceCSRMatrix& other)
46 m_row_offsets = other.m_row_offsets;
47 m_col_indices = other.m_col_indices;
48 m_values = other.m_values;
52 m_legacy_descr =
nullptr;
59DeviceCSRMatrix<Ty>& DeviceCSRMatrix<Ty>::operator=(DeviceCSRMatrix&& other)
noexcept
65 m_row_offsets = std::move(other.m_row_offsets);
66 m_col_indices = std::move(other.m_col_indices);
67 m_values = std::move(other.m_values);
70 m_descr = other.m_descr;
71 m_legacy_descr = other.m_legacy_descr;
75 other.m_descr =
nullptr;
76 other.m_legacy_descr =
nullptr;
82void DeviceCSRMatrix<Ty>::reshape(
int row,
int col)
85 m_row_offsets.resize(row + 1);
91void DeviceCSRMatrix<Ty>::reserve(
int non_zeros)
93 m_col_indices.reserve(non_zeros);
94 m_values.reserve(non_zeros);
98cusparseSpMatDescr_t DeviceCSRMatrix<Ty>::descr()
const
100 if(m_descr ==
nullptr)
102 checkCudaErrors(cusparseCreateCsr(
106 this->m_values.size(),
107 remove_const(this->m_row_offsets.data()),
108 remove_const(this->m_col_indices.data()),
109 remove_const(this->m_values.data()),
110 cusparse_index_type<
typename decltype(this->m_row_offsets)::value_type>(),
111 cusparse_index_type<
typename decltype(this->m_col_indices)::value_type>(),
112 CUSPARSE_INDEX_BASE_ZERO,
113 cuda_data_type<Ty>()));
117template <
typename Ty>
118cusparseMatDescr_t DeviceCSRMatrix<Ty>::legacy_descr()
const
120 if(m_legacy_descr ==
nullptr)
122 checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
123 checkCudaErrors(cusparseSetMatType(m_legacy_descr, CUSPARSE_MATRIX_TYPE_GENERAL));
124 checkCudaErrors(cusparseSetMatIndexBase(m_legacy_descr, CUSPARSE_INDEX_BASE_ZERO));
125 checkCudaErrors(cusparseSetMatDiagType(m_legacy_descr, CUSPARSE_DIAG_TYPE_NON_UNIT));
127 return m_legacy_descr;
129template <
typename Ty>
130void DeviceCSRMatrix<Ty>::clear()
134 m_row_offsets.clear();
135 m_col_indices.clear();
140template <
typename Ty>
141void DeviceCSRMatrix<Ty>::destroy_all_descr()
const
145 checkCudaErrors(cusparseDestroySpMat(m_descr));
150 checkCudaErrors(cusparseDestroyMatDescr(m_legacy_descr));
151 m_legacy_descr =
nullptr;