MUDA
Loading...
Searching...
No Matches
spmv.inl
1#include <muda/ext/linear_system/type_mapper/algo_mapper.h>
2namespace muda
3{
4template <typename T>
5void LinearSystemContext::generic_spmv(const T& a,
6 cusparseOperation_t op,
7 cusparseSpMatDescr_t A,
8 const cusparseDnVecDescr* x,
9 const T& b,
10 cusparseDnVecDescr_t y)
11{
12 set_pointer_mode_host();
13
14 size_t buffer_size = 0;
15 checkCudaErrors(cusparseSpMV_bufferSize(
16 cusparse(), op, &a, A, x, &b, y, cuda_data_type<T>(), LinearSystemAlgorithm::SPMV_ALG_DEFAULT, &buffer_size));
17
18 auto buffer = temp_buffer(buffer_size);
19
20 checkCudaErrors(cusparseSpMV(cusparse(),
21 op,
22 &a,
23 A,
24 x,
25 &b,
26 y,
27 cuda_data_type<T>(),
28 LinearSystemAlgorithm::SPMV_ALG_DEFAULT,
29 buffer.data()));
30}
31} // namespace muda
32
33#include "spmv/coo_spmv.inl"
34#include "spmv/csr_spmv.inl"
35#include "spmv/bsr_spmv.inl"
36#include "spmv/triplet_spmv.inl"