Skip to content

File triplet_matrix_viewer.h

File List > ext > linear_system > triplet_matrix_viewer.h

Go to the documentation of this file

#pragma once
#include <string>
#include <muda/viewer/viewer_base.h>
#include <muda/buffer/device_buffer.h>
#include <muda/tools/cuda_vec_utils.h>
#include <muda/ext/eigen/eigen_core_cxx20.h>


/*
* - 2024/2/23 remove viewer's subview, view's subview is enough
*/

namespace muda
{
template <bool IsConst, typename T, int N>
class TripletMatrixViewerT : public ViewerBase<IsConst>
{
    using Base = ViewerBase<IsConst>;
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

    template <bool OtherIsConst, typename U, int M>
    friend class TripletMatrixViewerT;

    MUDA_VIEWER_COMMON_NAME(TripletMatrixViewerT);

  public:
    using ValueT      = std::conditional_t<N == 1, T, Eigen::Matrix<T, N, N>>;
    using ConstViewer = TripletMatrixViewerT<true, T, N>;
    using NonConstViewer = TripletMatrixViewerT<false, T, N>;
    using ThisViewer     = TripletMatrixViewerT<IsConst, T, N>;

    struct CTriplet
    {
        MUDA_GENERIC CTriplet(int row_index, int col_index, const ValueT& block)
            : row_index(row_index)
            , col_index(col_index)
            , value(block)
        {
        }

        int           row_index;
        int           col_index;
        const ValueT& value;
    };

    class Proxy
    {
        friend class TripletMatrixViewerT;
        const TripletMatrixViewerT& m_viewer;
        int                         m_index = 0;

      private:
        MUDA_GENERIC Proxy(const TripletMatrixViewerT& viewer, int index)
            : m_viewer(viewer)
            , m_index(index)
        {
        }

      public:
        MUDA_GENERIC auto read() && { return m_viewer.at(m_index); }

        MUDA_GENERIC
        void write(int row_index, int col_index, const ValueT& block) &&
        {
            auto index = m_viewer.get_index(m_index);

            m_viewer.check_in_submatrix(row_index, col_index);

            auto global_i = m_viewer.m_submatrix_offset.x + row_index;
            auto global_j = m_viewer.m_submatrix_offset.y + col_index;

            m_viewer.m_row_indices[index] = global_i;
            m_viewer.m_col_indices[index] = global_j;
            m_viewer.m_values[index]      = block;
        }

        MUDA_GENERIC ~Proxy() = default;
    };

  protected:
    // matrix info
    int m_total_rows = 0;
    int m_total_cols = 0;

    // triplet info
    int m_triplet_index_offset = 0;
    int m_triplet_count        = 0;
    int m_total_triplet_count  = 0;

    // sub matrix info
    int2 m_submatrix_offset = {0, 0};
    int2 m_submatrix_extent = {0, 0};

    // data
    auto_const_t<int>*    m_row_indices;
    auto_const_t<int>*    m_col_indices;
    auto_const_t<ValueT>* m_values;


  public:
    MUDA_GENERIC TripletMatrixViewerT() = default;

    MUDA_GENERIC TripletMatrixViewerT(int total_block_rows,
                                      int total_block_cols,
                                      int triplet_index_offset,
                                      int triplet_count,
                                      int total_triplet_count,

                                      int2 submatrix_offset,
                                      int2 submatrix_extent,

                                      auto_const_t<int>*    block_row_indices,
                                      auto_const_t<int>*    block_col_indices,
                                      auto_const_t<ValueT>* block_values)
        : m_total_rows(total_block_rows)
        , m_total_cols(total_block_cols)
        , m_triplet_index_offset(triplet_index_offset)
        , m_triplet_count(triplet_count)
        , m_total_triplet_count(total_triplet_count)
        , m_submatrix_offset(submatrix_offset)
        , m_submatrix_extent(submatrix_extent)
        , m_row_indices(block_row_indices)
        , m_col_indices(block_col_indices)
        , m_values(block_values)
    {
        MUDA_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
                    "TripletMatrixViewer [%s:%s]: out of range, m_total_triplet_count=%d, "
                    "your triplet_index_offset=%d, triplet_count=%d. %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    total_triplet_count,
                    triplet_index_offset,
                    triplet_count,
                    this->kernel_file(),
                    this->kernel_line());

        MUDA_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
                    "TripletMatrixViewer[%s:%s]: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d. %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    submatrix_offset.x,
                    submatrix_offset.y,
                    this->kernel_file(),
                    this->kernel_line());

        MUDA_ASSERT(submatrix_offset.x + submatrix_extent.x <= total_block_rows,
                    "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d. %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    submatrix_offset.x,
                    submatrix_extent.x,
                    total_block_rows,
                    this->kernel_file(),
                    this->kernel_line());

        MUDA_ASSERT(submatrix_offset.y + submatrix_extent.y <= total_block_cols,
                    "TripletMatrixViewer[%s:%s]: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d. %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    submatrix_offset.y,
                    submatrix_extent.y,
                    total_block_cols,
                    this->kernel_file(),
                    this->kernel_line());
    }

    template <bool OtherIsConst>
    MUDA_GENERIC TripletMatrixViewerT(const TripletMatrixViewerT<OtherIsConst, T, N>& other)
        : m_total_rows(other.m_total_rows)
        , m_total_cols(other.m_total_cols)
        , m_triplet_index_offset(other.m_triplet_index_offset)
        , m_triplet_count(other.m_triplet_count)
        , m_total_triplet_count(other.m_total_triplet_count)
        , m_submatrix_offset(other.m_submatrix_offset)
        , m_submatrix_extent(other.m_submatrix_extent)
        , m_row_indices(other.m_row_indices)
        , m_col_indices(other.m_col_indices)
        , m_values(other.m_values)
    {
    }

    MUDA_GENERIC ConstViewer as_const() const
    {
        return ConstViewer{m_total_rows,
                           m_total_cols,
                           m_triplet_index_offset,
                           m_triplet_count,
                           m_total_triplet_count,
                           m_submatrix_offset,
                           m_submatrix_extent,
                           m_row_indices,
                           m_col_indices,
                           m_values};
    }

    MUDA_GENERIC auto total_rows() const { return m_total_rows; }

    MUDA_GENERIC auto total_cols() const { return m_total_cols; }

    MUDA_GENERIC auto total_extent() const
    {
        return int2{m_total_rows, m_total_cols};
    }

    MUDA_GENERIC auto submatrix_offset() const { return m_submatrix_offset; }

    MUDA_GENERIC auto extent() const { return m_submatrix_extent; }

    MUDA_GENERIC auto triplet_count() const { return m_triplet_count; }

    MUDA_GENERIC auto tripet_index_offset() const
    {
        return m_triplet_index_offset;
    }
    MUDA_GENERIC auto total_triplet_count() const
    {
        return m_total_triplet_count;
    }

    MUDA_GENERIC auto operator()(int i) const
    {
        if constexpr(IsConst)
        {
            return at(i);
        }
        else
        {
            return Proxy{*this, i};
        }
    }

  protected:
    MUDA_GENERIC MUDA_INLINE CTriplet at(int i) const noexcept
    {
        auto index    = get_index(i);
        auto global_i = m_row_indices[index];
        auto global_j = m_col_indices[index];
        auto sub_i    = global_i - m_submatrix_offset.x;
        auto sub_j    = global_j - m_submatrix_offset.y;
        check_in_submatrix(sub_i, sub_j);
        return CTriplet{sub_i, sub_j, m_values[index]};
    }

    MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
    {

        MUDA_KERNEL_ASSERT(i >= 0 && i < m_triplet_count,
                           "TripletMatrixViewer [%s:%s]: triplet_index out of range, block_count=%d, your index=%d. %s(%d)",
                           this->name(),
                           this->kernel_name(),
                           m_triplet_count,
                           i,
                           this->kernel_file(),
                           this->kernel_line());
        auto index = i + m_triplet_index_offset;
        return index;
    }

    MUDA_INLINE MUDA_GENERIC void check_in_submatrix(int i, int j) const noexcept
    {
        MUDA_KERNEL_ASSERT(i >= 0 && i < m_submatrix_extent.x,
                           "TripletMatrixViewer [%s:%s]: row index out of submatrix range,  submatrix_extent.x=%d, your i=%d. %s(%d)",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.x,
                           i,
                           this->kernel_file(),
                           this->kernel_line());

        MUDA_KERNEL_ASSERT(j >= 0 && j < m_submatrix_extent.y,
                           "TripletMatrixViewer [%s:%s]: col index out of submatrix range,  submatrix_extent.y=%d, your j=%d. %s(%d)",
                           this->name(),
                           this->kernel_name(),
                           m_submatrix_extent.y,
                           j,
                           this->kernel_file(),
                           this->kernel_line());
    }
};

template <typename T, int N>
using TripletMatrixViewer = TripletMatrixViewerT<false, T, N>;

template <typename T, int N>
using CTripletMatrixViewer = TripletMatrixViewerT<true, T, N>;
}  // namespace muda

#include "details/triplet_matrix_viewer.inl"