MUDA
Loading...
Searching...
No Matches
axpby.inl
1namespace muda
2{
3namespace details::linear_system
4{
5 template <typename T>
6 void axpby_common_check(CDenseVectorView<T> x, DenseVectorView<T> y)
7 {
8 MUDA_ASSERT(x.data(), "Vector x is empty");
9 MUDA_ASSERT(y.data(), "Vector y is empty");
10 MUDA_ASSERT(x.size() / x.inc() == y.size() / y.inc(),
11 "Vector x and y have different size, x (size=%lld, inc=%d), y (size=%lld, inc=%d)",
12 x.size(),
13 x.inc(),
14 y.size(),
15 y.inc());
16 }
17 template <typename T>
18 void axpby_common_check(CDenseVectorView<T> x, CDenseVectorView<T> y, DenseVectorView<T> z)
19 {
20 MUDA_ASSERT(x.data(), "Vector x is empty");
21 MUDA_ASSERT(y.data(), "Vector y is empty");
22 MUDA_ASSERT(z.data(), "Vector z is empty");
23 MUDA_ASSERT(x.size() / x.inc() == y.size() / y.inc(),
24 "Vector x and y have different size, x (size=%lld, inc=%d), y (size=%lld, inc=%d)",
25 x.size(),
26 x.inc(),
27 y.size(),
28 y.inc());
29 MUDA_ASSERT(x.size() / x.inc() == z.size() / z.inc(),
30 "Vector x and z have different size, x (size=%lld, inc=%d), z (size=%lld, inc=%d)",
31 x.size(),
32 x.inc(),
33 z.size(),
34 z.inc());
35 }
36} // namespace details::linear_system
37template <typename T>
38void LinearSystemContext::axpby(const T& alpha,
39 CDenseVectorView<T> x,
40 const T& beta,
41 DenseVectorView<T> y)
42{
43 details::linear_system::axpby_common_check(x, y);
44 auto size = x.size() / x.inc();
45 ParallelFor().apply(size,
46 [x = x.buffer_view(),
47 x_inc = x.inc(),
48 y = y.buffer_view(),
49 y_inc = y.inc(),
50 a = alpha,
51 b = beta] __device__(int i) mutable
52 {
53 auto& r_y = *y.data(i * y_inc);
54 auto& r_x = *x.data(i * x_inc);
55 r_y = a * r_x + b * r_y;
56 });
57}
58template <typename T>
59void muda::LinearSystemContext::axpby(CVarView<T> alpha,
60 CDenseVectorView<T> x,
61 CVarView<T> beta,
62 DenseVectorView<T> y)
63{
64 details::linear_system::axpby_common_check(x, y);
65 auto size = x.size() / x.inc();
66 ParallelFor().apply(size,
67 [x = x.buffer_view(),
68 x_inc = x.inc(),
69 y = y.buffer_view(),
70 y_inc = y.inc(),
71 a = alpha.data(),
72 b = beta.data()] __device__(int i) mutable
73 {
74 auto& r_y = *y.data(i * y_inc);
75 auto& r_x = *x.data(i * x_inc);
76 r_y = *a * r_x + *b * r_y;
77 });
78}
79template <typename T>
80void LinearSystemContext::plus(CDenseVectorView<T> x, CDenseVectorView<T> y, DenseVectorView<T> z)
81{
82 details::linear_system::axpby_common_check(x, y, z);
83 auto size = x.size() / x.inc();
84 ParallelFor().apply(size,
85 [x = x.buffer_view(),
86 x_inc = x.inc(),
87 y = y.buffer_view(),
88 y_inc = y.inc(),
89 z = z.buffer_view(),
90 z_inc = z.inc()] __device__(int i) mutable
91 {
92 auto& r_z = *z.data(i * z_inc);
93 auto& r_x = *x.data(i * x_inc);
94 auto& r_y = *y.data(i * y_inc);
95 r_z = r_x + r_y;
96 });
97}
98} // namespace muda