MUDA
Loading...
Searching...
No Matches
solve_sparse.inl
1#include <muda/check/check_cusolver.h>
2namespace muda
3{
4// using T = float;
5namespace details::linear_system
6{
7 template <typename T>
8 void svqr(cusolverSpHandle_t handle,
9 int m,
10 int nnz,
11 const cusparseMatDescr_t descrA,
12 const T* csrValA,
13 const int* csrRowPtrA,
14 const int* csrColIndA,
15 const T* b,
16 T tol,
17 int reorder,
18 T* x,
19 int* singularity)
20
21 {
22 if constexpr(std::is_same_v<T, float>)
23 {
24 checkCudaErrors(cusolverSpScsrlsvqr(
25 handle, m, nnz, descrA, csrValA, csrRowPtrA, csrColIndA, b, tol, reorder, x, singularity));
26 }
27 else if constexpr(std::is_same_v<T, double>)
28 {
29 checkCudaErrors(cusolverSpDcsrlsvqr(
30 handle, m, nnz, descrA, csrValA, csrRowPtrA, csrColIndA, b, tol, reorder, x, singularity));
31 }
32 else
33 {
34 static_assert(always_false_v<T>, "Unsupported type");
35 }
36 }
37} // namespace details::linear_system
38
39
40template <typename T>
41void LinearSystemContext::solve(DenseVectorView<T> x, CCSRMatrixView<T> A, CDenseVectorView<T> b)
42{
43 MUDA_ASSERT(!A.is_trans(), "CSRMatrix A must not be transposed");
44
45 auto handle = cusolver_sp();
46
47 auto singularity = std::make_shared<int>(0);
48
49 details::linear_system::svqr(handle,
50 A.rows(),
51 A.non_zeros(),
52 A.legacy_descr(),
53 A.values(),
54 A.row_offsets(),
55 A.col_indices(),
56 b.data(),
57 m_tolerance.solve_sparse_error_threshold<T>(),
58 m_reorder.reorder_method_int(),
59 x.data(),
60 singularity.get());
61
62 std::string label{this->label()};
63 this->label(""); // remove label because we consume it here
64
65 add_sync_callback(
66 [info = std::move(singularity), label = std::move(label)]() mutable
67 {
68 int result = *info;
69 if(result != -1)
70 {
71 MUDA_KERNEL_WARN_WITH_LOCATION("In calling label %s: A*x=b solving failed. R(%d,%d) is almost 0",
72 label.c_str(),
73 result,
74 result);
75 }
76 });
77}
78} // namespace muda