Skip to content

File triplet_matrix_viewer.h

File List > ext > linear_system > triplet_matrix_viewer.h

Go to the documentation of this file

#pragma once
#include <string>
#include <muda/viewer/viewer_base.h>
#include <muda/buffer/device_buffer.h>
#include <muda/tools/cuda_vec_utils.h>
#include <muda/ext/eigen/eigen_core_cxx20.h>


/*
* - 2024/2/23 remove viewer's subview, view's subview is enough
*/

namespace muda
{
template <bool IsConst, typename T, int N>
class TripletMatrixViewerBase : public ViewerBase<IsConst>
{
    using Base = ViewerBase<IsConst>;
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

  public:
    using BlockMatrix    = Eigen::Matrix<T, N, N>;
    using ConstViewer    = TripletMatrixViewerBase<true, T, N>;
    using NonConstViewer = TripletMatrixViewerBase<false, T, N>;
    using ThisViewer     = TripletMatrixViewerBase<IsConst, T, N>;

    struct CTriplet
    {
        MUDA_GENERIC CTriplet(int row_index, int col_index, const BlockMatrix& block)
            : block_row_index(row_index)
            , block_col_index(col_index)
            , block_value(block)
        {
        }
        int                block_row_index;
        int                block_col_index;
        const BlockMatrix& block_value;
    };


  protected:
    // matrix info
    int m_total_block_rows = 0;
    int m_total_block_cols = 0;

    // triplet info
    int m_triplet_index_offset = 0;
    int m_triplet_count        = 0;
    int m_total_triplet_count  = 0;

    // sub matrix info
    int2 m_submatrix_offset = {0, 0};
    int2 m_submatrix_extent = {0, 0};

    // data
    auto_const_t<int>*         m_block_row_indices;
    auto_const_t<int>*         m_block_col_indices;
    auto_const_t<BlockMatrix>* m_block_values;


  public:
    MUDA_GENERIC TripletMatrixViewerBase() = default;
    MUDA_GENERIC TripletMatrixViewerBase(int total_block_rows,
                                         int total_block_cols,
                                         int triplet_index_offset,
                                         int triplet_count,
                                         int total_triplet_count,

                                         int2 submatrix_offset,
                                         int2 submatrix_extent,

                                         auto_const_t<int>* block_row_indices,
                                         auto_const_t<int>* block_col_indices,
                                         auto_const_t<BlockMatrix>* block_values)
        : m_total_block_rows(total_block_rows)
        , m_total_block_cols(total_block_cols)
        , m_triplet_index_offset(triplet_index_offset)
        , m_triplet_count(triplet_count)
        , m_total_triplet_count(total_triplet_count)
        , m_submatrix_offset(submatrix_offset)
        , m_submatrix_extent(submatrix_extent)
        , m_block_row_indices(block_row_indices)
        , m_block_col_indices(block_col_indices)
        , m_block_values(block_values)
    {
        MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
                           "TripletMatrixViewer [%s:%s]: out of range, m_total_triplet_count=%d, "
                           "your triplet_index_offset=%d, triplet_count=%d",
                           this->name(),
                           this->kernel_name(),
                           total_triplet_count,
                           triplet_index_offset,
                           triplet_count);

        MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
                           "TripletMatrixViewer[%s:%s]: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.x,
                           submatrix_offset.y);

        MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_block_rows,
                           "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.x,
                           submatrix_extent.x,
                           total_block_rows);

        MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_block_cols,
                           "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.y,
                           submatrix_extent.y,
                           total_block_cols);
    }

    MUDA_GENERIC ConstViewer as_const() const
    {
        return ConstViewer{m_total_block_rows,
                           m_total_block_cols,
                           m_triplet_index_offset,
                           m_triplet_count,
                           m_total_triplet_count,
                           m_submatrix_offset,
                           m_submatrix_extent,
                           m_block_row_indices,
                           m_block_col_indices,
                           m_block_values};
    }

    MUDA_GENERIC operator ConstViewer() const { return as_const(); }

    // const accessor

    MUDA_GENERIC auto total_block_rows() const { return m_total_block_rows; }
    MUDA_GENERIC auto total_block_cols() const { return m_total_block_cols; }
    MUDA_GENERIC auto total_extent() const
    {
        return int2{m_total_block_rows, m_total_block_cols};
    }

    MUDA_GENERIC auto submatrix_offset() const { return m_submatrix_offset; }
    MUDA_GENERIC auto extent() const { return m_submatrix_extent; }

    MUDA_GENERIC auto triplet_count() const { return m_triplet_count; }
    MUDA_GENERIC auto tripet_index_offset() const
    {
        return m_triplet_index_offset;
    }
    MUDA_GENERIC auto total_triplet_count() const
    {
        return m_total_triplet_count;
    }

    MUDA_GENERIC CTriplet operator()(int i) const
    {
        auto index    = get_index(i);
        auto global_i = m_block_row_indices[index];
        auto global_j = m_block_col_indices[index];
        auto sub_i    = global_i - m_submatrix_offset.x;
        auto sub_j    = global_j - m_submatrix_offset.y;
        check_in_submatrix(sub_i, sub_j);
        return CTriplet{sub_i, sub_j, m_block_values[index]};
    }

  protected:
    MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
    {

        MUDA_KERNEL_ASSERT(i >= 0 && i < m_triplet_count,
                           "TripletMatrixViewer [%s:%s]: triplet_index out of range, block_count=%d, your index=%d",
                           this->name(),
                           this->kernel_name(),
                           m_triplet_count,
                           i);
        auto index = i + m_triplet_index_offset;
        return index;
    }

    MUDA_INLINE MUDA_GENERIC void check_in_submatrix(int i, int j) const noexcept
    {
        MUDA_KERNEL_ASSERT(i >= 0 && i < m_submatrix_extent.x,
                           "TripletMatrixViewer [%s:%s]: row index out of submatrix range,  submatrix_extent.x=%d, your i=%d",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.x,
                           i);

        MUDA_KERNEL_ASSERT(j >= 0 && j < m_submatrix_extent.y,
                           "TripletMatrixViewer [%s:%s]: col index out of submatrix range,  submatrix_extent.y=%d, your j=%d",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.y,
                           j);
    }
};

template <typename T, int N>
class CTripletMatrixViewer : public TripletMatrixViewerBase<true, T, N>
{
    using Base = TripletMatrixViewerBase<true, T, N>;
    MUDA_VIEWER_COMMON_NAME(CTripletMatrixViewer);
    using BlockMatrix = typename Base::BlockMatrix;

  public:
    using Base::Base;

    MUDA_GENERIC CTripletMatrixViewer(const Base& base)
        : Base(base)
    {
    }
};

template <typename T, int N>
class TripletMatrixViewer : public TripletMatrixViewerBase<false, T, N>
{
    using Base = TripletMatrixViewerBase<false, T, N>;
    MUDA_VIEWER_COMMON_NAME(TripletMatrixViewer);

  public:
    using Base::Base;
    using BlockMatrix    = typename Base::BlockMatrix;
    using CTriplet       = typename Base::CTriplet;
    using ConstViewer    = CTripletMatrixViewer<T, N>;
    using NonConstViewer = TripletMatrixViewer<T, N>;


    MUDA_GENERIC TripletMatrixViewer(const Base& base)
        : Base(base)
    {
    }

    using Base::operator();

    class Proxy
    {
        friend class TripletMatrixViewer;
        TripletMatrixViewer& m_viewer;
        int                  m_index = 0;

      private:
        MUDA_GENERIC Proxy(TripletMatrixViewer& viewer, int index)
            : m_viewer(viewer)
            , m_index(index)
        {
        }

      public:
        MUDA_GENERIC auto read() &&
        {
            return std::as_const(m_viewer).operator()(m_index);
        }

        MUDA_GENERIC
        void write(int block_row_index, int block_col_index, const BlockMatrix& block) &&
        {
            auto index = m_viewer.get_index(m_index);

            m_viewer.check_in_submatrix(block_row_index, block_col_index);

            auto global_i = m_viewer.m_submatrix_offset.x + block_row_index;
            auto global_j = m_viewer.m_submatrix_offset.y + block_col_index;

            m_viewer.m_block_row_indices[index] = global_i;
            m_viewer.m_block_col_indices[index] = global_j;
            m_viewer.m_block_values[index]      = block;
        }

        MUDA_GENERIC ~Proxy() = default;
    };

    MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
};


template <bool IsConst, typename T>
class TripletMatrixViewerBase<IsConst, T, 1> : public ViewerBase<IsConst>
{
    using Base = ViewerBase<IsConst>;
  protected:
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

  public:
    using ConstViewer    = TripletMatrixViewerBase<true, T, 1>;
    using NonConstViewer = TripletMatrixViewerBase<false, T, 1>;
    using ThisViewer     = TripletMatrixViewerBase<IsConst, T, 1>;

    struct CTriplet
    {
        MUDA_GENERIC CTriplet(int row_index, int col_index, const T& block)
            : row_index(row_index)
            , col_index(col_index)
            , value(block)
        {
        }
        int      row_index;
        int      col_index;
        const T& value;
    };

  protected:
    // matrix info
    int m_total_rows = 0;
    int m_total_cols = 0;

    // triplet info
    int m_triplet_index_offset = 0;
    int m_triplet_count        = 0;
    int m_total_triplet_count  = 0;

    // sub matrix info
    int2 m_submatrix_offset = {0, 0};
    int2 m_submatrix_extent = {0, 0};

    // data
    auto_const_t<int>* m_row_indices;
    auto_const_t<int>* m_col_indices;
    auto_const_t<T>*   m_values;

  public:
    MUDA_GENERIC TripletMatrixViewerBase() = default;
    MUDA_GENERIC TripletMatrixViewerBase(int total_rows,
                                         int total_cols,

                                         int triplet_index_offset,
                                         int triplet_count,
                                         int total_triplet_count,

                                         int2 submatrix_offset,
                                         int2 submatrix_extent,

                                         auto_const_t<int>* row_indices,
                                         auto_const_t<int>* col_indices,
                                         auto_const_t<T>*   values)
        : m_total_rows(total_rows)
        , m_total_cols(total_cols)
        , m_triplet_index_offset(triplet_index_offset)
        , m_triplet_count(triplet_count)
        , m_total_triplet_count(total_triplet_count)
        , m_submatrix_offset(submatrix_offset)
        , m_submatrix_extent(submatrix_extent)
        , m_row_indices(row_indices)
        , m_col_indices(col_indices)
        , m_values(values)
    {
        MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
                           "TripletMatrixViewer [%s:%s]: out of range, m_total_triplet_count=%d, "
                           "your triplet_index_offset=%d, triplet_count=%d",
                           this->name(),
                           this->kernel_name(),
                           total_triplet_count,
                           triplet_index_offset,
                           triplet_count);

        MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
                           "TripletMatrixViewer [%s:%s]: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.x,
                           submatrix_offset.y);

        MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_rows,
                           "TripletMatrixViewer [%s:%s]: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, rows=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.x,
                           submatrix_extent.x,
                           total_rows);

        MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_cols,
                           "TripletMatrixViewer [%s:%s]: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, cols=%d",
                           this->name(),
                           this->kernel_name(),
                           submatrix_offset.y,
                           submatrix_extent.y,
                           total_cols);
    }

    // implicit conversion

    MUDA_GENERIC ConstViewer as_const() const
    {
        return ConstViewer{m_total_rows,
                           m_total_cols,
                           m_triplet_index_offset,
                           m_triplet_count,
                           m_total_triplet_count,
                           m_submatrix_offset,
                           m_submatrix_extent,
                           m_row_indices,
                           m_col_indices,
                           m_values};
    }

    MUDA_GENERIC operator ConstViewer() const { return as_const(); }


    MUDA_GENERIC CTriplet operator()(int i) const
    {
        auto index = get_index(i);

        auto global_i = m_row_indices[index];
        auto global_j = m_col_indices[index];
        auto sub_i    = global_i - m_submatrix_offset.x;
        auto sub_j    = global_j - m_submatrix_offset.y;
        check_in_submatrix(sub_i, sub_j);
        return CTriplet{sub_i, sub_j, m_values[index]};
    }

    auto total_rows() const { return m_total_rows; }
    auto total_cols() const { return m_total_cols; }

    auto triplet_count() const { return m_triplet_count; }
    auto tripet_index_offset() const { return m_triplet_index_offset; }
    auto total_triplet_count() const { return m_total_triplet_count; }

    auto submatrix_offset() const { return m_submatrix_offset; }
    auto extent() const { return m_submatrix_extent; }
    auto total_extent() const { return int2{m_total_rows, m_total_cols}; }

  protected:
    MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
    {

        MUDA_KERNEL_ASSERT(i >= 0 && i < m_triplet_count,
                           "TripletMatrixViewer [%s:%s]: triplet_index out of range, block_count=%d, your index=%d",
                           this->name(),
                           this->kernel_name(),
                           m_triplet_count,
                           i);
        auto index = i + m_triplet_index_offset;
        return index;
    }

    MUDA_INLINE MUDA_GENERIC void check_in_submatrix(int i, int j) const noexcept
    {
        MUDA_KERNEL_ASSERT(i >= 0 && i < m_submatrix_extent.x,
                           "TripletMatrixViewer [%s:%s]: row index out of submatrix range, submatrix_extent.x=%d, yours=%d",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.x,
                           i);

        MUDA_KERNEL_ASSERT(j >= 0 && j < m_submatrix_extent.y,
                           "TripletMatrixViewer [%s:%s]: col index out of submatrix range, submatrix_extent.y=%d, yours=%d",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.y,
                           j);
    }
};

template <typename T>
class CTripletMatrixViewer<T, 1> : public TripletMatrixViewerBase<true, T, 1>
{
    using Base = TripletMatrixViewerBase<true, T, 1>;
    MUDA_VIEWER_COMMON_NAME(CTripletMatrixViewer);

  public:
    using Base::Base;
    using ConstViewer = CTripletMatrixViewer<T, 1>;

    MUDA_GENERIC CTripletMatrixViewer(const Base& base)
        : Base(base)
    {
    }
};

template <typename T>
class TripletMatrixViewer<T, 1> : public TripletMatrixViewerBase<false, T, 1>
{
    using Base = TripletMatrixViewerBase<false, T, 1>;
    MUDA_VIEWER_COMMON_NAME(TripletMatrixViewer);

  public:
    using ConstViewer    = CTripletMatrixViewer<T, 1>;
    using NonConstViewer = TripletMatrixViewer<T, 1>;

    using Base::Base;
    using CTriplet = typename Base::CTriplet;
    MUDA_GENERIC TripletMatrixViewer(const Base& base)
        : Base(base)
    {
    }

    class Proxy
    {
        friend class TripletMatrixViewer;
        TripletMatrixViewer& m_viewer;
        int                  m_index = 0;

      private:
        MUDA_GENERIC Proxy(TripletMatrixViewer& viewer, int index)
            : m_viewer(viewer)
            , m_index(index)
        {
        }

      public:
        MUDA_GENERIC auto read() &&
        {
            return std::as_const(m_viewer).operator()(m_index);
        }

        MUDA_GENERIC void write(int row_index, int col_index, const T& value) &&
        {
            auto index = m_viewer.get_index(m_index);
            m_viewer.check_in_submatrix(row_index, col_index);

            auto global_i = m_viewer.m_submatrix_offset.x + row_index;
            auto global_j = m_viewer.m_submatrix_offset.y + col_index;

            m_viewer.m_row_indices[index] = global_i;
            m_viewer.m_col_indices[index] = global_j;
            m_viewer.m_values[index]      = value;
        }

        MUDA_GENERIC ~Proxy() = default;
    };

    using Base::operator();

    MUDA_GENERIC Proxy operator()(int i)
    {
        auto index = Base::get_index(i);
        return Proxy{*this, index};
    }
};
}  // namespace muda

#include "details/triplet_matrix_viewer.inl"