MUDA
Loading...
Searching...
No Matches
device_bsr_matrix.h
1#pragma once
3#include <muda/ext/linear_system/bsr_matrix_view.h>
4#include <cusparse.h>
5
6namespace muda::details
7{
8template <typename T, int N>
9class MatrixFormatConverter;
10}
11
12namespace muda
13{
14template <typename Ty, int N>
16{
17 friend class details::MatrixFormatConverter<Ty, N>;
18 static_assert(N >= 2, "Block size must be >= 2");
19
20 public:
21 using BlockMatrix = Eigen::Matrix<Ty, N, N>;
22
23 protected:
25 muda::DeviceBuffer<int> m_block_row_offsets;
26 muda::DeviceBuffer<int> m_block_col_indices;
27 mutable cusparseSpMatDescr_t m_descr = nullptr;
28 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
29
30 int m_row = 0;
31 int m_col = 0;
32
33 public:
34 DeviceBSRMatrix() = default;
36
39
40 DeviceBSRMatrix& operator=(const DeviceBSRMatrix&);
41 DeviceBSRMatrix& operator=(DeviceBSRMatrix&&);
42
43 void reshape(int row, int col);
44 void reserve(int non_zero_blocks);
45 void reserve_offsets(int size);
46 void resize(int non_zero_blocks);
47
48 static constexpr int block_size() { return N; }
49
50 auto block_values() { return m_block_values.view(); }
51 auto block_values() const { return m_block_values.view(); }
52
53 auto block_row_offsets() { return m_block_row_offsets.view(); }
54 auto block_row_offsets() const { return m_block_row_offsets.view(); }
55
56 auto block_col_indices() { return m_block_col_indices.view(); }
57 auto block_col_indices() const { return m_block_col_indices.view(); }
58
59 auto block_rows() const { return m_row; }
60 auto block_cols() const { return m_col; }
61 auto non_zero_blocks() const { return m_block_values.size(); }
62
63 cusparseSpMatDescr_t descr() const;
64 cusparseMatDescr_t legacy_descr() const;
65
66 auto view()
67 {
68 return BSRMatrixView<Ty, N>{m_row,
69 m_col,
70 m_block_row_offsets.data(),
71 m_block_col_indices.data(),
72 m_block_values.data(),
73 (int)m_block_values.size(),
74 descr(),
75 legacy_descr(),
76 false};
77 }
78
79 operator BSRMatrixView<Ty, N>() { return view(); }
80
81 auto view() const
82 {
83 return CBSRMatrixView<Ty, N>{m_row,
84 m_col,
85 m_block_row_offsets.data(),
86 m_block_col_indices.data(),
87 m_block_values.data(),
88 (int)m_block_values.size(),
89 descr(),
90 legacy_descr(),
91 false};
92 }
93
94 operator CBSRMatrixView<Ty, N>() const { return view(); }
95
96 auto cview() const { return view(); }
97
98 auto T() const { return view().T(); }
99 auto T() { return view().T(); }
100
101 void clear();
102
103 private:
104 void destroy_all_descr() const;
105};
106} // namespace muda
107
108#include "details/device_bsr_matrix.inl"
Definition bsr_matrix_view.h:8
Definition device_bsr_matrix.h:16
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition matrix_format_converter_impl.h:53
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...