MUDA
Loading...
Searching...
No Matches
mv.inl
1namespace muda
2{
3namespace details::linear_system
4{
5 template <typename T>
6 MUDA_INLINE void mv_common_check(CDenseMatrixView<T> A,
7 CDenseVectorView<T> x,
8 DenseVectorView<T> y)
9 {
10 MUDA_ASSERT(A.col() == y.size(), "A.col() must be equal to y.size()");
11 MUDA_ASSERT(A.row() == x.size(), "A.row() must be equal to x.size()");
12 MUDA_ASSERT(A.data(), "Matrix A is empty");
13 MUDA_ASSERT(x.data(), "Vector x is empty");
14 MUDA_ASSERT(y.data(), "Vector y is empty");
15 }
16
17 template <typename T>
18 void gemv(cublasHandle_t handle,
19 cublasOperation_t trans,
20 int64_t m,
21 int64_t n,
22 const T* alpha,
23 const T* A,
24 int64_t lda,
25 const T* x,
26 int64_t incx,
27 const T* beta,
28 T* y,
29 int64_t incy)
30 {
31 if constexpr(std::is_same_v<T, float>)
32 {
33 checkCudaErrors(cublasSgemv_v2_64(
34 handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy));
35 }
36 else if constexpr(std::is_same_v<T, double>)
37 {
38 checkCudaErrors(cublasDgemv_v2_64(
39 handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy));
40 }
41 }
42
43 template <typename T>
44 void symv(cublasHandle_t handle,
45 int64_t m,
46 const T* alpha,
47 const T* A,
48 int64_t lda,
49 const T* x,
50 int64_t incx,
51 const T* beta,
52 T* y,
53 int64_t incy)
54 {
55 if constexpr(std::is_same_v<T, float>)
56 {
57 checkCudaErrors(cublasSsymv_v2_64(
58 handle, cublasFillMode_t::CUBLAS_FILL_MODE_LOWER, m, alpha, A, lda, x, incx, beta, y, incy));
59 }
60 else if constexpr(std::is_same_v<T, double>)
61 {
62 checkCudaErrors(cublasDsymv_v2_64(
63 handle, cublasFillMode_t::CUBLAS_FILL_MODE_LOWER, m, alpha, A, lda, x, incx, beta, y, incy));
64 }
65 }
66} // namespace details::linear_system
67template <typename T>
68void LinearSystemContext::mv(CDenseMatrixView<T> A,
69 const T& alpha,
70 CDenseVectorView<T> x,
71 const T& beta,
72 DenseVectorView<T> y)
73{
74 set_pointer_mode_host();
75 details::linear_system::mv_common_check(A, x, y);
76
77 if(A.is_sym())
78 {
79 MUDA_ASSERT(A.row() == A.col(), "A must be square matrix");
80
81 details::linear_system::symv<T>(cublas(),
82 A.row(),
83 &alpha,
84 A.data(),
85 A.lda(),
86 x.data(),
87 (int64_t)x.inc(),
88 &beta,
89 y.data(),
90 (int64_t)y.inc());
91 }
92 else
93 {
94 details::linear_system::gemv<T>(cublas(),
95 cublas_trans_operation(A.is_trans()),
96 A.row(),
97 A.col(),
98 &alpha,
99 A.data(),
100 A.lda(),
101 x.data(),
102 (int64_t)x.inc(),
103 &beta,
104 y.data(),
105 (int64_t)y.inc());
106 }
107}
108
109template <typename T>
110void LinearSystemContext::mv(CDenseMatrixView<T> A,
111 CVarView<T> alpha,
112 CDenseVectorView<T> x,
113 CVarView<T> beta,
114 DenseVectorView<T> y)
115{
116 set_pointer_mode_device();
117 details::linear_system::mv_common_check(A, x, y);
118
119 if(A.is_sym())
120 {
121 MUDA_ASSERT(A.row() == A.col(), "A must be square matrix");
122
123 details::linear_system::symv<T>(cublas(),
124 A.row(),
125 alpha.data(),
126 A.data(),
127 A.lda(),
128 x.data(),
129 (int64_t)x.inc(),
130 beta.data(),
131 y.data(),
132 (int64_t)y.inc());
133 }
134 else
135 {
136 details::linear_system::gemv<T>(cublas(),
137 cublas_trans_operation(A.is_trans()),
138 A.row(),
139 A.col(),
140 alpha.data(),
141 A.data(),
142 A.lda(),
143 x.data(),
144 (int64_t)x.inc(),
145 beta.data(),
146 y.data(),
147 (int64_t)y.inc());
148 }
149}
150
151template <typename T>
152void LinearSystemContext::mv(CDenseMatrixView<T> A, CDenseVectorView<T> x, DenseVectorView<T> y)
153{
154 mv(A, T(1), x, T(0), y);
155}
156} // namespace muda