Skip to content

File device_bsr_matrix.h

File List > ext > linear_system > device_bsr_matrix.h

Go to the documentation of this file

#pragma once
#include <muda/buffer/device_buffer.h>
#include <muda/ext/linear_system/bsr_matrix_view.h>
#include <cusparse.h>

namespace muda::details
{
template <typename T, int N>
class MatrixFormatConverter;
}

namespace muda
{
template <typename Ty, int N>
class DeviceBSRMatrix
{
    template <typename U, int M>
    friend class details::MatrixFormatConverter;

  public:
    using ValueT = std::conditional_t<N == 1, Ty, Eigen::Matrix<Ty, N, N>>;
    static constexpr bool IsBlockMatrix = (N > 1);

  protected:
    muda::DeviceBuffer<ValueT>   m_values;
    muda::DeviceBuffer<int>      m_row_offsets;
    muda::DeviceBuffer<int>      m_col_indices;
    mutable cusparseSpMatDescr_t m_descr        = nullptr;
    mutable cusparseMatDescr_t   m_legacy_descr = nullptr;

    int m_row = 0;
    int m_col = 0;

  public:
    DeviceBSRMatrix() = default;
    ~DeviceBSRMatrix();

    DeviceBSRMatrix(const DeviceBSRMatrix&);
    DeviceBSRMatrix(DeviceBSRMatrix&&);

    DeviceBSRMatrix& operator=(const DeviceBSRMatrix&);
    DeviceBSRMatrix& operator=(DeviceBSRMatrix&&);

    void reshape(int row, int col);
    void reserve(int non_zero_blocks);
    void reserve_offsets(int size);
    void resize(int non_zero_blocks);

    static constexpr int block_size() { return N; }

    auto values() { return m_values.view(); }
    auto values() const { return m_values.view(); }

    auto row_offsets() { return m_row_offsets.view(); }
    auto row_offsets() const { return m_row_offsets.view(); }

    auto col_indices() { return m_col_indices.view(); }
    auto col_indices() const { return m_col_indices.view(); }

    auto rows() const { return m_row; }
    auto cols() const { return m_col; }
    auto non_zeros() const { return m_values.size(); }

    cusparseSpMatDescr_t descr() const;
    cusparseMatDescr_t   legacy_descr() const;

    auto view()
    {
        return BSRMatrixView<Ty, N>{m_row,
                                    m_col,
                                    m_row_offsets.data(),
                                    m_col_indices.data(),
                                    m_values.data(),
                                    (int)m_values.size(),
                                    descr(),
                                    legacy_descr(),
                                    false};
    }

    operator BSRMatrixView<Ty, N>() { return view(); }

    auto view() const
    {
        return CBSRMatrixView<Ty, N>{m_row,
                                     m_col,
                                     m_row_offsets.data(),
                                     m_col_indices.data(),
                                     m_values.data(),
                                     (int)m_values.size(),
                                     descr(),
                                     legacy_descr(),
                                     false};
    }

    operator CBSRMatrixView<Ty, N>() const { return view(); }

    auto cview() const { return view(); }

    auto T() const { return view().T(); }
    auto T() { return view().T(); }

    void clear();

  private:
    void destroy_all_descr() const;
};
}  // namespace muda

#include "details/device_bsr_matrix.inl"