MUDA
Loading...
Searching...
No Matches
matrix_format_converter.h
1#pragma once
2#include <muda/ext/linear_system/linear_system_handles.h>
3#include <muda/ext/linear_system/device_dense_matrix.h>
4#include <muda/ext/linear_system/device_dense_vector.h>
5#include <muda/ext/linear_system/device_triplet_matrix.h>
6#include <muda/ext/linear_system/device_doublet_vector.h>
7#include <muda/ext/linear_system/device_bcoo_matrix.h>
8#include <muda/ext/linear_system/device_bcoo_vector.h>
9#include <muda/ext/linear_system/device_bsr_matrix.h>
10#include <muda/ext/linear_system/device_csr_matrix.h>
11
12namespace muda::details
13{
14class MatrixFormatConverterBase;
15template <typename T, int N>
16class MatrixFormatConverter;
17
19{
20 public:
21 cudaDataType_t data_type;
22 int N;
23 bool friend operator==(const MatrixFormatConverterType& lhs,
25 {
26 return lhs.data_type == rhs.data_type && lhs.N == rhs.N;
27 }
28};
29} // namespace muda::details
30
31namespace std
32{
33template <>
34struct hash<muda::details::MatrixFormatConverterType>
35{
36 size_t operator()(const muda::details::MatrixFormatConverterType& x) const
37 {
38 return (std::hash<int>()(x.data_type) << 8) ^ std::hash<int>()(x.N);
39 }
40};
41} // namespace std
42
43
44namespace muda
45{
46
48{
49 template <typename T>
50 using U = std::unique_ptr<T>;
51 LinearSystemHandles& m_handles;
52 using TypeN = std::pair<cudaDataType_t, int>;
53 std::unordered_map<details::MatrixFormatConverterType, U<details::MatrixFormatConverterBase>> m_impls;
54 details::MatrixFormatConverterBase* current = nullptr;
55 template <typename T, int N>
57
58 public:
60 : m_handles(handles)
61 {
62 }
64
65 // Triplet -> BCOO
66 template <typename T, int N>
67 void convert(const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to);
68
69 // BCOO -> Dense Matrix
70 template <typename T, int N>
71 void convert(const DeviceBCOOMatrix<T, N>& from,
73 bool clear_dense_matrix = true);
74
75 // BCOO -> COO
76 template <typename T, int N>
77 void convert(const DeviceBCOOMatrix<T, N>& from, DeviceCOOMatrix<T>& to);
78
79 // BCOO -> BSR
80 template <typename T, int N>
81 void convert(const DeviceBCOOMatrix<T, N>& from, DeviceBSRMatrix<T, N>& to);
82
83 // Doublet -> BCOO
84 template <typename T, int N>
85 void convert(const DeviceDoubletVector<T, N>& from, DeviceBCOOVector<T, N>& to);
86
87 // BCOO -> Dense Vector
88 template <typename T, int N>
89 void convert(const DeviceBCOOVector<T, N>& from,
91 bool clear_dense_vector = true);
92
93 // Doublet -> Dense Vector
94 template <typename T, int N>
95 void convert(const DeviceDoubletVector<T, N>& from,
97 bool clear_dense_vector = true);
98
99 // BSR -> CSR
100 template <typename T, int N>
101 void convert(const DeviceBSRMatrix<T, N>& from, DeviceCSRMatrix<T>& to);
102
103 // Triplet -> COO
104 template <typename T>
105 void convert(const DeviceTripletMatrix<T, 1>& from, DeviceCOOMatrix<T>& to);
106
107 // COO -> Dense Matrix
108 template <typename T>
109 void convert(const DeviceCOOMatrix<T>& from,
111 bool clear_dense_matrix = true);
112
113 // COO -> CSR
114 template <typename T>
115 void convert(const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to);
116 template <typename T>
117 void convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to);
118
119 // Doublet -> COO
120 template <typename T>
121 void convert(const DeviceDoubletVector<T, 1>& from, DeviceCOOVector<T>& to);
122
123 // COO -> Dense Vector
124 template <typename T>
125 void convert(const DeviceCOOVector<T>& from,
127 bool clear_dense_vector = true);
128
129 // Doublet -> Dense Vector
130 template <typename T>
131 void convert(const DeviceDoubletVector<T, 1>& from,
133 bool clear_dense_vector = true);
134};
135} // namespace muda
136
137#include "details/matrix_format_converter.inl"
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:28
Definition device_bcoo_vector.h:10
Definition device_bsr_matrix.h:16
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_triplet_matrix.h:14
Definition linear_system_handles.h:16
Definition matrix_format_converter.h:48
Definition matrix_format_converter_impl.h:19
Definition matrix_format_converter_impl.h:53
Definition matrix_format_converter.h:19