1#include <muda/check/check_cusparse.h>
2#include <muda/ext/linear_system/type_mapper/data_type_mapper.h>
6template <
typename Ty,
int N>
7DeviceBSRMatrix<Ty, N>::~DeviceBSRMatrix()
12template <
typename Ty,
int N>
13DeviceBSRMatrix<Ty, N>::DeviceBSRMatrix(
const DeviceBSRMatrix& other)
16 , m_block_row_offsets(other.m_block_row_offsets)
17 , m_block_col_indices(other.m_block_col_indices)
18 , m_block_values(other.m_block_values)
22template <
typename Ty,
int N>
23DeviceBSRMatrix<Ty, N>::DeviceBSRMatrix(DeviceBSRMatrix&& other)
26 , m_block_row_offsets(std::move(other.m_block_row_offsets))
27 , m_block_col_indices(std::move(other.m_block_col_indices))
28 , m_block_values(std::move(other.m_block_values))
29 , m_legacy_descr(other.m_legacy_descr)
33 other.m_legacy_descr =
nullptr;
34 other.m_descr =
nullptr;
37template <
typename Ty,
int N>
38DeviceBSRMatrix<Ty, N>& DeviceBSRMatrix<Ty, N>::operator=(
const DeviceBSRMatrix& other)
44 m_block_row_offsets = other.m_block_row_offsets;
45 m_block_col_indices = other.m_block_col_indices;
46 m_block_values = other.m_block_values;
50 m_legacy_descr =
nullptr;
56template <
typename Ty,
int N>
57DeviceBSRMatrix<Ty, N>& DeviceBSRMatrix<Ty, N>::operator=(DeviceBSRMatrix&& other)
63 m_block_row_offsets = std::move(other.m_block_row_offsets);
64 m_block_col_indices = std::move(other.m_block_col_indices);
65 m_block_values = std::move(other.m_block_values);
69 m_legacy_descr = other.m_legacy_descr;
70 m_descr = other.m_descr;
74 other.m_legacy_descr =
nullptr;
75 other.m_descr =
nullptr;
80template <
typename Ty,
int N>
81void DeviceBSRMatrix<Ty, N>::reshape(
int row,
int col)
84 m_block_row_offsets.resize(row + 1);
88template <
typename Ty,
int N>
89void DeviceBSRMatrix<Ty, N>::reserve(
int non_zero_blocks)
91 m_block_col_indices.reserve(non_zero_blocks);
92 m_block_values.reserve(non_zero_blocks);
94template <
typename Ty,
int N>
95void DeviceBSRMatrix<Ty, N>::reserve_offsets(
int size)
97 m_block_row_offsets.reserve(size);
99template <
typename Ty,
int N>
100void DeviceBSRMatrix<Ty, N>::resize(
int non_zero_blocks)
102 m_block_col_indices.resize(non_zero_blocks);
103 m_block_values.resize(non_zero_blocks);
105template <
typename Ty,
int N>
106cusparseMatDescr_t DeviceBSRMatrix<Ty, N>::legacy_descr()
const
108 if(m_legacy_descr ==
nullptr)
110 checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
112 return m_legacy_descr;
115template <
typename Ty,
int N>
116void DeviceBSRMatrix<Ty, N>::clear()
120 m_block_row_offsets =
decltype(m_block_row_offsets)();
121 m_block_col_indices =
decltype(m_block_col_indices)();
122 m_block_values =
decltype(m_block_values)();
126template <
typename Ty,
int N>
127void DeviceBSRMatrix<Ty, N>::destroy_all_descr()
const
131 checkCudaErrors(cusparseDestroyMatDescr(m_legacy_descr));
132 m_legacy_descr =
nullptr;
136 checkCudaErrors(cusparseDestroySpMat(m_descr));
141template <
typename Ty,
int N>
142cusparseSpMatDescr_t DeviceBSRMatrix<Ty, N>::descr()
const
144 if(m_descr ==
nullptr)
147 checkCudaErrors(cusparseCreateBsr(
151 m_block_values.size(),
154 remove_const(m_block_row_offsets.data()),
155 remove_const(m_block_col_indices.data()),
156 remove_const(m_block_values.data()),
157 cusparse_index_type<
typename decltype(m_block_row_offsets)::value_type>(),
158 cusparse_index_type<
typename decltype(m_block_col_indices)::value_type>(),
159 CUSPARSE_INDEX_BASE_ZERO,
160 cuda_data_type<Ty>(),
161 cusparseOrder_t::CUSPARSE_ORDER_COL));