3namespace details::linear_system
6 MUDA_INLINE
void mv_common_check(CDenseMatrixView<T> A,
10 MUDA_ASSERT(A.col() == y.size(),
"A.col() must be equal to y.size()");
11 MUDA_ASSERT(A.row() == x.size(),
"A.row() must be equal to x.size()");
12 MUDA_ASSERT(A.data(),
"Matrix A is empty");
13 MUDA_ASSERT(x.data(),
"Vector x is empty");
14 MUDA_ASSERT(y.data(),
"Vector y is empty");
18 void gemv(cublasHandle_t handle,
19 cublasOperation_t trans,
31 if constexpr(std::is_same_v<T, float>)
33 checkCudaErrors(cublasSgemv_v2_64(
34 handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy));
36 else if constexpr(std::is_same_v<T, double>)
38 checkCudaErrors(cublasDgemv_v2_64(
39 handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy));
44 void symv(cublasHandle_t handle,
55 if constexpr(std::is_same_v<T, float>)
57 checkCudaErrors(cublasSsymv_v2_64(
58 handle, cublasFillMode_t::CUBLAS_FILL_MODE_LOWER, m, alpha, A, lda, x, incx, beta, y, incy));
60 else if constexpr(std::is_same_v<T, double>)
62 checkCudaErrors(cublasDsymv_v2_64(
63 handle, cublasFillMode_t::CUBLAS_FILL_MODE_LOWER, m, alpha, A, lda, x, incx, beta, y, incy));
68void LinearSystemContext::mv(CDenseMatrixView<T> A,
70 CDenseVectorView<T> x,
74 set_pointer_mode_host();
75 details::linear_system::mv_common_check(A, x, y);
79 MUDA_ASSERT(A.row() == A.col(),
"A must be square matrix");
81 details::linear_system::symv<T>(cublas(),
94 details::linear_system::gemv<T>(cublas(),
95 cublas_trans_operation(A.is_trans()),
110void LinearSystemContext::mv(CDenseMatrixView<T> A,
112 CDenseVectorView<T> x,
114 DenseVectorView<T> y)
116 set_pointer_mode_device();
117 details::linear_system::mv_common_check(A, x, y);
121 MUDA_ASSERT(A.row() == A.col(),
"A must be square matrix");
123 details::linear_system::symv<T>(cublas(),
136 details::linear_system::gemv<T>(cublas(),
137 cublas_trans_operation(A.is_trans()),
152void LinearSystemContext::mv(CDenseMatrixView<T> A, CDenseVectorView<T> x, DenseVectorView<T> y)
154 mv(A, T(1), x, T(0), y);