Skip to content

File data_type_mapper.h

File List > ext > linear_system > type_mapper > data_type_mapper.h

Go to the documentation of this file

#pragma once
#include <cublas_v2.h>
#include <cusparse_v2.h>
#include <muda/type_traits/always.h>
namespace muda
{
template <typename T>
inline constexpr cudaDataType_t cuda_data_type()
{
    if constexpr(std::is_same_v<T, float>)
    {
        return CUDA_R_32F;
    }
    else if constexpr(std::is_same_v<T, double>)
    {
        return CUDA_R_64F;
    }
    else if constexpr(std::is_same_v<T, cuComplex>)
    {
        return CUDA_C_32F;
    }
    else if constexpr(std::is_same_v<T, cuDoubleComplex>)
    {
        return CUDA_C_64F;
    }
    else
    {
        static_assert(always_false_v<T>, "not supported type");
    }
}

constexpr cublasOperation_t cublas_trans_operation(bool b)
{
    return b ? CUBLAS_OP_T : CUBLAS_OP_N;
}

template <typename T>
constexpr cusparseIndexType_t cusparse_index_type()
{
    if constexpr(std::is_same_v<T, int>)
        return cusparseIndexType_t::CUSPARSE_INDEX_32I;
    else if constexpr(std::is_same_v<T, int64_t>)
        return cusparseIndexType_t::CUSPARSE_INDEX_64I;
    else if constexpr(std::is_same_v<T, uint16_t>)
        return cusparseIndexType_t::CUSPARSE_INDEX_16U;
    else
        static_assert(always_false_v<T>, "Unsupported type");
}
}  // namespace muda