Skip to content

File dense_matrix_viewer.h

File List > ext > linear_system > dense_matrix_viewer.h

Go to the documentation of this file

#pragma once
#include <muda/ext/eigen/eigen_core_cxx20.h>
#include <muda/buffer/buffer_2d_view.h>
#include <muda/viewer/viewer_base.h>
#include <muda/atomic.h>

namespace muda
{
template <bool IsConst, typename T>
class DenseMatrixViewerT : public ViewerBase<IsConst>
{
    static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
                  "now only support real number");
    static_assert(!std::is_const_v<T>, "T must be non-const type");

    using Base = ViewerBase<IsConst>;
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

    template <bool OtherIsConst, typename U>
    friend class DenseMatrixViewerT;


  public:
    using CBuffer2DView = CBuffer2DView<T>;
    using Buffer2DView  = Buffer2DView<T>;
    using ThisBuffer2DView = std::conditional_t<IsConst, CBuffer2DView, Buffer2DView>;

    using ConstViewer    = DenseMatrixViewerT<true, T>;
    using NonConstViewer = DenseMatrixViewerT<false, T>;
    using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;

    using MatrixType = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
    template <typename U>
    using MapMatrixT =
        Eigen::Map<U, Eigen::AlignmentType::Unaligned, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;
    using MapMatrix     = MapMatrixT<MatrixType>;
    using CMapMatrix    = MapMatrixT<const MatrixType>;
    using ThisMapMatrix = std::conditional_t<IsConst, CMapMatrix, MapMatrix>;

    MUDA_VIEWER_COMMON_NAME(DenseMatrixViewerT);

  protected:
    ThisBuffer2DView m_view;
    size_t           m_row_offset = 0;
    size_t           m_col_offset = 0;
    size_t           m_row_size   = 0;
    size_t           m_col_size   = 0;

  public:
    MUDA_GENERIC DenseMatrixViewerT(ThisBuffer2DView view,
                                    size_t           row_offset,
                                    size_t           col_offset,
                                    size_t           row_size,
                                    size_t           col_size)
        : m_view(view)
        , m_row_offset(row_offset)
        , m_col_offset(col_offset)
        , m_row_size(row_size)
        , m_col_size(col_size)
    {
    }

    template <bool OtherIsConst>
    MUDA_GENERIC DenseMatrixViewerT(const DenseMatrixViewerT<OtherIsConst, T>& other)
        MUDA_REQUIRES(IsConst)
        : m_view(other.m_view)
        , m_row_offset(other.m_row_offset)
        , m_col_offset(other.m_col_offset)
        , m_row_size(other.m_row_size)
        , m_col_size(other.m_col_size)
    {
        static_assert(IsConst);
    }

    MUDA_GENERIC auto as_const() const
    {
        return ConstViewer{m_view, m_row_offset, m_col_offset, m_row_size, m_col_size};
    }

    MUDA_GENERIC ThisViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size) const
    {
        MUDA_ASSERT(row_offset + row_size <= m_row_size && col_offset + col_size <= m_col_size,
                    "DenseMatrixViewerBase [%s:%s]: block index out of range, shape=(%lld,%lld), yours index=(%lld,%lld). %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    m_row_size,
                    m_col_size,
                    row_offset,
                    col_offset,
                    this->kernel_file(),
                    this->kernel_line());

        auto ret = DenseMatrixViewerT{
            m_view, m_row_offset + row_offset, m_col_offset + col_offset, row_size, col_size};
        ret.copy_label(*this);
        return ret;
    }

    template <int M, int N>
    MUDA_GENERIC ThisViewer block(int row_offset, int col_offset)
    {
        return block(row_offset, col_offset, M, N);
    }
    MUDA_GENERIC Eigen::Block<ThisMapMatrix> as_eigen() const
    {
        auto outer = m_view.pitch_bytes() / sizeof(T);

        return ThisMapMatrix{m_view.origin_data(),
                             (int)origin_row(),
                             (int)origin_col(),
                             Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{(int)outer, 1}}
            .block(m_row_offset, m_col_offset, m_row_size, m_col_size);
    }

    MUDA_GENERIC auto_const_t<T>& operator()(size_t i, size_t j) const
    {
        if constexpr(DEBUG_VIEWER)
        {
            MUDA_ASSERT(m_view.data(0),
                        "DenseMatrixViewer [%s:%s]: data is null",
                        this->name(),
                        this->kernel_name());
            if(m_row_offset == 0 && m_col_offset == 0)
            {
                MUDA_ASSERT(i < m_row_size && j < m_col_size,
                            "DenseMatrixViewer [%s:%s]: index out of range, shape=(%lld,%lld), yours index=(%lld,%lld)",
                            this->name(),
                            this->kernel_name(),
                            m_row_size,
                            m_col_size,
                            i,
                            j);
            }
            else
            {
                MUDA_ASSERT(i < m_row_size && j < m_col_size,
                            "DenseMatrixViewer [%s:%s]:index out of range, block shape=(%lld,%lld), your index=(%lld,%lld)",
                            this->name(),
                            this->kernel_name(),
                            m_row_size,
                            m_col_size,
                            i,
                            j);
            }
        }
        i += m_row_offset;
        j += m_col_offset;
        return *m_view.data(j, i);
    }

    MUDA_GENERIC auto buffer_view() const { return m_view; }

    MUDA_GENERIC size_t row() const { return m_row_size; }

    MUDA_GENERIC size_t col() const { return m_col_size; }

    MUDA_GENERIC size_t origin_row() const
    {
        size_t ret;
        ret = m_view.extent().width();
        return ret;
    }

    MUDA_GENERIC size_t origin_col() const
    {
        size_t ret;
        ret = m_view.extent().height();
        return ret;
    }

    MUDA_GENERIC auto row_offset() const { return m_row_offset; }

    MUDA_GENERIC auto col_offset() const { return m_col_offset; }

    MUDA_DEVICE T atomic_add(size_t i, size_t j, T val) const MUDA_REQUIRES(!IsConst)
    {
        static_assert(!IsConst, "Cannot assign to a const viewer");
        auto ptr = &this->operator()(i, j);
        muda::atomic_add(ptr, val);
        return val;
    }

    template <int M, int N>
    MUDA_DEVICE Eigen::Matrix<T, M, N> atomic_add(const Eigen::Matrix<T, M, N>& other) const
        MUDA_REQUIRES(!IsConst)
    {
        static_assert(!IsConst, "Cannot assign to a const viewer");
        check_size_matching(M, N);
        Eigen::Matrix<T, M, N> ret;
#pragma unroll
        for(int i = 0; i < M; ++i)
#pragma unroll
            for(int j = 0; j < N; ++j)
            {
                ret(i, j) = atomic_add(i, j, other(i, j));
            }
        return ret;
    }

    template <int M, int N>
    MUDA_GENERIC DenseMatrixViewerT& operator=(const Eigen::Matrix<T, M, N>& other) const
        MUDA_REQUIRES(!IsConst)
    {
        static_assert(!IsConst, "Cannot assign to a const viewer");
        check_size_matching(M, N);
#pragma unroll
        for(int i = 0; i < M; ++i)
#pragma unroll
            for(int j = 0; j < N; ++j)
                (*this)(i, j) = other(i, j);
        return *this;
    }

  private:
    MUDA_GENERIC void check_size_matching(int M, int N) const
    {
        MUDA_ASSERT(this->m_row_size == M && this->m_col_size == N,
                    "DenseMatrixViewer [%s:%s] shape mismatching, Viewer=(%lld,%lld), yours=(%lld,%lld). %s(%d)",
                    this->name(),
                    this->kernel_name(),
                    this->m_row_size,
                    this->m_col_size,
                    M,
                    N,
                    this->kernel_file(),
                    this->kernel_line());
    }
};

template <typename T>
using DenseMatrixViewer = DenseMatrixViewerT<false, T>;

template <typename T>
using CDenseMatrixViewer = DenseMatrixViewerT<true, T>;
}  // namespace muda

#include "details/dense_matrix_viewer.inl"