MUDA
Loading...
Searching...
No Matches
dot.inl
1namespace muda
2{
3namespace details::linear_system
4{
5 template <typename T>
6 MUDA_INLINE void dot_common_check(CDenseVectorView<T> x, CDenseVectorView<T> y)
7 {
8 MUDA_ASSERT(x.data() && y.data(), "x.data() and y.data() should not be nullptr");
9 MUDA_ASSERT(x.size() / x.inc() == y.size() / y.inc(),
10 "x (size=%lld, inc=%d) should be the same as y (size=%lld, inc=%d)",
11 x.size(),
12 x.inc(),
13 y.size(),
14 y.inc());
15 }
16} // namespace details::linear_system
17
18
19template <typename T>
20void LinearSystemContext::dot(CDenseVectorView<T> x, CDenseVectorView<T> y, T* result)
21{
22 set_pointer_mode_host();
23 details::linear_system::dot_common_check(x, y);
24
25 auto type = cuda_data_type<T>();
26 auto size = x.size() / x.inc();
27
28 checkCudaErrors(cublasDotEx(
29 cublas(), size, x.data(), type, x.inc(), y.data(), type, y.inc(), result, type, type));
30}
31
32template <typename T>
33T LinearSystemContext::dot(CDenseVectorView<T> x, CDenseVectorView<T> y)
34{
35 T result;
36 dot(x, y, &result);
37 sync();
38 return result;
39}
40
41template <typename T>
42void LinearSystemContext::dot(CDenseVectorView<T> x, CDenseVectorView<T> y, VarView<T> result)
43{
44 set_pointer_mode_device();
45 details::linear_system::dot_common_check(x, y);
46
47 auto type = cuda_data_type<T>();
48 auto size = x.size() / x.inc();
49
50
51 checkCudaErrors(cublasDotEx(
52 cublas(), size, x.data(), type, x.inc(), y.data(), type, y.inc(), result.data(), type, type));
53}
54
55} // namespace muda