Skip to content

File matrix_format_converter.h

File List > ext > linear_system > matrix_format_converter.h

Go to the documentation of this file

#pragma once
#include <muda/ext/linear_system/linear_system_handles.h>
#include <muda/ext/linear_system/device_dense_matrix.h>
#include <muda/ext/linear_system/device_dense_vector.h>
#include <muda/ext/linear_system/device_triplet_matrix.h>
#include <muda/ext/linear_system/device_doublet_vector.h>
#include <muda/ext/linear_system/device_bcoo_matrix.h>
#include <muda/ext/linear_system/device_bcoo_vector.h>
#include <muda/ext/linear_system/device_bsr_matrix.h>
#include <muda/ext/linear_system/device_csr_matrix.h>

namespace muda::details
{
class MatrixFormatConverterBase;
template <typename T, int N>
class MatrixFormatConverter;

class MatrixFormatConverterType
{
  public:
    cudaDataType_t data_type;
    int            N;
    bool friend    operator==(const MatrixFormatConverterType& lhs,
                           const MatrixFormatConverterType& rhs)
    {
        return lhs.data_type == rhs.data_type && lhs.N == rhs.N;
    }
};
}  // namespace muda::details

namespace std
{
template <>
struct hash<muda::details::MatrixFormatConverterType>
{
    size_t operator()(const muda::details::MatrixFormatConverterType& x) const
    {
        return (std::hash<int>()(x.data_type) << 8) ^ std::hash<int>()(x.N);
    }
};
}  // namespace std


namespace muda
{

class MatrixFormatConverter
{
    template <typename T>
    using U = std::unique_ptr<T>;
    LinearSystemHandles& m_handles;
    using TypeN = std::pair<cudaDataType_t, int>;
    std::unordered_map<details::MatrixFormatConverterType, U<details::MatrixFormatConverterBase>> m_impls;
    details::MatrixFormatConverterBase* current = nullptr;
    template <typename T, int N>
    details::MatrixFormatConverter<T, N>& impl();

  public:
    MatrixFormatConverter(LinearSystemHandles& handles)
        : m_handles(handles)
    {
    }
    ~MatrixFormatConverter();

    // Triplet -> BCOO
    template <typename T, int N>
    void convert(const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to);

    // BCOO -> Dense Matrix
    template <typename T, int N>
    void convert(const DeviceBCOOMatrix<T, N>& from,
                 DeviceDenseMatrix<T>&         to,
                 bool                          clear_dense_matrix = true);

    // BCOO -> COO
    template <typename T, int N>
    void convert(const DeviceBCOOMatrix<T, N>& from, DeviceCOOMatrix<T>& to);

    // BCOO -> BSR
    template <typename T, int N>
    void convert(const DeviceBCOOMatrix<T, N>& from, DeviceBSRMatrix<T, N>& to);

    // Doublet -> BCOO
    template <typename T, int N>
    void convert(const DeviceDoubletVector<T, N>& from, DeviceBCOOVector<T, N>& to);

    // BCOO -> Dense Vector
    template <typename T, int N>
    void convert(const DeviceBCOOVector<T, N>& from,
                 DeviceDenseVector<T>&         to,
                 bool                          clear_dense_vector = true);

    // Doublet -> Dense Vector
    template <typename T, int N>
    void convert(const DeviceDoubletVector<T, N>& from,
                 DeviceDenseVector<T>&            to,
                 bool                             clear_dense_vector = true);

    // BSR -> CSR
    template <typename T, int N>
    void convert(const DeviceBSRMatrix<T, N>& from, DeviceCSRMatrix<T>& to);

    // Triplet -> COO
    template <typename T>
    void convert(const DeviceTripletMatrix<T, 1>& from, DeviceCOOMatrix<T>& to);

    // COO -> Dense Matrix
    template <typename T>
    void convert(const DeviceCOOMatrix<T>& from,
                 DeviceDenseMatrix<T>&     to,
                 bool                      clear_dense_matrix = true);

    // COO -> CSR
    template <typename T>
    void convert(const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to);
    template <typename T>
    void convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to);

    // Doublet -> COO
    template <typename T>
    void convert(const DeviceDoubletVector<T, 1>& from, DeviceCOOVector<T>& to);

    // COO -> Dense Vector
    template <typename T>
    void convert(const DeviceCOOVector<T>& from,
                 DeviceDenseVector<T>&     to,
                 bool                      clear_dense_vector = true);

    // Doublet -> Dense Vector
    template <typename T>
    void convert(const DeviceDoubletVector<T, 1>& from,
                 DeviceDenseVector<T>&            to,
                 bool                             clear_dense_vector = true);
};
}  // namespace muda

#include "details/matrix_format_converter.inl"