MUDA
Loading...
Searching...
No Matches
device_bcoo_matrix.h
1#pragma once
3#include <muda/ext/linear_system/bcoo_matrix_view.h>
4#include <muda/ext/linear_system/device_triplet_matrix.h>
5#include <cusparse.h>
6#include <muda/ext/linear_system/type_mapper/data_type_mapper.h>
7
8namespace muda::details
9{
10 template <typename T, int N>
11 class MatrixFormatConverter;
12}
13
14namespace muda
15{
16template <typename T, int N>
18{
19 friend class details::MatrixFormatConverter<T, N>;
20
21 public:
22 using BlockMatrix = Eigen::Matrix<T, N, N>;
23
24 DeviceBCOOMatrix() = default;
25 ~DeviceBCOOMatrix() = default;
26 DeviceBCOOMatrix(const DeviceBCOOMatrix&) = default;
28 DeviceBCOOMatrix& operator=(const DeviceBCOOMatrix&) = default;
29 DeviceBCOOMatrix& operator=(DeviceBCOOMatrix&&) = default;
30 auto non_zero_blocks() const { return this->m_block_values.size(); }
31};
32
33template <typename Ty>
34class DeviceBCOOMatrix<Ty, 1> : public DeviceTripletMatrix<Ty, 1>
35{
36 template <typename U, int M>
38
39 protected:
40 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
41 mutable cusparseSpMatDescr_t m_descr = nullptr;
42
43 public:
44 DeviceBCOOMatrix() = default;
45 ~DeviceBCOOMatrix() { destroy_all_descr(); }
46
49 , m_legacy_descr{nullptr}
50 , m_descr{nullptr}
51 {
52 }
53
55 : DeviceTripletMatrix<Ty, 1>{std::move(other)}
56 , m_legacy_descr{other.m_legacy_descr}
57 , m_descr{other.m_descr}
58 {
59 other.m_legacy_descr = nullptr;
60 other.m_descr = nullptr;
61 }
62
63 DeviceBCOOMatrix& operator=(const DeviceBCOOMatrix& other)
64 {
65 if(this == &other)
66 return *this;
68 destroy_all_descr();
69 m_legacy_descr = nullptr;
70 m_descr = nullptr;
71 return *this;
72 }
73
74 DeviceBCOOMatrix& operator=(DeviceBCOOMatrix&& other)
75 {
76 if(this == &other)
77 return *this;
79 destroy_all_descr();
80 m_legacy_descr = other.m_legacy_descr;
81 m_descr = other.m_descr;
82 other.m_legacy_descr = nullptr;
83 other.m_descr = nullptr;
84 return *this;
85 }
86
87
88 auto view()
89 {
90 return COOMatrixView<Ty>{this->m_rows,
91 this->m_cols,
92 (int)this->m_values.size(),
93 this->m_row_indices.data(),
94 this->m_col_indices.data(),
95 this->m_values.data(),
96 descr(),
97 legacy_descr(),
98 false};
99 }
100
101 auto view() const
102 {
103 return CCOOMatrixView<Ty>{this->m_rows,
104 this->m_cols,
105 (int)this->m_values.size(),
106 this->m_row_indices.data(),
107 this->m_col_indices.data(),
108 this->m_values.data(),
109 descr(),
110 legacy_descr(),
111 false};
112 }
113
114 auto cview() const { return view(); }
115
116 auto viewer() { return view().viewer(); }
117
118 auto cviewer() const { return view().cviewer(); }
119
120 auto non_zeros() const { return this->m_values.size(); }
121
122 auto legacy_descr() const
123 {
124 if(m_legacy_descr == nullptr)
125 {
126 checkCudaErrors(cusparseCreateMatDescr(&m_legacy_descr));
127 checkCudaErrors(cusparseSetMatType(m_legacy_descr, CUSPARSE_MATRIX_TYPE_GENERAL));
128 checkCudaErrors(cusparseSetMatIndexBase(m_legacy_descr, CUSPARSE_INDEX_BASE_ZERO));
129 }
130 return m_legacy_descr;
131 }
132
133 auto descr() const
134 {
135 if(m_descr == nullptr)
136 {
137 checkCudaErrors(cusparseCreateCoo(&m_descr,
138 this->m_rows,
139 this->m_cols,
140 non_zeros(),
141 (void*)this->m_row_indices.data(),
142 (void*)this->m_col_indices.data(),
143 (void*)this->m_values.data(),
144 CUSPARSE_INDEX_32I,
145 CUSPARSE_INDEX_BASE_ZERO,
146 cuda_data_type<Ty>()));
147 }
148 return m_descr;
149 }
150
151 //auto T() const { return view().T(); }
152 //auto T() { return view().T(); }
153
154 operator COOMatrixView<Ty>() { return view(); }
155 operator CCOOMatrixView<Ty>() const { return view(); }
156
157 void clear()
158 {
160 destroy_all_descr();
161 }
162
163 private:
164 void destroy_all_descr()
165 {
166 if(m_legacy_descr != nullptr)
167 {
168 checkCudaErrors(cusparseDestroyMatDescr(m_legacy_descr));
169 m_legacy_descr = nullptr;
170 }
171 if(m_descr != nullptr)
172 {
173 checkCudaErrors(cusparseDestroySpMat(m_descr));
174 m_descr = nullptr;
175 }
176 }
177};
178
179template <typename T>
181} // namespace muda
182
183#include "details/device_bcoo_matrix.inl"
Definition bcoo_matrix_view.h:15
Definition device_bcoo_matrix.h:18
Definition device_triplet_matrix.h:14
Definition matrix_format_converter_impl.h:53
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...