MUDA
Loading...
Searching...
No Matches
linear_system_context.h
1#pragma once
2#include <cublas_v2.h>
3#include <cusparse_v2.h>
4#include <cusolverDn.h>
5#include <cusolverSp.h>
6#include <list>
8#include <muda/literal/unit.h>
9#include <muda/mstl/span.h>
10#include <muda/ext/linear_system/dense_vector_view.h>
11#include <muda/ext/linear_system/dense_matrix_view.h>
12#include <muda/ext/linear_system/matrix_format_converter.h>
13#include <muda/ext/linear_system/linear_system_handles.h>
14#include <muda/ext/linear_system/linear_system_solve_tolerance.h>
15#include <muda/ext/linear_system/linear_system_solve_reorder.h>
16namespace muda
17{
19{
20 public:
21 cudaStream_t stream = nullptr;
22 // base size of temp buffer, if buffer is not enough
23 // we create a new buffer with size = buffer_byte_size_base * 2 / 4 / 8 / 16 / ...
24 // and we will not release the old buffer because of safety
25 size_t buffer_byte_size_base = 256_M;
26};
28{
29 private:
30 LinearSystemHandles m_handles;
31 std::list<DeviceBuffer<std::byte>> m_buffers;
32 std::list<std::vector<std::byte>> m_host_buffers;
33 DeviceBuffer<std::byte> m_scalar_buffer;
34
36 std::list<std::function<void()>> m_sync_callbacks;
37 std::string m_current_label;
38
39 void set_pointer_mode_device();
40 void set_pointer_mode_host();
41 void shrink_temp_buffers();
42 void add_sync_callback(std::function<void()>&& callback);
43 BufferView<std::byte> temp_buffer(size_t size);
44 span<std::byte> temp_host_buffer(size_t size);
45 template <typename T>
46 BufferView<T> temp_buffer(size_t size);
47 template <typename T>
48 span<T> temp_host_buffer(size_t size);
49 template <typename T>
50 std::vector<T*> temp_buffers(size_t size_in_buffer, size_t num_buffer);
51 template <typename T>
52 std::vector<T*> temp_host_buffers(size_t size_in_buffer, size_t num_buffer);
53
56 MatrixFormatConverter m_converter;
57
58 private:
59 auto cublas() const { return m_handles.cublas(); }
60 auto cusparse() const { return m_handles.cusparse(); }
61 auto cusolver_dn() const { return m_handles.cusolver_dn(); }
62 auto cusolver_sp() const { return m_handles.cusolver_sp(); }
63
64 public:
67 LinearSystemContext& operator=(const LinearSystemContext&) = delete;
69 LinearSystemContext& operator=(LinearSystemContext&&) = delete;
71
72 void label(std::string_view label) { m_current_label = label; }
73 auto label() const -> std::string_view { return m_current_label; }
74 auto stream() const { return m_handles.stream(); }
75 void stream(cudaStream_t stream);
76 void sync();
77
78 /***********************************************************************************************
79 Settings
80 ***********************************************************************************************/
81
82 auto& tolerance() { return m_tolerance; }
83 auto& reorder() { return m_reorder; }
84 auto reserve_ratio() const { return m_handles.m_reserve_ratio; }
85 void reserve_ratio(float ratio) { m_handles.m_reserve_ratio = ratio; }
86
87
88 public:
89 /***********************************************************************************************
90 Converter
91 ***********************************************************************************************/
92 // Triplet -> BCOO
93 template <typename T, int N>
94 void convert(const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to);
95
96 // BCOO -> Dense Matrix
97 template <typename T, int N>
98 void convert(const DeviceBCOOMatrix<T, N>& from,
100 bool clear_dense_matrix = true);
101
102 // BCOO -> COO
103 template <typename T, int N>
104 void convert(const DeviceBCOOMatrix<T, N>& from, DeviceCOOMatrix<T>& to);
105
106 // BCOO -> BSR
107 template <typename T, int N>
108 void convert(const DeviceBCOOMatrix<T, N>& from, DeviceBSRMatrix<T, N>& to);
109
110 // Doublet -> BCOO
111 template <typename T, int N>
112 void convert(const DeviceDoubletVector<T, N>& from, DeviceBCOOVector<T, N>& to);
113
114 // BCOO -> Dense Vector
115 template <typename T, int N>
116 void convert(const DeviceBCOOVector<T, N>& from,
118 bool clear_dense_vector = true);
119
120 // Doublet -> Dense Vector
121 template <typename T, int N>
122 void convert(const DeviceDoubletVector<T, N>& from,
124 bool clear_dense_vector = true);
125
126 // BSR -> CSR
127 template <typename T, int N>
128 void convert(const DeviceBSRMatrix<T, N>& from, DeviceCSRMatrix<T>& to);
129
130 // Triplet -> COO
131 template <typename T>
132 void convert(const DeviceTripletMatrix<T, 1>& from, DeviceCOOMatrix<T>& to);
133
134 // COO -> Dense Matrix
135 template <typename T>
136 void convert(const DeviceCOOMatrix<T>& from,
138 bool clear_dense_matrix = true);
139
140 // COO -> CSR
141 template <typename T>
142 void convert(const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to);
143 template <typename T>
144 void convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to);
145
146 // Doublet -> COO
147 template <typename T>
148 void convert(const DeviceDoubletVector<T, 1>& from, DeviceCOOVector<T>& to);
149
150 // COO -> Dense Vector
151 template <typename T>
152 void convert(const DeviceCOOVector<T>& from,
154 bool clear_dense_vector = true);
155
156 // Doublet -> Dense Vector
157 template <typename T>
158 void convert(const DeviceDoubletVector<T, 1>& from,
160 bool clear_dense_vector = true);
161
162 public:
163 /***********************************************************************************************
164 Norm
165 ***********************************************************************************************/
166 template <typename T>
167 T norm(CDenseVectorView<T> x);
168 template <typename T>
169 void norm(CDenseVectorView<T> x, VarView<T> result);
170 template <typename T>
171 void norm(CDenseVectorView<T> x, T* result);
172
173 /***********************************************************************************************
174 Dot
175 ***********************************************************************************************/
176 template <typename T>
178 template <typename T>
180 template <typename T>
181 void dot(CDenseVectorView<T> x, CDenseVectorView<T> y, T* result);
182
183 /***********************************************************************************************
184 Max/Min
185 ***********************************************************************************************/
186 //TODO:
187
188
189 /***********************************************************************************************
190 Axpby
191 y = alpha * x + beta * y
192 ***********************************************************************************************/
193 // y = alpha * x + beta * y
194 template <typename T>
195 void axpby(const T& alpha, CDenseVectorView<T> x, const T& beta, DenseVectorView<T> y);
196 // y = alpha * x + beta * y
197 template <typename T>
199 // z = x + y
200 template <typename T>
202
203 /***********************************************************************************************
204 Spmv
205 y = a * A * x + b * y
206 ***********************************************************************************************/
207 // BSR
208 template <typename T, int N>
209 void spmv(const T& a,
212 const T& b,
214 template <typename T, int N>
216 // CSR
217 template <typename T>
218 void spmv(const T& a, CCSRMatrixView<T> A, CDenseVectorView<T> x, const T& b, DenseVectorView<T>& y);
219 template <typename T>
221 // BCOO & Triplet
222 template <typename T, int N>
223 void spmv(const T& a,
226 const T& b,
228 template <typename T, int N>
230 // COO
231 template <typename T>
232 void spmv(const T& a, CCOOMatrixView<T> A, CDenseVectorView<T> x, const T& b, DenseVectorView<T>& y);
233 template <typename T>
235
236
237 /***********************************************************************************************
238 Mv
239 y = a * A * x + b * y
240 ***********************************************************************************************/
241 template <typename T>
242 void mv(CDenseMatrixView<T> A,
243 const T& alpha,
245 const T& beta,
247 template <typename T>
248 void mv(CDenseMatrixView<T> A,
249 CVarView<T> alpha,
251 CVarView<T> beta,
253 template <typename T>
255
256 /***********************************************************************************************
257 Solve
258 A * x = b
259 ***********************************************************************************************/
260 // solve Ax = b, A will be modified for factorization
261 // and b will be modified to store the solution
262 template <typename T>
263 void solve(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
264 // solve Ax = b
265 // A is the CSR Matrix
266 template <typename T>
268
269 private:
270 template <typename T>
271 void generic_spmv(const T& a,
272 cusparseOperation_t op,
273 cusparseSpMatDescr_t A,
274 const cusparseDnVecDescr* x,
275 const T& b,
276 cusparseDnVecDescr_t y);
277 template <typename T>
278 void sysv(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
279 template <typename T>
280 void gesv(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
281};
282} // namespace muda
283
284#include "details/linear_system_context.inl"
285#include "details/routines/convert.inl"
286#include "details/routines/norm.inl"
287#include "details/routines/dot.inl"
288#include "details/routines/axpby.inl"
289#include "details/routines/spmv.inl"
290#include "details/routines/mv.inl"
291#include "details/routines/solve.inl"
292#include "details/routines/mm.inl"
Definition bsr_matrix_view.h:8
Definition dense_matrix_view.h:93
Definition bcoo_matrix_view.h:15
Definition csr_matrix_view.h:8
Definition dense_matrix_view.h:108
Definition dense_vector_view.h:10
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:28
Definition device_bcoo_vector.h:10
Definition device_bsr_matrix.h:16
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_triplet_matrix.h:14
Definition linear_system_context.h:19
Definition linear_system_context.h:28
Definition linear_system_handles.h:16
Definition linear_system_solve_reorder.h:11
Definition linear_system_solve_tolerance.h:7
Definition matrix_format_converter.h:48
Definition triplet_matrix_view.h:10
Definition var_view.h:11
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...