MUDA
Loading...
Searching...
No Matches
dense_matrix_viewer.h
1#pragma once
2#include <muda/ext/eigen/eigen_core_cxx20.h>
3#include <muda/buffer/buffer_2d_view.h>
4#include <muda/viewer/viewer_base.h>
5#include <cublas_v2.h>
6namespace muda
7{
8template <bool IsConst, typename T>
9class DenseMatrixViewerBase : public ViewerBase<IsConst>
10{
11 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
12 "now only support real number");
13 static_assert(!std::is_const_v<T>, "T must be non-const type");
14
16 template <typename U>
17 using auto_const_t = typename Base::template auto_const_t<U>;
18
19 public:
20 using CBuffer2DView = CBuffer2DView<T>;
21 using Buffer2DView = Buffer2DView<T>;
22 using ThisBuffer2DView = std::conditional_t<IsConst, CBuffer2DView, Buffer2DView>;
23
26 using ThisViewer = std::conditional_t<IsConst, ConstViewer, NonConstViewer>;
27
28 using MatrixType = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
29 template <typename U>
30 using MapMatrixT =
31 Eigen::Map<U, Eigen::AlignmentType::Unaligned, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;
32 using MapMatrix = MapMatrixT<MatrixType>;
33 using CMapMatrix = MapMatrixT<const MatrixType>;
34 using ThisMapMatrix = std::conditional_t<IsConst, CMapMatrix, MapMatrix>;
35
36 protected:
37 ThisBuffer2DView m_view;
38 size_t m_row_offset = 0;
39 size_t m_col_offset = 0;
40 size_t m_row_size = 0;
41 size_t m_col_size = 0;
42
43 public:
44 MUDA_GENERIC DenseMatrixViewerBase(ThisBuffer2DView view,
45 size_t row_offset,
46 size_t col_offset,
47 size_t row_size,
48 size_t col_size)
49 : m_view(view)
50 , m_row_offset(row_offset)
51 , m_col_offset(col_offset)
52 , m_row_size(row_size)
53 , m_col_size(col_size)
54 {
55 }
56
57 // implicit conversion
58
59 MUDA_GENERIC auto as_const() const
60 {
61 return ConstViewer{m_view, m_row_offset, m_col_offset, m_row_size, m_col_size};
62 }
63
64 MUDA_GENERIC operator ConstViewer() const { return as_const(); }
65
66 // non-const accessor
67
68 MUDA_GENERIC ThisViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size);
69 template <int M, int N>
70 MUDA_GENERIC ThisViewer block(int row_offset, int col_offset)
71 {
72 return block(row_offset, col_offset, M, N);
73 }
74 MUDA_GENERIC Eigen::Block<ThisMapMatrix> as_eigen();
75 MUDA_GENERIC operator Eigen::Block<CMapMatrix>();
76 MUDA_GENERIC auto_const_t<T>& operator()(size_t i, size_t j);
77 MUDA_GENERIC auto buffer_view() { return m_view; }
78
79 // const accessor
80
81 MUDA_GENERIC ConstViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size) const
82 {
83 return remove_const(*this).block(row_offset, col_offset, row_size, col_size);
84 }
85 template <int M, int N>
86 MUDA_GENERIC ConstViewer block(int row_offset, int col_offset) const
87 {
88 return remove_const(*this).block<M, N>(row_offset, col_offset);
89 }
90 MUDA_GENERIC Eigen::Block<CMapMatrix> as_eigen() const;
91 MUDA_GENERIC operator Eigen::Block<CMapMatrix>() const
92 {
93 return as_eigen();
94 }
95 MUDA_GENERIC const T& operator()(size_t i, size_t j) const
96 {
97 return remove_const(*this)(i, j);
98 }
99
100 MUDA_GENERIC size_t row() const { return m_row_size; }
101 MUDA_GENERIC size_t col() const { return m_col_size; }
102 MUDA_GENERIC size_t origin_row() const;
103 MUDA_GENERIC size_t origin_col() const;
104 MUDA_GENERIC auto buffer_view() const { return m_view; }
105 MUDA_GENERIC auto row_offset() const { return m_row_offset; }
106 MUDA_GENERIC auto col_offset() const { return m_col_offset; }
107};
108
109template <typename T>
111{
112 MUDA_VIEWER_COMMON_NAME(CDenseMatrixViewer);
113
115 using CMapMatrix = typename Base::CMapMatrix;
116
117 public:
118 using Base::Base;
119
120 MUDA_GENERIC CDenseMatrixViewer(const Base& base)
121 : Base(base)
122 {
123 }
124
125 MUDA_GENERIC CDenseMatrixViewer block(size_t row_offset,
126 size_t col_offset,
127 size_t row_size,
128 size_t col_size) const
129 {
130 return Base::block(row_offset, col_offset, row_size, col_size);
131 }
132
133 template <size_t M, size_t N>
134 MUDA_GENERIC CDenseMatrixViewer block(size_t row_offset, size_t col_offset) const
135 {
136 return Base::template block<M, N>(row_offset, col_offset);
137 }
138};
139
140template <typename T>
142{
143 MUDA_VIEWER_COMMON_NAME(DenseMatrixViewer);
144
146 using MapMatrix = typename Base::MapMatrix;
147 using CMapMatrix = typename Base::CMapMatrix;
148
149 public:
150 using Base::Base;
151
152 MUDA_GENERIC DenseMatrixViewer(const Base& base)
153 : Base(base)
154 {
155 }
156
157 MUDA_GENERIC DenseMatrixViewer(const CDenseMatrixViewer<T>&) = delete;
158
159 MUDA_GENERIC DenseMatrixViewer block(size_t row_offset, size_t col_offset, size_t row_size, size_t col_size)
160 {
161 return Base::block(row_offset, col_offset, row_size, col_size);
162 }
163
164 template <size_t M, size_t N>
165 MUDA_GENERIC DenseMatrixViewer block(size_t row_offset, size_t col_offset)
166 {
167 return Base::template block<M, N>(row_offset, col_offset);
168 }
169
170 MUDA_DEVICE T atomic_add(size_t i, size_t j, T val);
171
172 template <int M, int N>
173 MUDA_DEVICE Eigen::Matrix<T, M, N> atomic_add(const Eigen::Matrix<T, M, N>& other);
174
175 template <int M, int N>
176 MUDA_GENERIC DenseMatrixViewer& operator=(const Eigen::Matrix<T, M, N>& other);
177
178 private:
179 MUDA_GENERIC void check_size_matching(int M, int N) const;
180};
181
182} // namespace muda
183
184//namespace muda
185//{
186//template <typename T>
187//struct read_only_viewer<DenseMatrixViewer<T>>
188//{
189// using type = CDenseMatrixViewer<T>;
190//};
191//
192//template <typename T>
193//struct read_write_viewer<CDenseMatrixViewer<T>>
194//{
195// using type = DenseMatrixViewer<T>;
196//};
197//} // namespace muda
198
199#include "details/dense_matrix_viewer.inl"
Definition dense_matrix_viewer.h:111
Definition dense_matrix_viewer.h:10
Definition dense_matrix_viewer.h:142
Definition viewer_base.h:18