MUDA
Loading...
Searching...
No Matches
device_dense_matrix.inl
1#include <cublas_v2.h>
2
3namespace muda
4{
5template <typename Ty>
6DeviceDenseMatrix<Ty>::DeviceDenseMatrix(size_t row, size_t col, bool sym)
7 : m_row{row}
8 , m_col{col}
9 , m_sym{sym}
10 , m_data{muda::Extent2D{col, row}}
11{
12}
13template <typename Ty>
14DeviceDenseMatrix<Ty>::DeviceDenseMatrix(DeviceDenseMatrix&& other)
15 : m_row{other.m_row}
16 , m_col{other.m_col}
17 , m_data{std::move(other.m_data)}
18{
19 other.m_row = 0;
20 other.m_col = 0;
21}
22template <typename Ty>
23DeviceDenseMatrix<Ty>& DeviceDenseMatrix<Ty>::operator=(DeviceDenseMatrix&& other)
24{
25 if(this != &other)
26 {
27 m_row = other.m_row;
28 m_col = other.m_col;
29 m_data = std::move(other.m_data);
30 other.m_row = 0;
31 other.m_col = 0;
32 }
33 return *this;
34}
35
36template <typename Ty>
37void DeviceDenseMatrix<Ty>::reshape(size_t row, size_t col)
38{
39 m_data.resize(muda::Extent2D{col, row});
40 m_row = row;
41 m_col = col;
42}
43
44template <typename Ty>
45void DeviceDenseMatrix<Ty>::fill(Ty value)
46{
47 m_data.fill(value);
48}
49
50template <typename Ty>
51void DeviceDenseMatrix<Ty>::copy_to(Eigen::MatrixX<Ty>& mat) const
52{
53 std::vector<Ty> host_data;
54 m_data.copy_to(host_data);
55 mat.resize(m_row, m_col);
56
57 for(size_t i = 0; i < m_row; ++i)
58 {
59 for(size_t j = 0; j < m_col; ++j)
60 {
61 mat(i, j) = host_data[j * m_row + i];
62 }
63 }
64}
65template <typename Ty>
66void DeviceDenseMatrix<Ty>::copy_to(std::vector<Ty>& vec) const
67{
68 m_data.copy_to(vec);
69}
70template <typename Ty>
71DeviceDenseMatrix<Ty>::DeviceDenseMatrix(const Eigen::MatrixX<Ty>& mat)
72{
73 reshape(mat.rows(), mat.cols());
74 std::vector<Ty> host_data(m_row * m_col);
75
76 for(size_t i = 0; i < m_row; ++i)
77 {
78 for(size_t j = 0; j < m_col; ++j)
79 {
80 host_data[j * m_row + i] = mat(i, j);
81 }
82 }
83 m_data.copy_from(host_data);
84}
85template <typename Ty>
86DeviceDenseMatrix<Ty>& DeviceDenseMatrix<Ty>::operator=(const Eigen::MatrixX<Ty>& mat)
87{
88 if(mat.rows() != m_row || mat.cols() != m_col)
89 {
90 reshape(mat.rows(), mat.cols());
91 }
92 std::vector<Ty> host_data(m_row * m_col);
93
94 for(size_t i = 0; i < m_row; ++i)
95 {
96 for(size_t j = 0; j < m_col; ++j)
97 {
98 host_data[j * m_row + i] = mat(i, j);
99 }
100 }
101
102 m_data.copy_from(host_data);
103 return *this;
104}
105template <typename Ty>
106DenseMatrixView<Ty> DeviceDenseMatrix<Ty>::T()
107{
108 return DenseMatrixView{m_data, m_row, m_col, true, m_sym};
109}
110
111template <typename Ty>
112CDenseMatrixView<Ty> DeviceDenseMatrix<Ty>::T() const
113{
114 return CDenseMatrixView{m_data, m_row, m_col, true, m_sym};
115}
116
117template <typename Ty>
118DenseMatrixView<Ty> DeviceDenseMatrix<Ty>::view()
119{
120 return DenseMatrixView<Ty>{m_data.view(), m_row, m_col, false, m_sym};
121}
122
123template <typename Ty>
124CDenseMatrixView<Ty> DeviceDenseMatrix<Ty>::view() const
125{
126 return CDenseMatrixView<Ty>{m_data.view(), m_row, m_col, false, m_sym};
127}
128
129template <typename Ty>
130DeviceDenseMatrix<Ty>::operator CDenseMatrixView<Ty>() const
131{
132 return view();
133}
134template <typename Ty>
135DeviceDenseMatrix<Ty>::operator DenseMatrixView<Ty>()
136{
137 return view();
138}
139} // namespace muda
Definition extent.h:10