MUDA
Loading...
Searching...
No Matches
matrix_format_converter.inl
1#include <muda/ext/linear_system/matrix_format_converter_impl.h>
2
3namespace muda
4{
5template <typename T, int N>
6details::MatrixFormatConverter<T, N>& MatrixFormatConverter::impl()
7{
8 using namespace details;
9 constexpr auto ask_data_type = cuda_data_type<T>();
10 constexpr auto ask_N = N;
11
12 if(current)
13 {
14 if(current->data_type() == ask_data_type && current->dim() == ask_N)
15 {
16 return *static_cast<details::MatrixFormatConverter<T, N>*>(current);
17 }
18 }
19
20 MatrixFormatConverterType type{ask_data_type, ask_N};
21 auto it = m_impls.find(type);
22 if(it != m_impls.end())
23 {
24 current = it->second.get();
25 return *static_cast<details::MatrixFormatConverter<T, N>*>(current);
26 }
27
28 auto impl = std::make_unique<details::MatrixFormatConverter<T, N>>(m_handles);
29 current = impl.get();
30 m_impls.emplace(type, std::move(impl));
31 return *static_cast<details::MatrixFormatConverter<T, N>*>(current);
32}
33
34inline MatrixFormatConverter::~MatrixFormatConverter() {}
35} // namespace muda
36
37namespace muda
38{
39// Triplet -> BCOO
40template <typename T, int N>
41void MatrixFormatConverter::convert(const DeviceTripletMatrix<T, N>& from,
42 DeviceBCOOMatrix<T, N>& to)
43{
44 impl<T, N>().convert(from, to);
45}
46
47// BCOO -> Dense Matrix
48template <typename T, int N>
49void MatrixFormatConverter::convert(const DeviceBCOOMatrix<T, N>& from,
50 DeviceDenseMatrix<T>& to,
51 bool clear_dense_matrix)
52{
53 impl<T, N>().convert(from, to, clear_dense_matrix);
54}
55
56// BCOO -> COO
57template <typename T, int N>
58void MatrixFormatConverter::convert(const DeviceBCOOMatrix<T, N>& from,
59 DeviceCOOMatrix<T>& to)
60{
61 impl<T, N>().convert(from, to);
62}
63
64// BCOO -> BSR
65template <typename T, int N>
66void MatrixFormatConverter::convert(const DeviceBCOOMatrix<T, N>& from,
67 DeviceBSRMatrix<T, N>& to)
68{
69 impl<T, N>().convert(from, to);
70}
71
72// Doublet -> BCOO
73template <typename T, int N>
74void MatrixFormatConverter::convert(const DeviceDoubletVector<T, N>& from,
75 DeviceBCOOVector<T, N>& to)
76{
77 impl<T, N>().convert(from, to);
78}
79
80// BCOO -> Dense Vector
81template <typename T, int N>
82void MatrixFormatConverter::convert(const DeviceBCOOVector<T, N>& from,
83 DeviceDenseVector<T>& to,
84 bool clear_dense_vector)
85{
86
87 impl<T, N>().convert(from, to, clear_dense_vector);
88}
89
90// Doublet -> Dense Vector
91template <typename T, int N>
92void MatrixFormatConverter::convert(const DeviceDoubletVector<T, N>& from,
93 DeviceDenseVector<T>& to,
94 bool clear_dense_vector)
95{
96
97 impl<T, N>().convert(from, to, clear_dense_vector);
98}
99
100// BSR -> CSR
101template <typename T, int N>
102void MatrixFormatConverter::convert(const DeviceBSRMatrix<T, N>& from,
103 DeviceCSRMatrix<T>& to)
104{
105 impl<T, N>().convert(from, to);
106}
107
108// Triplet -> COO
109template <typename T>
110void MatrixFormatConverter::convert(const DeviceTripletMatrix<T, 1>& from,
111 DeviceCOOMatrix<T>& to)
112{
113 impl<T, 1>().convert(from, to);
114}
115
116// COO -> Dense Matrix
117template <typename T>
118void MatrixFormatConverter::convert(const DeviceCOOMatrix<T>& from,
119 DeviceDenseMatrix<T>& to,
120 bool clear_dense_matrix)
121{
122 impl<T, 1>().convert(from, to, clear_dense_matrix);
123}
124
125// COO -> CSR
126template <typename T>
127void MatrixFormatConverter::convert(const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to)
128{
129 impl<T, 1>().convert(from, to);
130}
131template <typename T>
132void MatrixFormatConverter::convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to)
133{
134 impl<T, 1>().convert(std::move(from), to);
135}
136
137// Doublet -> COO
138template <typename T>
139void MatrixFormatConverter::convert(const DeviceDoubletVector<T, 1>& from,
140 DeviceCOOVector<T>& to)
141{
142 impl<T, 1>().convert(from, to);
143}
144
145// COO -> Dense Vector
146template <typename T>
147void MatrixFormatConverter::convert(const DeviceCOOVector<T>& from,
148 DeviceDenseVector<T>& to,
149 bool clear_dense_vector)
150{
151 impl<T, 1>().convert(from, to, clear_dense_vector);
152}
153
154// Doublet -> Dense Vector
155template <typename T>
156void MatrixFormatConverter::convert(const DeviceDoubletVector<T, 1>& from,
157 DeviceDenseVector<T>& to,
158 bool clear_dense_vector)
159{
160 impl<T, 1>().convert(from, to, clear_dense_vector);
161}
162} // namespace muda