File dense_matrix_viewer.h
File List > ext > linear_system > dense_matrix_viewer.h
Go to the documentation of this file
#pragma once
#include <muda/ext/eigen/eigen_core_cxx20.h>
#include <muda/buffer/buffer_2d_view.h>
#include <muda/viewer/viewer_base.h>
#include <cublas_v2.h>
namespace muda
{
template <bool IsConst, typename T>
class DenseMatrixViewerBase : public ViewerBase<IsConst>
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"now only support real number");
static_assert(!std::is_const_v<T>, "T must be non-const type");
using Base = ViewerBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
using CBuffer2DView = CBuffer2DView<T>;
using Buffer2DView = Buffer2DView<T>;
using ThisBuffer2DView = std::conditional_t<IsConst, CBuffer2DView, Buffer2DView>;
using ConstViewer = DenseMatrixViewerBase<true, T>;
using NonConstViewer = DenseMatrixViewerBase<false, T>;
using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
using MatrixType = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
template <typename U>
using MapMatrixT =
Eigen::Map<U, Eigen::AlignmentType::Unaligned, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;
using MapMatrix = MapMatrixT<MatrixType>;
using CMapMatrix = MapMatrixT<const MatrixType>;
using ThisMapMatrix = std::conditional_t<IsConst, CMapMatrix, MapMatrix>;
protected:
ThisBuffer2DView m_view;
size_t m_row_offset = 0;
size_t m_col_offset = 0;
size_t m_row_size = 0;
size_t m_col_size = 0;
public:
MUDA_GENERIC DenseMatrixViewerBase(ThisBuffer2DView view,
size_t row_offset,
size_t col_offset,
size_t row_size,
size_t col_size)
: m_view(view)
, m_row_offset(row_offset)
, m_col_offset(col_offset)
, m_row_size(row_size)
, m_col_size(col_size)
{
}
// implicit conversion
MUDA_GENERIC auto as_const() const
{
return ConstViewer{m_view, m_row_offset, m_col_offset, m_row_size, m_col_size};
}
MUDA_GENERIC operator ConstViewer() const { return as_const(); }
// non-const accessor
MUDA_GENERIC ThisViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size);
template <int M, int N>
MUDA_GENERIC ThisViewer block(int row_offset, int col_offset)
{
return block(row_offset, col_offset, M, N);
}
MUDA_GENERIC Eigen::Block<ThisMapMatrix> as_eigen();
MUDA_GENERIC operator Eigen::Block<CMapMatrix>();
MUDA_GENERIC auto_const_t<T>& operator()(size_t i, size_t j);
MUDA_GENERIC auto buffer_view() { return m_view; }
// const accessor
MUDA_GENERIC ConstViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size) const
{
return remove_const(*this).block(row_offset, col_offset, row_size, col_size);
}
template <int M, int N>
MUDA_GENERIC ConstViewer block(int row_offset, int col_offset) const
{
return remove_const(*this).block<M, N>(row_offset, col_offset);
}
MUDA_GENERIC Eigen::Block<CMapMatrix> as_eigen() const;
MUDA_GENERIC operator Eigen::Block<CMapMatrix>() const
{
return as_eigen();
}
MUDA_GENERIC const T& operator()(size_t i, size_t j) const
{
return remove_const(*this)(i, j);
}
MUDA_GENERIC size_t row() const { return m_row_size; }
MUDA_GENERIC size_t col() const { return m_col_size; }
MUDA_GENERIC size_t origin_row() const;
MUDA_GENERIC size_t origin_col() const;
MUDA_GENERIC auto buffer_view() const { return m_view; }
MUDA_GENERIC auto row_offset() const { return m_row_offset; }
MUDA_GENERIC auto col_offset() const { return m_col_offset; }
};
template <typename T>
class CDenseMatrixViewer : public DenseMatrixViewerBase<true, T>
{
MUDA_VIEWER_COMMON_NAME(CDenseMatrixViewer);
using Base = DenseMatrixViewerBase<true, T>;
using CMapMatrix = typename Base::CMapMatrix;
public:
using Base::Base;
MUDA_GENERIC CDenseMatrixViewer(const Base& base)
: Base(base)
{
}
MUDA_GENERIC CDenseMatrixViewer block(size_t row_offset,
size_t col_offset,
size_t row_size,
size_t col_size) const
{
return Base::block(row_offset, col_offset, row_size, col_size);
}
template <size_t M, size_t N>
MUDA_GENERIC CDenseMatrixViewer block(size_t row_offset, size_t col_offset) const
{
return Base::template block<M, N>(row_offset, col_offset);
}
};
template <typename T>
class DenseMatrixViewer : public DenseMatrixViewerBase<false, T>
{
MUDA_VIEWER_COMMON_NAME(DenseMatrixViewer);
using Base = DenseMatrixViewerBase<false, T>;
using MapMatrix = typename Base::MapMatrix;
using CMapMatrix = typename Base::CMapMatrix;
public:
using Base::Base;
MUDA_GENERIC DenseMatrixViewer(const Base& base)
: Base(base)
{
}
MUDA_GENERIC DenseMatrixViewer(const CDenseMatrixViewer<T>&) = delete;
MUDA_GENERIC DenseMatrixViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size)
{
return Base::block(row_offset, col_offset, row_size, col_size);
}
template <size_t M, size_t N>
MUDA_GENERIC DenseMatrixViewer block(size_t row_offset, size_t col_offset)
{
return Base::template block<M, N>(row_offset, col_offset);
}
MUDA_DEVICE T atomic_add(size_t i, size_t j, T val);
template <int M, int N>
MUDA_DEVICE Eigen::Matrix<T, M, N> atomic_add(const Eigen::Matrix<T, M, N>& other);
template <int M, int N>
MUDA_GENERIC DenseMatrixViewer& operator=(const Eigen::Matrix<T, M, N>& other);
private:
MUDA_GENERIC void check_size_matching(int M, int N) const;
};
} // namespace muda
//namespace muda
//{
//template <typename T>
//struct read_only_viewer<DenseMatrixViewer<T>>
//{
// using type = CDenseMatrixViewer<T>;
//};
//
//template <typename T>
//struct read_write_viewer<CDenseMatrixViewer<T>>
//{
// using type = DenseMatrixViewer<T>;
//};
//} // namespace muda
#include "details/dense_matrix_viewer.inl"