4void muda::LinearSystemContext::sysv(DenseMatrixView<T> A, DenseVectorView<T> b)
6 auto cusolver = cusolver_dn();
8 auto info = std::make_shared<DeviceVar<int>>();
14 cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;
19 checkCudaErrors(cusolverDnXpotrf_bufferSize(
20 cusolver,
nullptr, uplo, m, cuda_data_type<T>(), A.data(), A.lda(), cuda_data_type<T>(), &d_lwork, &h_lwork));
22 auto device_buffer = temp_buffer<T>(d_lwork);
23 auto host_buffer = temp_host_buffer<T>(h_lwork);
27 checkCudaErrors(cusolverDnXpotrf(cusolver,
42 checkCudaErrors(cusolverDnXpotrs(cusolver,
55 std::string label{this->label()};
59 [info = std::move(info), label = std::move(label)]()
mutable
63 MUDA_KERNEL_WARN_WITH_LOCATION(
"In calling label %s: A*x=b solving failed. The %d-th parameter in Cholesky factorization is wrong.",
70void LinearSystemContext::gesv(DenseMatrixView<T> A, DenseVectorView<T> b)
72 auto cusolver = cusolver_dn();
78 auto info = std::make_shared<DeviceVar<int>>();
80 cusolverDnParams_t params;
81 cusolverDnCreateParams(¶ms);
82 cusolverDnSetAdvOptions(params, CUSOLVERDN_GETRF, CUSOLVER_ALG_0);
84 constexpr int pivot_on = 1;
85 size_t d_piv_count = A.row();
87 checkCudaErrors(cusolverDnXgetrf_bufferSize(cusolver,
98 auto buffer = temp_buffer(d_lwork *
sizeof(T) + d_piv_count *
sizeof(int64_t));
102 auto last = device_buffer.data(d_lwork);
106 auto host_buffer = temp_host_buffer<T>(h_lwork);
108 checkCudaErrors(cusolverDnXgetrf(cusolver,
117 device_buffer.data(),
123 checkCudaErrors(cusolverDnXgetrs(cusolver,
137 std::string label{this->label()};
141 [info = std::move(info), label = std::move(label), params]()
mutable
145 MUDA_KERNEL_WARN_WITH_LOCATION(
"In calling label %f: A*x=b solving failed. The %d-th parameter in LU factorization is wrong.",
149 checkCudaErrors(cusolverDnDestroyParams(params));
155void LinearSystemContext::solve(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x)
157 if(A_to_fact.is_sym())
159 sysv(A_to_fact, b_to_x);
163 gesv(A_to_fact, b_to_x);