Skip to content

File device_csr_matrix.h

File List > ext > linear_system > device_csr_matrix.h

Go to the documentation of this file

#pragma once
#include <muda/buffer/device_buffer.h>
#include <cusparse.h>
#include <muda/ext/linear_system/csr_matrix_view.h>

namespace muda::details
{
template <typename T, int N>
class MatrixFormatConverter;
}

namespace muda
{
template <typename Ty>
class DeviceCSRMatrix
{
    template <typename T, int N>
    friend class details::MatrixFormatConverter;

  public:
    int m_row = 0;
    int m_col = 0;

    muda::DeviceBuffer<int> m_row_offsets;
    muda::DeviceBuffer<int> m_col_indices;
    muda::DeviceBuffer<Ty>  m_values;

    mutable cusparseSpMatDescr_t m_descr        = nullptr;
    mutable cusparseMatDescr_t   m_legacy_descr = nullptr;

  public:
    DeviceCSRMatrix() = default;
    ~DeviceCSRMatrix();

    DeviceCSRMatrix(const DeviceCSRMatrix&);
    DeviceCSRMatrix(DeviceCSRMatrix&&) noexcept;

    DeviceCSRMatrix& operator=(const DeviceCSRMatrix&);
    DeviceCSRMatrix& operator=(DeviceCSRMatrix&&) noexcept;

    void reshape(int row, int col);
    void reserve(int non_zeros);

    auto values() { return m_values.view(); }
    auto values() const { return m_values.view(); }

    auto row_offsets() { return m_row_offsets.view(); }
    auto row_offsets() const { return m_row_offsets.view(); }

    auto col_indices() { return m_col_indices.view(); }
    auto col_indices() const { return m_col_indices.view(); }

    auto rows() const { return m_row; }
    auto cols() const { return m_col; }
    auto non_zeros() const { return m_values.size(); }

    cusparseSpMatDescr_t descr() const;
    cusparseMatDescr_t   legacy_descr() const;

    auto view()
    {
        return CSRMatrixView<Ty>{m_row,
                                 m_col,
                                 m_row_offsets.data(),
                                 m_col_indices.data(),
                                 m_values.data(),
                                 (int)non_zeros(),
                                 descr(),
                                 legacy_descr(),
                                 false};
    }

    auto view() const
    {
        return CCSRMatrixView<Ty>{m_row,
                                  m_col,
                                  m_row_offsets.data(),
                                  m_col_indices.data(),
                                  m_values.data(),
                                  (int)non_zeros(),
                                  descr(),
                                  legacy_descr(),
                                  false};
    }

    auto cview() const { return view(); }

    auto T() const { return view().T(); }
    auto T() { return view().T(); }
    operator CSRMatrixView<Ty>() { return view(); }
    operator CCSRMatrixView<Ty>() const { return view(); }

    void clear();

  private:
    void destroy_all_descr() const;
};
}  // namespace muda
#include "details/device_csr_matrix.inl"