MUDA
Loading...
Searching...
No Matches
device_csr_matrix.h
1#pragma once
3#include <cusparse.h>
4#include <muda/ext/linear_system/csr_matrix_view.h>
5
6namespace muda::details
7{
8template <typename T, int N>
9class MatrixFormatConverter;
10}
11
12namespace muda
13{
14template <typename Ty>
16{
17 template <typename T, int N>
19
20 public:
21 int m_row = 0;
22 int m_col = 0;
23
24 muda::DeviceBuffer<int> m_row_offsets;
25 muda::DeviceBuffer<int> m_col_indices;
27
28 mutable cusparseSpMatDescr_t m_descr = nullptr;
29 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
30
31 public:
32 DeviceCSRMatrix() = default;
34
37
38 DeviceCSRMatrix& operator=(const DeviceCSRMatrix&);
39 DeviceCSRMatrix& operator=(DeviceCSRMatrix&&) noexcept;
40
41 void reshape(int row, int col);
42 void reserve(int non_zeros);
43
44 auto values() { return m_values.view(); }
45 auto values() const { return m_values.view(); }
46
47 auto row_offsets() { return m_row_offsets.view(); }
48 auto row_offsets() const { return m_row_offsets.view(); }
49
50 auto col_indices() { return m_col_indices.view(); }
51 auto col_indices() const { return m_col_indices.view(); }
52
53 auto rows() const { return m_row; }
54 auto cols() const { return m_col; }
55 auto non_zeros() const { return m_values.size(); }
56
57 cusparseSpMatDescr_t descr() const;
58 cusparseMatDescr_t legacy_descr() const;
59
60 auto view()
61 {
62 return CSRMatrixView<Ty>{m_row,
63 m_col,
64 m_row_offsets.data(),
65 m_col_indices.data(),
66 m_values.data(),
67 (int)non_zeros(),
68 descr(),
69 legacy_descr(),
70 false};
71 }
72
73 auto view() const
74 {
75 return CCSRMatrixView<Ty>{m_row,
76 m_col,
77 m_row_offsets.data(),
78 m_col_indices.data(),
79 m_values.data(),
80 (int)non_zeros(),
81 descr(),
82 legacy_descr(),
83 false};
84 }
85
86 auto cview() const { return view(); }
87
88 auto T() const { return view().T(); }
89 auto T() { return view().T(); }
90 operator CSRMatrixView<Ty>() { return view(); }
91 operator CCSRMatrixView<Ty>() const { return view(); }
92
93 void clear();
94
95 private:
96 void destroy_all_descr() const;
97};
98} // namespace muda
99#include "details/device_csr_matrix.inl"
Definition csr_matrix_view.h:8
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition device_csr_matrix.h:16
Definition matrix_format_converter_impl.h:53
A light-weight wrapper of cuda device memory. Like std::vector, allow user to resize,...