MUDA
Loading...
Searching...
No Matches
data_type_mapper.h
1#pragma once
2#include <cublas_v2.h>
3#include <cusparse_v2.h>
4#include <muda/type_traits/always.h>
5namespace muda
6{
7template <typename T>
8inline constexpr cudaDataType_t cuda_data_type()
9{
10 if constexpr(std::is_same_v<T, float>)
11 {
12 return CUDA_R_32F;
13 }
14 else if constexpr(std::is_same_v<T, double>)
15 {
16 return CUDA_R_64F;
17 }
18 else if constexpr(std::is_same_v<T, cuComplex>)
19 {
20 return CUDA_C_32F;
21 }
22 else if constexpr(std::is_same_v<T, cuDoubleComplex>)
23 {
24 return CUDA_C_64F;
25 }
26 else
27 {
28 static_assert(always_false_v<T>, "not supported type");
29 }
30}
31
32constexpr cublasOperation_t cublas_trans_operation(bool b)
33{
34 return b ? CUBLAS_OP_T : CUBLAS_OP_N;
35}
36
37template <typename T>
38constexpr cusparseIndexType_t cusparse_index_type()
39{
40 if constexpr(std::is_same_v<T, int>)
41 return cusparseIndexType_t::CUSPARSE_INDEX_32I;
42 else if constexpr(std::is_same_v<T, int64_t>)
43 return cusparseIndexType_t::CUSPARSE_INDEX_64I;
44 else if constexpr(std::is_same_v<T, uint16_t>)
45 return cusparseIndexType_t::CUSPARSE_INDEX_16U;
46 else
47 static_assert(always_false_v<T>, "Unsupported type");
48}
49} // namespace muda