File dense_vector_viewer.h
File List > ext > linear_system > dense_vector_viewer.h
Go to the documentation of this file
#pragma once
#include <muda/ext/eigen/eigen_core_cxx20.h>
#include <muda/buffer/buffer_2d_view.h>
#include <muda/viewer/viewer_base.h>
#include <cublas_v2.h>
#include <muda/atomic.h>
namespace muda
{
template <bool IsConst, typename T>
class DenseVectorViewerBase : public ViewerBase<IsConst>
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"now only support real number");
using Base = ViewerBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
using CBufferView = CBufferView<T>;
using BufferView = BufferView<T>;
using ThisBufferView = std::conditional_t<IsConst, CBufferView, BufferView>;
using ConstViewer = DenseVectorViewerBase<true, T>;
using NonConstViewer = DenseVectorViewerBase<false, T>;
using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
using VectorType = Eigen::Vector<T, Eigen::Dynamic>;
template <typename U>
using MapVectorT =
Eigen::Map<U, Eigen::AlignmentType::Unaligned, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;
using MapVector = MapVectorT<VectorType>;
using CMapVector = MapVectorT<const VectorType>;
using ThisMapVector = std::conditional_t<IsConst, CMapVector, MapVector>;
protected:
auto_const_t<T>* m_data;
int m_offset = 0;
int m_size = 0;
int m_origin_size = 0;
public:
MUDA_GENERIC DenseVectorViewerBase(auto_const_t<T>* data, int offset, int size, int origin_size)
: m_data(data)
, m_offset(offset)
, m_size(size)
, m_origin_size(origin_size)
{
}
MUDA_GENERIC auto as_const() const
{
return ConstViewer{m_data, m_offset, m_size, m_origin_size};
}
MUDA_GENERIC operator ConstViewer() const { return as_const(); }
MUDA_GENERIC auto segment(int offset, int size)
{
check_segment(offset, size);
auto ret = ThisViewer{m_data, m_offset + offset, size, m_origin_size};
ret.copy_label(*this);
return ret;
}
template <int N>
MUDA_GENERIC auto segment(int offset)
{
return segment(offset, N);
}
MUDA_GENERIC auto segment(int offset, int size) const
{
return remove_const(*this).segment(offset, size);
}
MUDA_GENERIC const T& operator()(int i) const { return m_data[index(i)]; }
MUDA_GENERIC auto_const_t<T>& operator()(int i) { return m_data[index(i)]; }
template <int N>
MUDA_GENERIC auto segment(int offset) const
{
return remove_const(*this).segment(offset, N);
}
MUDA_GENERIC Eigen::VectorBlock<CMapVector> as_eigen() const
{
check_data();
return CMapVector{m_data,
(int)origin_size(),
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{1, 1}}
.segment(m_offset, m_size);
}
MUDA_GENERIC operator Eigen::VectorBlock<CMapVector>() const
{
return as_eigen();
}
MUDA_GENERIC Eigen::VectorBlock<ThisMapVector> as_eigen()
{
check_data();
return ThisMapVector{m_data,
(int)origin_size(),
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{1, 1}}
.segment(m_offset, m_size);
}
MUDA_GENERIC operator Eigen::VectorBlock<ThisMapVector>()
{
return as_eigen();
}
MUDA_GENERIC auto size() const { return m_size; }
MUDA_GENERIC auto offset() const { return m_offset; }
MUDA_GENERIC auto origin_data() const { return m_data; }
MUDA_GENERIC auto origin_size() const { return m_origin_size; }
protected:
MUDA_INLINE MUDA_GENERIC void check_size_matching(int N)
{
MUDA_KERNEL_ASSERT(m_size == N,
"DenseVectorViewerBase [%s:%s]: size not match, yours size=%d, expected size=%d",
this->name(),
this->kernel_name(),
m_size,
N);
}
MUDA_INLINE MUDA_GENERIC int index(int i) const
{
MUDA_KERNEL_ASSERT(origin_data(),
"DenseVectorViewerBase [%s:%s]: data is null",
this->name(),
this->kernel_name());
MUDA_KERNEL_ASSERT(i < m_size,
"DenseVectorViewerBase [%s:%s]: index out of range, size=%d, yours index=%d",
this->name(),
this->kernel_name(),
m_size,
i);
return m_offset + i;
}
MUDA_INLINE MUDA_GENERIC void check_data() const
{
MUDA_KERNEL_ASSERT(origin_data(),
"DenseVectorViewerBase [%s:%s]: data is null",
this->name(),
this->kernel_name());
}
MUDA_INLINE MUDA_GENERIC void check_segment(int offset, int size) const
{
MUDA_KERNEL_ASSERT(offset + size <= m_size,
"DenseVectorViewerBase [%s:%s]: segment out of range, m_size=%d, offset=%d, size=%d",
this->name(),
this->kernel_name(),
m_size,
offset,
size);
}
};
//template <typename T>
//using CDenseVectorViewer = DenseVectorViewerBase<true, T>;
//template <typename T>
//using DenseVectorViewer = DenseVectorViewerBase<false, T>;
template <typename T>
class CDenseVectorViewer : public DenseVectorViewerBase<true, T>
{
MUDA_VIEWER_COMMON_NAME(CDenseVectorViewer);
using Base = DenseVectorViewerBase<true, T>;
using MapVector = typename Base::MapVector;
using CMapVector = typename Base::CMapVector;
public:
using Base::Base;
MUDA_GENERIC CDenseVectorViewer(const Base& base)
: Base(base)
{
}
MUDA_GENERIC CDenseVectorViewer segment(int offset, int size) const
{
return CDenseVectorViewer{Base::segment(offset, size)};
}
template <int N>
MUDA_GENERIC auto segment(int offset) const
{
return segment(offset, N);
}
};
template <typename T>
class DenseVectorViewer : public DenseVectorViewerBase<false, T>
{
MUDA_VIEWER_COMMON_NAME(DenseVectorViewer);
using Base = DenseVectorViewerBase<false, T>;
using MapVector = typename Base::MapVector;
using CMapVector = typename Base::CMapVector;
public:
using Base::Base;
MUDA_GENERIC DenseVectorViewer(const Base& base)
: Base(base)
{
}
MUDA_GENERIC DenseVectorViewer segment(int offset, int size)
{
return DenseVectorViewer{Base::segment(offset, size)};
}
template <int N>
MUDA_GENERIC auto segment(int offset)
{
return segment(offset, N);
}
MUDA_DEVICE T atomic_add(int i, T val)
{
auto ptr = &this->operator()(i);
return muda::atomic_add(ptr, val);
}
template <int N>
MUDA_DEVICE Eigen::Vector<T, N> atomic_add(const Eigen::Vector<T, N>& val)
{
this->check_size_matching(N);
Eigen::Vector<T, N> ret;
#pragma unroll
for(int i = 0; i < N; ++i)
{
ret(i) = atomic_add(i, val(i));
}
return ret;
}
MUDA_DEVICE T atomic_add(const T& val)
{
this->check_size_matching(1);
T ret = atomic_add(0, val);
return ret;
}
template <int N>
MUDA_GENERIC DenseVectorViewer& operator=(const Eigen::Vector<T, N>& other)
{
this->check_size_matching(N);
#pragma unroll
for(int i = 0; i < N; ++i)
{
this->operator()(i) = other(i);
}
return *this;
}
};
} // namespace muda
namespace muda
{
template <typename T>
struct read_only_viewer<DenseVectorViewer<T>>
{
using type = CDenseVectorViewer<T>;
};
template <typename T>
struct read_write_viewer<CDenseVectorViewer<T>>
{
using type = DenseVectorViewer<T>;
};
} // namespace muda
#include "details/dense_vector_viewer.inl"