File doublet_vector_viewer.h
File List > ext > linear_system > doublet_vector_viewer.h
Go to the documentation of this file
#pragma once
#include <string>
#include <muda/viewer/viewer_base.h>
#include <muda/ext/eigen/eigen_core_cxx20.h>
/*
* - 2024/2/23 remove viewer's subview, view's subview is enough
*/
namespace muda
{
template <bool IsConst, typename T, int N>
class DoubletVectorViewerBase : public ViewerBase<IsConst>
{
using Base = ViewerBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
using SegmentVector = Eigen::Matrix<T, N, 1>;
using ConstViewer = DoubletVectorViewerBase<true, T, N>;
using NonConstViewer = DoubletVectorViewerBase<false, T, N>;
using ThisViewer = DoubletVectorViewerBase<IsConst, T, N>;
struct CDoublet
{
MUDA_GENERIC CDoublet(int index, const SegmentVector& segment)
: index(index)
, segment_value(segment)
{
}
int index;
const SegmentVector& segment_value;
};
protected:
// vector info
int m_total_segment_count = 0;
// doublet info
int m_doublet_index_offset = 0;
int m_doublet_count = 0;
int m_total_doublet_count = 0;
// subvector info
int m_subvector_offset = 0;
int m_subvector_extent = 0;
// data
auto_const_t<int>* m_segment_indices;
auto_const_t<SegmentVector>* m_segment_values;
public:
MUDA_GENERIC DoubletVectorViewerBase() = default;
MUDA_GENERIC DoubletVectorViewerBase(int total_segment_count,
int doublet_index_offset,
int doublet_count,
int total_doublet_count,
int subvector_offset,
int subvector_extent,
auto_const_t<int>* segment_indices,
auto_const_t<SegmentVector>* segment_values)
: m_total_segment_count(total_segment_count)
, m_doublet_index_offset(doublet_index_offset)
, m_doublet_count(doublet_count)
, m_total_doublet_count(total_doublet_count)
, m_subvector_offset(subvector_offset)
, m_subvector_extent(subvector_extent)
, m_segment_indices(segment_indices)
, m_segment_values(segment_values)
{
MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
"DoubletVectorViewer: out of range, m_total_doublet_count=%d, "
"your doublet_index_offset=%d, doublet_count=%d",
m_total_doublet_count,
doublet_index_offset,
doublet_count);
MUDA_KERNEL_ASSERT(subvector_offset + subvector_extent <= total_segment_count,
"DoubletVectorViewer: out of range, m_total_segment_count=%d, "
"your subvector_offset=%d, subvector_extent=%d",
m_total_segment_count,
subvector_offset,
subvector_extent);
}
// implicit conversion
MUDA_GENERIC ConstViewer as_const() const noexcept
{
return ConstViewer{m_total_segment_count,
m_doublet_index_offset,
m_doublet_count,
m_total_doublet_count,
m_subvector_offset,
m_subvector_extent,
m_segment_indices,
m_segment_values};
}
MUDA_GENERIC operator ConstViewer() const noexcept { return as_const(); }
// const access
MUDA_GENERIC int doublet_count() const noexcept { return m_doublet_count; }
MUDA_GENERIC int total_doublet_count() const noexcept
{
return m_total_doublet_count;
}
MUDA_GENERIC CDoublet operator()(int i) const
{
auto index = get_index(i);
auto global_i = m_segment_indices[index];
auto sub_i = global_i - m_subvector_offset;
check_in_subvector(sub_i);
return CDoublet{sub_i, m_segment_values[index]};
}
protected:
MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
{
MUDA_KERNEL_ASSERT(i >= 0 && i < m_doublet_count,
"DoubletVectorViewer [%s:%s]: index out of range, m_doublet_count=%d, your index=%d",
this->name(),
this->kernel_name(),
m_doublet_count,
i);
auto index = i + m_doublet_index_offset;
return index;
}
MUDA_INLINE MUDA_GENERIC void check_in_subvector(int i) const noexcept
{
MUDA_KERNEL_ASSERT(i >= 0 && i < m_subvector_extent,
"DoubletVectorViewer [%s:%s]: index out of range, m_subvector_extent=%d, your index=%d",
this->name(),
this->kernel_name(),
m_subvector_extent,
i);
}
};
template <typename T, int N>
class CDoubletVectorViewer : public DoubletVectorViewerBase<true, T, N>
{
using Base = DoubletVectorViewerBase<true, T, N>;
MUDA_VIEWER_COMMON_NAME(CDoubletVectorViewer);
public:
using Base::Base;
using SegmentVector = typename Base::SegmentVector;
MUDA_GENERIC CDoubletVectorViewer(const Base& base)
: Base(base)
{
}
};
template <typename T, int N>
class DoubletVectorViewer : public DoubletVectorViewerBase<false, T, N>
{
using Base = DoubletVectorViewerBase<false, T, N>;
MUDA_VIEWER_COMMON_NAME(DoubletVectorViewer);
public:
using SegmentVector = typename Base::SegmentVector;
using CDoublet = typename Base::CDoublet;
using Base::Base;
MUDA_GENERIC DoubletVectorViewer(const Base& base)
: Base(base)
{
}
using Base::operator();
class Proxy
{
friend class DoubletVectorViewer;
DoubletVectorViewer& m_viewer;
int m_index = 0;
private:
MUDA_GENERIC Proxy(DoubletVectorViewer& viewer, int index)
: m_viewer(viewer)
, m_index(index)
{
}
public:
MUDA_GENERIC auto read() &&
{
return std::as_const(m_viewer).operator()(m_index);
}
MUDA_GENERIC void write(int segment_i, const SegmentVector& block) &&
{
auto index = m_viewer.get_index(m_index);
m_viewer.check_in_subvector(segment_i);
auto global_i = segment_i + m_viewer.m_subvector_offset;
m_viewer.m_segment_indices[index] = global_i;
m_viewer.m_segment_values[index] = block;
}
MUDA_GENERIC ~Proxy() = default;
};
MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
};
template <bool IsConst, typename T>
class DoubletVectorViewerBase<IsConst, T, 1> : public ViewerBase<IsConst>
{
using Base = ViewerBase<IsConst>;
protected:
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
using ConstViewer = DoubletVectorViewerBase<true, T, 1>;
using Viewer = DoubletVectorViewerBase<false, T, 1>;
using ThisViewer = DoubletVectorViewerBase<IsConst, T, 1>;
struct CDoublet
{
MUDA_GENERIC CDoublet(int index, const T& segment)
: index(index)
, value(segment)
{
}
int index;
const T& value;
};
protected:
// vector info
int m_total_count = 0;
// doublet info
int m_doublet_index_offset = 0;
int m_doublet_count = 0;
int m_total_doublet_count = 0;
// subvector info
int m_subvector_offset = 0;
int m_subvector_extent = 0;
auto_const_t<int>* m_indices;
auto_const_t<T>* m_values;
public:
MUDA_GENERIC DoubletVectorViewerBase() = default;
MUDA_GENERIC DoubletVectorViewerBase(int total_count,
int doublet_index_offset,
int doublet_count,
int total_doublet_count,
int subvector_offset,
int subvector_extent,
auto_const_t<int>* indices,
auto_const_t<T>* values)
: m_total_count(total_count)
, m_doublet_index_offset(doublet_index_offset)
, m_doublet_count(doublet_count)
, m_total_doublet_count(total_doublet_count)
, m_indices(indices)
, m_values(values)
{
MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
"DoubletVectorViewer: out of range, m_total_doublet_count=%d, "
"your doublet_index_offset=%d, doublet_count=%d",
m_total_doublet_count,
doublet_index_offset,
doublet_count);
MUDA_KERNEL_ASSERT(subvector_offset + subvector_extent <= total_count,
"DoubletVectorViewer: out of range, m_total_segment_count=%d, "
"your subvector_offset=%d, subvector_extent=%d",
m_total_count,
subvector_offset,
subvector_extent);
}
// implicit conversion
MUDA_GENERIC ConstViewer as_const() const noexcept
{
return ConstViewer{m_total_count,
m_doublet_index_offset,
m_doublet_count,
m_total_doublet_count,
m_subvector_offset,
m_subvector_extent,
m_indices,
m_values};
}
MUDA_GENERIC operator ConstViewer() const noexcept { return as_const(); }
// non-const access
MUDA_GENERIC CDoublet operator()(int i) const
{
check_in_subvector(i);
auto index = get_index(i);
auto global_i = m_indices[index];
auto sub_i = global_i - m_subvector_offset;
return CDoublet{sub_i, m_values[index]};
}
MUDA_GENERIC int extent() const noexcept { return m_subvector_extent; }
MUDA_GENERIC int total_extent() const noexcept { return m_total_count; }
MUDA_GENERIC int subvector_offset() const noexcept
{
return m_subvector_offset;
}
MUDA_GENERIC int doublet_count() const noexcept { return m_doublet_count; }
MUDA_GENERIC int total_doublet_count() const noexcept
{
return m_total_doublet_count;
}
protected:
MUDA_INLINE MUDA_GENERIC int get_index(int i) const noexcept
{
MUDA_KERNEL_ASSERT(i >= 0 && i < m_doublet_count,
"DoubletVectorViewer [%s:%s]: index out of range, m_doublet_count=%d, your index=%d",
this->name(),
this->kernel_name(),
m_doublet_count,
i);
auto index = i + m_doublet_index_offset;
return index;
}
MUDA_INLINE MUDA_GENERIC void check_in_subvector(int i) const noexcept
{
MUDA_KERNEL_ASSERT(i >= 0 && i < m_subvector_extent,
"DoubletVectorViewer [%s:%s]: index out of range, m_subvector_extent=%d, your index=%d",
this->name(),
this->kernel_name(),
m_subvector_extent,
i);
}
};
template <typename T>
class CDoubletVectorViewer<T, 1> : public DoubletVectorViewerBase<true, T, 1>
{
using Base = DoubletVectorViewerBase<true, T, 1>;
MUDA_VIEWER_COMMON_NAME(CDoubletVectorViewer);
public:
using Base::Base;
using ThisViewer = CDoubletVectorViewer<T, 1>;
MUDA_GENERIC CDoubletVectorViewer(const Base& base)
: Base(base)
{
}
};
template <typename T>
class DoubletVectorViewer<T, 1> : public DoubletVectorViewerBase<false, T, 1>
{
using Base = DoubletVectorViewerBase<false, T, 1>;
MUDA_VIEWER_COMMON_NAME(DoubletVectorViewer);
public:
using CDoublet = typename Base::CDoublet;
using ThisViewer = DoubletVectorViewer<T, 1>;
using ConstViewer = CDoubletVectorViewer<T, 1>;
using Base::Base;
MUDA_GENERIC DoubletVectorViewer(const Base& base)
: Base(base)
{
}
using Base::operator();
class Proxy
{
friend class DoubletVectorViewer;
DoubletVectorViewer& m_viewer;
int m_index = 0;
private:
MUDA_GENERIC Proxy(DoubletVectorViewer& viewer, int index)
: m_viewer(viewer)
, m_index(index)
{
}
public:
MUDA_GENERIC auto read() &&
{
return std::as_const(m_viewer).operator()(m_index);
}
MUDA_GENERIC void write(int i, const T& value) &&
{
m_viewer.check_in_subvector(i);
auto index = m_viewer.get_index(m_index);
auto global_i = i + m_viewer.m_subvector_offset;
m_viewer.m_indices[index] = global_i;
m_viewer.m_values[index] = value;
}
MUDA_GENERIC ~Proxy() = default;
};
MUDA_GENERIC Proxy operator()(int i) { return Proxy{*this, i}; }
};
} // namespace muda
#include "details/doublet_vector_viewer.inl"