1#include <muda/ext/linear_system/matrix_format_converter_impl.h>
5template <
typename T,
int N>
6details::MatrixFormatConverter<T, N>& MatrixFormatConverter::impl()
8 using namespace details;
9 constexpr auto ask_data_type = cuda_data_type<T>();
10 constexpr auto ask_N = N;
14 if(current->data_type() == ask_data_type && current->dim() == ask_N)
16 return *
static_cast<details::MatrixFormatConverter<T, N>*
>(current);
20 MatrixFormatConverterType type{ask_data_type, ask_N};
21 auto it = m_impls.find(type);
22 if(it != m_impls.end())
24 current = it->second.get();
25 return *
static_cast<details::MatrixFormatConverter<T, N>*
>(current);
28 auto impl = std::make_unique<details::MatrixFormatConverter<T, N>>(m_handles);
30 m_impls.emplace(type, std::move(impl));
31 return *
static_cast<details::MatrixFormatConverter<T, N>*
>(current);
34inline MatrixFormatConverter::~MatrixFormatConverter() {}
40template <
typename T,
int N>
41void MatrixFormatConverter::convert(
const DeviceTripletMatrix<T, N>& from,
42 DeviceBCOOMatrix<T, N>& to)
44 impl<T, N>().convert(from, to);
48template <
typename T,
int N>
49void MatrixFormatConverter::convert(
const DeviceBCOOMatrix<T, N>& from,
50 DeviceDenseMatrix<T>& to,
51 bool clear_dense_matrix)
53 impl<T, N>().convert(from, to, clear_dense_matrix);
57template <
typename T,
int N>
58void MatrixFormatConverter::convert(
const DeviceBCOOMatrix<T, N>& from,
59 DeviceCOOMatrix<T>& to)
61 impl<T, N>().convert(from, to);
65template <
typename T,
int N>
66void MatrixFormatConverter::convert(
const DeviceBCOOMatrix<T, N>& from,
67 DeviceBSRMatrix<T, N>& to)
69 impl<T, N>().convert(from, to);
73template <
typename T,
int N>
74void MatrixFormatConverter::convert(
const DeviceDoubletVector<T, N>& from,
75 DeviceBCOOVector<T, N>& to)
77 impl<T, N>().convert(from, to);
81template <
typename T,
int N>
82void MatrixFormatConverter::convert(
const DeviceBCOOVector<T, N>& from,
83 DeviceDenseVector<T>& to,
84 bool clear_dense_vector)
87 impl<T, N>().convert(from, to, clear_dense_vector);
91template <
typename T,
int N>
92void MatrixFormatConverter::convert(
const DeviceDoubletVector<T, N>& from,
93 DeviceDenseVector<T>& to,
94 bool clear_dense_vector)
97 impl<T, N>().convert(from, to, clear_dense_vector);
101template <
typename T,
int N>
102void MatrixFormatConverter::convert(
const DeviceBSRMatrix<T, N>& from,
103 DeviceCSRMatrix<T>& to)
105 impl<T, N>().convert(from, to);
110void MatrixFormatConverter::convert(
const DeviceTripletMatrix<T, 1>& from,
111 DeviceCOOMatrix<T>& to)
113 impl<T, 1>().convert(from, to);
118void MatrixFormatConverter::convert(
const DeviceCOOMatrix<T>& from,
119 DeviceDenseMatrix<T>& to,
120 bool clear_dense_matrix)
122 impl<T, 1>().convert(from, to, clear_dense_matrix);
127void MatrixFormatConverter::convert(
const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to)
129 impl<T, 1>().convert(from, to);
132void MatrixFormatConverter::convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to)
134 impl<T, 1>().convert(std::move(from), to);
139void MatrixFormatConverter::convert(
const DeviceDoubletVector<T, 1>& from,
140 DeviceCOOVector<T>& to)
142 impl<T, 1>().convert(from, to);
147void MatrixFormatConverter::convert(
const DeviceCOOVector<T>& from,
148 DeviceDenseVector<T>& to,
149 bool clear_dense_vector)
151 impl<T, 1>().convert(from, to, clear_dense_vector);
156void MatrixFormatConverter::convert(
const DeviceDoubletVector<T, 1>& from,
157 DeviceDenseVector<T>& to,
158 bool clear_dense_vector)
160 impl<T, 1>().convert(from, to, clear_dense_vector);