MUDA
Loading...
Searching...
No Matches
dense_matrix_viewer.inl
1#include <muda/atomic.h>
2
3namespace muda
4{
5template <bool IsConst, typename T>
6MUDA_GENERIC auto DenseMatrixViewerBase<IsConst, T>::block(size_t row_offset,
7 size_t col_offset,
8 size_t row_size,
9 size_t col_size) -> ThisViewer
10{
11 MUDA_KERNEL_ASSERT(row_offset + row_size <= m_row_size && col_offset + col_size <= m_col_size,
12 "DenseMatrixViewerBase [%s:%s]: block index out of range, shape=(%lld,%lld), yours index=(%lld,%lld)",
13 this->name(),
14 this->kernel_name(),
15 m_row_size,
16 m_col_size,
17 row_offset,
18 col_offset);
19
20 auto ret = DenseMatrixViewerBase{
21 m_view, m_row_offset + row_offset, m_col_offset + col_offset, row_size, col_size};
22 ret.copy_label(*this);
23 return ret;
24}
25
26template <bool IsConst, typename T>
27MUDA_GENERIC auto DenseMatrixViewerBase<IsConst, T>::as_eigen()
28 -> Eigen::Block<ThisMapMatrix>
29{
30 auto outer = m_view.pitch_bytes() / sizeof(T);
31
32 return ThisMapMatrix{m_view.origin_data(),
33 (int)origin_row(),
34 (int)origin_col(),
35 Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{(int)outer, 1}}
36 .block(m_row_offset, m_col_offset, m_row_size, m_col_size);
37}
38
39template <bool IsConst, typename T>
40MUDA_GENERIC auto DenseMatrixViewerBase<IsConst, T>::operator()(size_t i, size_t j)
41 -> auto_const_t<T>&
42{
43 if constexpr(DEBUG_VIEWER)
44 {
45 MUDA_KERNEL_ASSERT(m_view.data(0),
46 "DenseMatrixViewer [%s:%s]: data is null",
47 this->name(),
48 this->kernel_name());
49 if(m_row_offset == 0 && m_col_offset == 0)
50 {
51 MUDA_KERNEL_ASSERT(i < m_row_size && j < m_col_size,
52 "DenseMatrixViewer [%s:%s]: index out of range, shape=(%lld,%lld), yours index=(%lld,%lld)",
53 this->name(),
54 this->kernel_name(),
55 m_row_size,
56 m_col_size,
57 i,
58 j);
59 }
60 else
61 {
62 MUDA_KERNEL_ASSERT(i < m_row_size && j < m_col_size,
63 "DenseMatrixViewer [%s:%s]:index out of range, block shape=(%lld,%lld), your index=(%lld,%lld)",
64 this->name(),
65 this->kernel_name(),
66 m_row_size,
67 m_col_size,
68 i,
69 j);
70 }
71 }
72 i += m_row_offset;
73 j += m_col_offset;
74 return *m_view.data(j, i);
75}
76
77template <bool IsConst, typename T>
78MUDA_GENERIC auto DenseMatrixViewerBase<IsConst, T>::as_eigen() const
79 -> Eigen::Block<CMapMatrix>
80{
81 auto outer = m_view.pitch_bytes() / sizeof(T);
82
83 return CMapMatrix{m_view.origin_data(),
84 (int)origin_row(),
85 (int)origin_col(),
86 Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>{(int)outer, 1}}
87 .block(m_row_offset, m_col_offset, m_row_size, m_col_size);
88}
89
90template <bool IsConst, typename T>
91MUDA_GENERIC size_t DenseMatrixViewerBase<IsConst, T>::origin_row() const
92{
93 size_t ret;
94 ret = m_view.extent().width();
95 return ret;
96}
97
98template <bool IsConst, typename T>
99MUDA_GENERIC size_t DenseMatrixViewerBase<IsConst, T>::origin_col() const
100{
101 size_t ret;
102 ret = m_view.extent().height();
103 return ret;
104}
105
106/**************************************************************************
107*
108* DenseMatrixViewer
109*
110**************************************************************************/
111
112//template <typename T>
113//MUDA_GENERIC DenseMatrixViewer<T> DenseMatrixViewer<T>::block(size_t row_offset,
114// size_t col_offset,
115// size_t row_size,
116// size_t col_size) const
117//{
118// return Base::block(row_offset, col_offset, row_size, col_size);
119//}
120//
121//template <typename T>
122//template <size_t M, size_t N>
123//MUDA_GENERIC DenseMatrixViewer<T> DenseMatrixViewer<T>::block(size_t row_offset,
124// size_t col_offset) const
125//{
126// return Base::block<M, N>(row_offset, col_offset);
127//}
128
129template <typename T>
130template <int M, int N>
131MUDA_DEVICE Eigen::Matrix<T, M, N> DenseMatrixViewer<T>::atomic_add(const Eigen::Matrix<T, M, N>& other)
132{
133 check_size_matching(M, N);
134 Eigen::Matrix<T, M, N> ret;
135#pragma unroll
136 for(int i = 0; i < M; ++i)
137#pragma unroll
138 for(int j = 0; j < N; ++j)
139 {
140 ret(i, j) = atomic_add(i, j, other(i, j));
141 }
142 return ret;
143}
144
145template <typename T>
146template <int M, int N>
147MUDA_GENERIC DenseMatrixViewer<T>& DenseMatrixViewer<T>::operator=(const Eigen::Matrix<T, M, N>& other)
148{
149 check_size_matching(M, N);
150#pragma unroll
151 for(int i = 0; i < M; ++i)
152#pragma unroll
153 for(int j = 0; j < N; ++j)
154 (*this)(i, j) = other(i, j);
155 return *this;
156}
157
158template <typename T>
159MUDA_DEVICE T DenseMatrixViewer<T>::atomic_add(size_t i, size_t j, T val)
160{
161 auto ptr = &this->operator()(i, j);
162 muda::atomic_add(ptr, val);
163 return val;
164}
165
166template <typename T>
167MUDA_INLINE MUDA_GENERIC void DenseMatrixViewer<T>::check_size_matching(int M, int N) const
168{
169 MUDA_KERNEL_ASSERT(this->m_row_size == M && this->m_col_size == N,
170 "DenseMatrixViewer [%s:%s] shape mismatching, Viewer=(%lld,%lld), yours=(%lld,%lld)",
171 this->name(),
172 this->kernel_name(),
173 this->m_row_size,
174 this->m_col_size,
175 M,
176 N);
177}
178} // namespace muda