1#include <muda/atomic.h>
5template <
bool IsConst,
typename T>
6MUDA_GENERIC
auto DenseMatrixViewerBase<IsConst, T>::block(
size_t row_offset,
9 size_t col_size) -> ThisViewer
11 MUDA_KERNEL_ASSERT(row_offset + row_size <= m_row_size && col_offset + col_size <= m_col_size,
12 "DenseMatrixViewerBase [%s:%s]: block index out of range, shape=(%lld,%lld), yours index=(%lld,%lld)",
20 auto ret = DenseMatrixViewerBase{
21 m_view, m_row_offset + row_offset, m_col_offset + col_offset, row_size, col_size};
22 ret.copy_label(*
this);
26template <
bool IsConst,
typename T>
27MUDA_GENERIC
auto DenseMatrixViewerBase<IsConst, T>::as_eigen()
28 -> Eigen::Block<ThisMapMatrix>
30 auto outer = m_view.pitch_bytes() /
sizeof(T);
32 return ThisMapMatrix{m_view.origin_data(),
35 Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{(int)outer, 1}}
36 .block(m_row_offset, m_col_offset, m_row_size, m_col_size);
39template <
bool IsConst,
typename T>
40MUDA_GENERIC
auto DenseMatrixViewerBase<IsConst, T>::operator()(
size_t i,
size_t j)
43 if constexpr(DEBUG_VIEWER)
45 MUDA_KERNEL_ASSERT(m_view.data(0),
46 "DenseMatrixViewer [%s:%s]: data is null",
49 if(m_row_offset == 0 && m_col_offset == 0)
51 MUDA_KERNEL_ASSERT(i < m_row_size && j < m_col_size,
52 "DenseMatrixViewer [%s:%s]: index out of range, shape=(%lld,%lld), yours index=(%lld,%lld)",
62 MUDA_KERNEL_ASSERT(i < m_row_size && j < m_col_size,
63 "DenseMatrixViewer [%s:%s]:index out of range, block shape=(%lld,%lld), your index=(%lld,%lld)",
74 return *m_view.data(j, i);
77template <
bool IsConst,
typename T>
78MUDA_GENERIC
auto DenseMatrixViewerBase<IsConst, T>::as_eigen() const
79 -> Eigen::Block<CMapMatrix>
81 auto outer = m_view.pitch_bytes() /
sizeof(T);
83 return CMapMatrix{m_view.origin_data(),
86 Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{(int)outer, 1}}
87 .block(m_row_offset, m_col_offset, m_row_size, m_col_size);
90template <
bool IsConst,
typename T>
91MUDA_GENERIC
size_t DenseMatrixViewerBase<IsConst, T>::origin_row()
const
94 ret = m_view.extent().width();
98template <
bool IsConst,
typename T>
99MUDA_GENERIC
size_t DenseMatrixViewerBase<IsConst, T>::origin_col()
const
102 ret = m_view.extent().height();
130template <
int M,
int N>
131MUDA_DEVICE Eigen::Matrix<T, M, N> DenseMatrixViewer<T>::atomic_add(
const Eigen::Matrix<T, M, N>& other)
133 check_size_matching(M, N);
134 Eigen::Matrix<T, M, N> ret;
136 for(
int i = 0; i < M; ++i)
138 for(
int j = 0; j < N; ++j)
140 ret(i, j) = atomic_add(i, j, other(i, j));
146template <
int M,
int N>
147MUDA_GENERIC DenseMatrixViewer<T>& DenseMatrixViewer<T>::operator=(
const Eigen::Matrix<T, M, N>& other)
149 check_size_matching(M, N);
151 for(
int i = 0; i < M; ++i)
153 for(
int j = 0; j < N; ++j)
154 (*
this)(i, j) = other(i, j);
159MUDA_DEVICE T DenseMatrixViewer<T>::atomic_add(
size_t i,
size_t j, T val)
161 auto ptr = &this->operator()(i, j);
162 muda::atomic_add(ptr, val);
167MUDA_INLINE MUDA_GENERIC
void DenseMatrixViewer<T>::check_size_matching(
int M,
int N)
const
169 MUDA_KERNEL_ASSERT(this->m_row_size == M && this->m_col_size == N,
170 "DenseMatrixViewer [%s:%s] shape mismatching, Viewer=(%lld,%lld), yours=(%lld,%lld)",