File dense_vector_viewer.h
File List > ext > linear_system > dense_vector_viewer.h
Go to the documentation of this file
#pragma once
#include <muda/ext/eigen/eigen_core_cxx20.h>
#include <muda/buffer/buffer_2d_view.h>
#include <muda/viewer/viewer_base.h>
#include <cublas_v2.h>
#include <muda/atomic.h>
namespace muda
{
template <bool IsConst, typename T>
class DenseVectorViewerT : public ViewerBase<IsConst>
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"now only support real number");
using Base = ViewerBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
template <bool OtherIsConst, typename U>
friend class DenseVectorViewerT;
public:
using CBufferView = CBufferView<T>;
using BufferView = BufferView<T>;
using ThisBufferView = std::conditional_t<IsConst, CBufferView, BufferView>;
using ConstViewer = DenseVectorViewerT<true, T>;
using NonConstViewer = DenseVectorViewerT<false, T>;
using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
using VectorType = Eigen::Vector<T, Eigen::Dynamic>;
template <typename U>
using MapVectorT =
Eigen::Map<U, Eigen::AlignmentType::Unaligned, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;
using MapVector = MapVectorT<VectorType>;
using CMapVector = MapVectorT<const VectorType>;
using ThisMapVector = std::conditional_t<IsConst, CMapVector, MapVector>;
MUDA_VIEWER_COMMON_NAME(DenseVectorViewerT);
protected:
auto_const_t<T>* m_data;
int m_offset = 0;
int m_size = 0;
int m_origin_size = 0;
public:
MUDA_GENERIC DenseVectorViewerT(auto_const_t<T>* data, int offset, int size, int origin_size)
: m_data(data)
, m_offset(offset)
, m_size(size)
, m_origin_size(origin_size)
{
}
template <bool OtherIsConst>
MUDA_GENERIC DenseVectorViewerT(const DenseVectorViewerT<OtherIsConst, T>& other)
MUDA_REQUIRES(IsConst)
: m_data(other.m_data)
, m_offset(other.m_offset)
, m_size(other.m_size)
, m_origin_size(other.m_origin_size)
{
static_assert(IsConst);
}
MUDA_GENERIC auto as_const() const
{
return ConstViewer{m_data, m_offset, m_size, m_origin_size};
}
MUDA_GENERIC auto segment(int offset, int size) const
{
check_segment(offset, size);
auto ret = ThisViewer{m_data, m_offset + offset, size, m_origin_size};
ret.copy_label(*this);
return ret;
}
template <int N>
MUDA_GENERIC auto segment(int offset) const
{
return segment(offset, N);
}
MUDA_GENERIC auto_const_t<T>& operator()(int i) const
{
return m_data[index(i)];
}
MUDA_GENERIC Eigen::VectorBlock<ThisMapVector> as_eigen() const
{
check_data();
return ThisMapVector{m_data,
(int)origin_size(),
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{1, 1}}
.segment(m_offset, m_size);
}
MUDA_GENERIC operator Eigen::VectorBlock<ThisMapVector>()
{
return as_eigen();
}
MUDA_GENERIC auto size() const { return m_size; }
MUDA_GENERIC auto offset() const { return m_offset; }
MUDA_GENERIC auto origin_data() const { return m_data; }
MUDA_GENERIC auto origin_size() const { return m_origin_size; }
MUDA_DEVICE T atomic_add(int i, T val) const MUDA_REQUIRES(!IsConst)
{
auto ptr = &this->operator()(i);
return muda::atomic_add(ptr, val);
}
template <int N>
MUDA_DEVICE Eigen::Vector<T, N> atomic_add(const Eigen::Vector<T, N>& val) const
MUDA_REQUIRES(!IsConst)
{
this->check_size_matching(N);
Eigen::Vector<T, N> ret;
#pragma unroll
for(int i = 0; i < N; ++i)
{
ret(i) = atomic_add(i, val(i));
}
return ret;
}
MUDA_DEVICE T atomic_add(const T& val)
{
this->check_size_matching(1);
T ret = atomic_add(0, val);
return ret;
}
template <int N>
MUDA_GENERIC DenseVectorViewerT& operator=(const Eigen::Vector<T, N>& other)
{
this->check_size_matching(N);
#pragma unroll
for(int i = 0; i < N; ++i)
{
this->operator()(i) = other(i);
}
return *this;
}
protected:
MUDA_INLINE MUDA_GENERIC void check_size_matching(int N) const
{
MUDA_KERNEL_ASSERT(m_size == N,
"DenseVectorViewerBase [%s:%s]: size not match, yours size=%d, expected size=%d. %s(%d)",
this->name(),
this->kernel_name(),
m_size,
N,
this->kernel_file(),
this->kernel_line());
}
MUDA_INLINE MUDA_GENERIC int index(int i) const
{
MUDA_KERNEL_ASSERT(origin_data(),
"DenseVectorViewerBase [%s:%s]: data is null. %s(%d)",
this->name(),
this->kernel_name(),
this->kernel_file(),
this->kernel_line());
MUDA_KERNEL_ASSERT(i < m_size,
"DenseVectorViewerBase [%s:%s]: index out of range, size=%d, yours index=%d. %s(%d)",
this->name(),
this->kernel_name(),
m_size,
i,
this->kernel_file(),
this->kernel_line());
return m_offset + i;
}
MUDA_INLINE MUDA_GENERIC void check_data() const
{
MUDA_KERNEL_ASSERT(origin_data(),
"DenseVectorViewerBase [%s:%s]: data is null. %s(%d)",
this->name(),
this->kernel_name(),
this->kernel_file(),
this->kernel_line());
}
MUDA_INLINE MUDA_GENERIC void check_segment(int offset, int size) const
{
MUDA_KERNEL_ASSERT(offset + size <= m_size,
"DenseVectorViewerBase [%s:%s]: segment out of range, m_size=%d, offset=%d, size=%d. %s(%d)",
this->name(),
this->kernel_name(),
m_size,
offset,
size,
this->kernel_file(),
this->kernel_line());
}
};
template <typename T>
using DenseVectorViewer = DenseVectorViewerT<false, T>;
template <typename T>
using CDenseVectorViewer = DenseVectorViewerT<true, T>;
} // namespace muda
#include "details/dense_vector_viewer.inl"