Skip to content

File device_triplet_matrix.h

File List > ext > linear_system > device_triplet_matrix.h

Go to the documentation of this file

#pragma once
#include <muda/buffer/device_buffer.h>
#include <muda/ext/linear_system/triplet_matrix_view.h>
namespace muda::details
{
template <typename T, int N>
class MatrixFormatConverter;
}

namespace muda
{
template <typename T, int N>
class DeviceTripletMatrix
{
  public:
    template <typename U, int M>
    friend class details::MatrixFormatConverter;
    using BlockMatrix = Eigen::Matrix<T, N, N>;

  protected:
    DeviceBuffer<BlockMatrix> m_block_values;
    DeviceBuffer<int>         m_block_row_indices;
    DeviceBuffer<int>         m_block_col_indices;

    int m_block_rows = 0;
    int m_block_cols = 0;

  public:
    DeviceTripletMatrix()                                      = default;
    ~DeviceTripletMatrix()                                     = default;
    DeviceTripletMatrix(const DeviceTripletMatrix&)            = default;
    DeviceTripletMatrix(DeviceTripletMatrix&&)                 = default;
    DeviceTripletMatrix& operator=(const DeviceTripletMatrix&) = default;
    DeviceTripletMatrix& operator=(DeviceTripletMatrix&&)      = default;

    void reshape(int row, int col)
    {
        m_block_rows = row;
        m_block_cols = col;
    }

    void resize_triplets(size_t nonzero_count)
    {
        m_block_values.resize(nonzero_count);
        m_block_row_indices.resize(nonzero_count);
        m_block_col_indices.resize(nonzero_count);
    }

    void reserve_triplets(size_t nonzero_count)
    {
        m_block_values.reserve(nonzero_count);
        m_block_row_indices.reserve(nonzero_count);
        m_block_col_indices.reserve(nonzero_count);
    }

    void resize(int row, int col, size_t nonzero_count)
    {
        reshape(row, col);
        resize_triplets(nonzero_count);
    }

    static constexpr int block_dim() { return N; }

    auto block_values() { return m_block_values.view(); }
    auto block_values() const { return m_block_values.view(); }
    auto block_row_indices() { return m_block_row_indices.view(); }
    auto block_row_indices() const { return m_block_row_indices.view(); }
    auto block_col_indices() { return m_block_col_indices.view(); }
    auto block_col_indices() const { return m_block_col_indices.view(); }

    auto block_rows() const { return m_block_rows; }
    auto block_cols() const { return m_block_cols; }
    auto triplet_count() const { return m_block_values.size(); }
    auto triplet_capacity() const { return m_block_values.capacity(); }

    auto view()
    {
        return TripletMatrixView<T, N>{m_block_rows,
                                       m_block_cols,
                                       (int)m_block_values.size(),
                                       m_block_row_indices.data(),
                                       m_block_col_indices.data(),
                                       m_block_values.data()};
    }

    auto view() const { return remove_const(*this).view().as_const(); }

    auto cview() const { return view(); }

    auto viewer() { return view().viewer(); }

    auto cviewer() const { return view().cviewer(); }

    operator TripletMatrixView<T, N>() { return view(); }
    operator CTripletMatrixView<T, N>() const { return view(); }

    void clear()
    {
        m_block_rows = 0;
        m_block_cols = 0;
        m_block_values.clear();
        m_block_row_indices.clear();
        m_block_col_indices.clear();
    }
};

template <typename T>
class DeviceTripletMatrix<T, 1>
{
  public:
    template <typename U, int M>
    friend class details::MatrixFormatConverter;

  protected:
    DeviceBuffer<T>   m_values;
    DeviceBuffer<int> m_row_indices;
    DeviceBuffer<int> m_col_indices;

    int m_rows = 0;
    int m_cols = 0;

  public:
    DeviceTripletMatrix()                                      = default;
    ~DeviceTripletMatrix()                                     = default;
    DeviceTripletMatrix(const DeviceTripletMatrix&)            = default;
    DeviceTripletMatrix(DeviceTripletMatrix&&)                 = default;
    DeviceTripletMatrix& operator=(const DeviceTripletMatrix&) = default;
    DeviceTripletMatrix& operator=(DeviceTripletMatrix&&)      = default;

    void reshape(int row, int col)
    {
        m_rows = row;
        m_cols = col;
    }

    void resize_triplets(size_t nonzero_count)
    {
        m_values.resize(nonzero_count);
        m_row_indices.resize(nonzero_count);
        m_col_indices.resize(nonzero_count);
    }

    void reserve_triplets(size_t nonzero_count)
    {
        m_values.reserve(nonzero_count);
        m_row_indices.reserve(nonzero_count);
        m_col_indices.reserve(nonzero_count);
    }

    void resize(int row, int col, size_t nonzero_count)
    {
        reshape(row, col);
        resize_triplets(nonzero_count);
    }

    static constexpr int block_size() { return 1; }

    auto values() { return m_values.view(); }
    auto values() const { return m_values.view(); }
    auto row_indices() { return m_row_indices.view(); }
    auto row_indices() const { return m_row_indices.view(); }
    auto col_indices() { return m_col_indices.view(); }
    auto col_indices() const { return m_col_indices.view(); }

    auto rows() const { return m_rows; }
    auto cols() const { return m_cols; }
    auto triplet_count() const { return m_values.size(); }

    auto view() const { return remove_const(*this).view().as_const(); }

    auto view()
    {
        return TripletMatrixView<T, 1>{m_rows,
                                       m_cols,
                                       (int)m_values.size(),
                                       m_row_indices.data(),
                                       m_col_indices.data(),
                                       m_values.data()};
    }

    auto viewer() { return view().viewer(); }
    auto cviewer() const { return view().cviewer(); }

    operator TripletMatrixView<T, 1>() { return view(); }
    operator CTripletMatrixView<T, 1>() const { return view(); }

    void clear()
    {
        m_rows = 0;
        m_cols = 0;
        m_values.clear();
        m_row_indices.clear();
        m_col_indices.clear();
    }
};
}  // namespace muda
#include "details/device_triplet_matrix.inl"