Skip to content

File doublet_vector_viewer.h

File List > ext > linear_system > doublet_vector_viewer.h

Go to the documentation of this file

#pragma once

#include <string>
#include <muda/viewer/viewer_base.h>
#include <muda/ext/eigen/eigen_core_cxx20.h>

/*
* - 2024/2/23 remove viewer's subview, view's subview is enough
*/

namespace muda
{
template <bool IsConst, typename T, int N>
class DoubletVectorViewerT : public ViewerBase<IsConst>
{
    using Base = ViewerBase<IsConst>;
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

    template <bool OtherIsConst, typename U, int M>
    friend class DoubletVectorViewerT;

    MUDA_VIEWER_COMMON_NAME(DoubletVectorViewerT);

  public:
    using ValueT      = std::conditional_t<N == 1, T, Eigen::Matrix<T, N, 1>>;
    using ConstViewer = DoubletVectorViewerT<true, T, N>;
    using NonConstViewer = DoubletVectorViewerT<false, T, N>;
    using ThisViewer     = DoubletVectorViewerT<IsConst, T, N>;

    struct CDoublet
    {
        MUDA_GENERIC CDoublet(int index, const ValueT& segment)
            : index(index)
            , value(segment)
        {
        }
        int           index;
        const ValueT& value;
    };

    class Proxy
    {
        friend class DoubletVectorViewerT;
        const DoubletVectorViewerT& m_viewer;
        int                         m_index = 0;

      private:
        MUDA_GENERIC Proxy(const DoubletVectorViewerT& viewer, int index)
            : m_viewer(viewer)
            , m_index(index)
        {
        }

      public:
        MUDA_GENERIC auto read() && { return m_viewer.at(m_index); }

        MUDA_GENERIC void write(int segment_i, const ValueT& value) &&
        {
            auto index = m_viewer.get_index(m_index);

            m_viewer.check_in_subvector(segment_i);

            auto global_i = segment_i + m_viewer.m_subvector_offset;

            m_viewer.m_segment_indices[index] = global_i;
            m_viewer.m_segment_values[index]  = value;
        }

        MUDA_GENERIC ~Proxy() = default;
    };

  protected:
    // vector info
    int m_total_segment_count = 0;

    // doublet info
    int m_doublet_index_offset = 0;
    int m_doublet_count        = 0;
    int m_total_doublet_count  = 0;

    // subvector info
    int m_subvector_offset = 0;
    int m_subvector_extent = 0;

    // data
    auto_const_t<int>*    m_segment_indices;
    auto_const_t<ValueT>* m_segment_values;

  public:
    MUDA_GENERIC DoubletVectorViewerT() = default;
    MUDA_GENERIC DoubletVectorViewerT(int                total_segment_count,
                                      int                doublet_index_offset,
                                      int                doublet_count,
                                      int                total_doublet_count,
                                      int                subvector_offset,
                                      int                subvector_extent,
                                      auto_const_t<int>* segment_indices,
                                      auto_const_t<ValueT>* segment_values)
        : m_total_segment_count(total_segment_count)
        , m_doublet_index_offset(doublet_index_offset)
        , m_doublet_count(doublet_count)
        , m_total_doublet_count(total_doublet_count)
        , m_subvector_offset(subvector_offset)
        , m_subvector_extent(subvector_extent)
        , m_segment_indices(segment_indices)
        , m_segment_values(segment_values)
    {
        MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
                           "DoubletVectorViewer: out of range, m_total_doublet_count=%d, "
                           "your doublet_index_offset=%d, doublet_count=%d. %s(%d)",
                           m_total_doublet_count,
                           doublet_index_offset,
                           doublet_count,
                           this->kernel_file(),
                           this->kernel_line());

        MUDA_KERNEL_ASSERT(subvector_offset + subvector_extent <= total_segment_count,
                           "DoubletVectorViewer: out of range, m_total_segment_count=%d, "
                           "your subvector_offset=%d, subvector_extent=%d. %s(%d)",
                           m_total_segment_count,
                           subvector_offset,
                           subvector_extent,
                           this->kernel_file(),
                           this->kernel_line());
    }

    template <bool OtherIsConst>
    MUDA_GENERIC DoubletVectorViewerT(const DoubletVectorViewerT<OtherIsConst, T, N>& other) noexcept
        MUDA_REQUIRES(IsConst)
        : m_total_segment_count(other.m_total_segment_count)
        , m_doublet_index_offset(other.m_doublet_index_offset)
        , m_doublet_count(other.m_doublet_count)
        , m_total_doublet_count(other.m_total_doublet_count)
        , m_subvector_offset(other.m_subvector_offset)
        , m_subvector_extent(other.m_subvector_extent)
        , m_segment_indices(other.m_segment_indices)
        , m_segment_values(other.m_segment_values)
    {
        static_assert(IsConst);
    }

    MUDA_GENERIC ConstViewer as_const() const noexcept
    {
        return ConstViewer{m_total_segment_count,
                           m_doublet_index_offset,
                           m_doublet_count,
                           m_total_doublet_count,
                           m_subvector_offset,
                           m_subvector_extent,
                           m_segment_indices,
                           m_segment_values};
    }

    MUDA_GENERIC int doublet_count() const noexcept { return m_doublet_count; }

    MUDA_GENERIC int total_doublet_count() const noexcept
    {
        return m_total_doublet_count;
    }

    MUDA_GENERIC auto operator()(int i) const
    {
        if constexpr(IsConst)
        {
            return at(i);
        }
        else
        {
            return Proxy{*this, i};
        }
    }

  protected:
    MUDA_INLINE MUDA_GENERIC CDoublet at(int i) const
    {
        auto index    = get_index(i);
        auto global_i = m_segment_indices[index];
        auto sub_i    = global_i - m_subvector_offset;

        check_in_subvector(sub_i);
        return CDoublet{sub_i, m_segment_values[index]};
    }


    MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
    {
        MUDA_KERNEL_ASSERT(i >= 0 && i < m_doublet_count,
                           "DoubletVectorViewer [%s:%s]: index out of range, m_doublet_count=%d, your index=%d. %s(%d)",
                           this->name(),
                           this->kernel_name(),
                           m_doublet_count,
                           i,
                           this->kernel_file(),
                           this->kernel_line());
        auto index = i + m_doublet_index_offset;
        return index;
    }

    MUDA_INLINE MUDA_GENERIC void check_in_subvector(int i) const noexcept
    {
        MUDA_KERNEL_ASSERT(i >= 0 && i < m_subvector_extent,
                           "DoubletVectorViewer [%s:%s]: index out of range, m_subvector_extent=%d, your index=%d. %s(%d)",
                           this->name(),
                           this->kernel_name(),
                           m_subvector_extent,
                           i,
                           this->kernel_file(),
                           this->kernel_line());
    }
};

template <typename T, int N>
using DoubletVectorViewer = DoubletVectorViewerT<false, T, N>;

template <typename T, int N>
using CDoubletVectorViewer = DoubletVectorViewerT<true, T, N>;
}  // namespace muda

#include "details/doublet_vector_viewer.inl"