MUDA
Loading...
Searching...
No Matches
solve_dense.inl
1namespace muda
2{
3template <typename T>
4void muda::LinearSystemContext::sysv(DenseMatrixView<T> A, DenseVectorView<T> b)
5{
6 auto cusolver = cusolver_dn();
7
8 auto info = std::make_shared<DeviceVar<int>>();
9
10
11 size_t d_lwork = 0; /* size of workspace in device */
12 size_t h_lwork = 0; /* size of workspace in host */
13
14 cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;
15
16 auto m = A.row();
17
18 // query working space
19 checkCudaErrors(cusolverDnXpotrf_bufferSize(
20 cusolver, nullptr, uplo, m, cuda_data_type<T>(), A.data(), A.lda(), cuda_data_type<T>(), &d_lwork, &h_lwork));
21
22 auto device_buffer = temp_buffer<T>(d_lwork);
23 auto host_buffer = temp_host_buffer<T>(h_lwork);
24
25
26 // Cholesky factorization
27 checkCudaErrors(cusolverDnXpotrf(cusolver,
28 nullptr,
29 uplo,
30 m,
31 cuda_data_type<T>(),
32 A.data(),
33 A.lda(),
34 cuda_data_type<T>(),
35 device_buffer.data(),
36 d_lwork,
37 host_buffer.data(),
38 h_lwork,
39 info->data()));
40
41 // solve the system
42 checkCudaErrors(cusolverDnXpotrs(cusolver,
43 nullptr,
44 uplo,
45 m,
46 1, /* nrhs */
47 cuda_data_type<T>(),
48 A.data(),
49 A.lda(),
50 cuda_data_type<T>(),
51 b.data(),
52 m,
53 info->data()));
54
55 std::string label{this->label()};
56 this->label(""); // remove label because we consume it here
57
58 add_sync_callback(
59 [info = std::move(info), label = std::move(label)]() mutable
60 {
61 int result = *info;
62 if(result < 0)
63 MUDA_KERNEL_WARN_WITH_LOCATION("In calling label %s: A*x=b solving failed. The %d-th parameter in Cholesky factorization is wrong.",
64 label.c_str(),
65 -result);
66 });
67}
68
69template <typename T>
70void LinearSystemContext::gesv(DenseMatrixView<T> A, DenseVectorView<T> b)
71{
72 auto cusolver = cusolver_dn();
73
74 int64_t m = A.row();
75 size_t d_lwork = 0; /* size of workspace in device */
76 size_t h_lwork = 0; /* size of workspace in host */
77
78 auto info = std::make_shared<DeviceVar<int>>();
79
80 cusolverDnParams_t params;
81 cusolverDnCreateParams(&params);
82 cusolverDnSetAdvOptions(params, CUSOLVERDN_GETRF, CUSOLVER_ALG_0);
83
84 constexpr int pivot_on = 1;
85 size_t d_piv_count = A.row();
86
87 checkCudaErrors(cusolverDnXgetrf_bufferSize(cusolver,
88 params,
89 A.row(),
90 A.col(),
91 cuda_data_type<T>(),
92 A.data(),
93 A.lda(),
94 cuda_data_type<T>(),
95 &d_lwork,
96 &h_lwork));
97
98 auto buffer = temp_buffer(d_lwork * sizeof(T) + d_piv_count * sizeof(int64_t));
99
100 auto device_buffer = muda::BufferView<T>{(T*)buffer.data(), 0, d_lwork};
101
102 auto last = device_buffer.data(d_lwork);
103
104 auto d_piv = muda::BufferView<int64_t>{(int64_t*)(last), 0, d_piv_count};
105
106 auto host_buffer = temp_host_buffer<T>(h_lwork);
107
108 checkCudaErrors(cusolverDnXgetrf(cusolver,
109 params,
110 m,
111 m,
112 cuda_data_type<T>(),
113 A.data(),
114 A.lda(),
115 d_piv.data(),
116 cuda_data_type<T>(),
117 device_buffer.data(),
118 d_lwork,
119 host_buffer.data(),
120 h_lwork,
121 info->data()));
122
123 checkCudaErrors(cusolverDnXgetrs(cusolver,
124 params,
125 CUBLAS_OP_N,
126 m,
127 1, /* nrhs */
128 cuda_data_type<T>(),
129 A.data(),
130 A.lda(),
131 d_piv.data(),
132 cuda_data_type<T>(),
133 b.data(),
134 m,
135 info->data()));
136
137 std::string label{this->label()};
138 this->label(""); // remove label because we consume it here
139
140 add_sync_callback(
141 [info = std::move(info), label = std::move(label), params]() mutable
142 {
143 int result = *info;
144 if(result < 0)
145 MUDA_KERNEL_WARN_WITH_LOCATION("In calling label %f: A*x=b solving failed. The %d-th parameter in LU factorization is wrong.",
146 label.c_str(),
147 -result);
148
149 checkCudaErrors(cusolverDnDestroyParams(params));
150 });
151}
152
153
154template <typename T>
155void LinearSystemContext::solve(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x)
156{
157 if(A_to_fact.is_sym())
158 {
159 sysv(A_to_fact, b_to_x);
160 }
161 else
162 {
163 gesv(A_to_fact, b_to_x);
164 }
165}
166} // namespace muda