MUDA
Loading...
Searching...
No Matches
bsr_matrix_view.h
1#pragma once
2#include <cusparse_v2.h>
3#include <muda/view/view_base.h>
4namespace muda
5{
6template <bool IsConst, typename Ty, int N>
7class BSRMatrixViewBase : public ViewBase<IsConst>
8{
10 template <typename U>
11 using auto_const_t = typename Base::template auto_const_t<U>;
12
13 public:
14 static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
18
19 using BlockMatrix = Eigen::Matrix<Ty, N, N>;
20
21 protected:
22 // data
23 int m_row = 0;
24 int m_col = 0;
25
26 auto_const_t<int>* m_block_row_offsets = nullptr;
27 auto_const_t<int>* m_block_col_indices = nullptr;
28 auto_const_t<BlockMatrix>* m_block_values = nullptr;
29 int m_non_zeros = 0;
30
31 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
32 mutable cusparseSpMatDescr_t m_descr = nullptr;
33
34 bool m_trans = false;
35
36 public:
37 BSRMatrixViewBase() = default;
38 BSRMatrixViewBase(int row,
39 int col,
40 auto_const_t<int>* block_row_offsets,
41 auto_const_t<int>* block_col_indices,
42 auto_const_t<BlockMatrix>* block_values,
43 int non_zeros,
44 cusparseSpMatDescr_t descr,
45 cusparseMatDescr_t legacy_descr,
46 bool trans)
47 : m_row(row)
48 , m_col(col)
49 , m_block_row_offsets(block_row_offsets)
50 , m_block_col_indices(block_col_indices)
51 , m_block_values(block_values)
52 , m_non_zeros(non_zeros)
53 , m_descr(descr)
54 , m_legacy_descr(legacy_descr)
55 , m_trans(trans)
56
57 {
58 }
59
60 // explicit conversion to non-const
61 ConstView as_const() const
62 {
63 return ConstView{m_row,
64 m_col,
65 m_block_row_offsets,
66 m_block_col_indices,
67 m_block_values,
68 m_non_zeros,
69 m_descr,
70 m_legacy_descr,
71 m_trans};
72 }
73
74 // implicit conversion to const
75 operator ConstView() const { return as_const(); }
76
77 // non-const access
78 auto_const_t<BlockMatrix>* block_values() { return m_block_values; }
79 auto_const_t<int>* block_row_offsets() { return m_block_row_offsets; }
80 auto_const_t<int>* block_col_indices() { return m_block_col_indices; }
81
82 // const access
83 auto block_values() const { return m_block_values; }
84 auto block_row_offsets() const { return m_block_row_offsets; }
85 auto block_col_indices() const { return m_block_col_indices; }
86
87 auto block_rows() const { return m_row; }
88 auto block_cols() const { return m_col; }
89 auto non_zero_blocks() const { return m_non_zeros; }
90
91 auto legacy_descr() const { return m_legacy_descr; }
92 auto descr() const { return m_descr; }
93 auto is_trans() const { return m_trans; }
94
95 auto T() const
96 {
97 return ThisView{m_row,
98 m_col,
99 m_block_row_offsets,
100 m_block_col_indices,
101 m_block_values,
102 m_non_zeros,
103 m_descr,
104 m_legacy_descr,
105 !m_trans};
106 }
107};
108
109template <typename Ty, int N>
111template <typename Ty, int N>
113} // namespace muda
114
115namespace muda
116{
117template <typename Ty, int N>
119{
121};
122
123template <typename Ty, int N>
125{
127};
128} // namespace muda
129
130
131#include "details/bsr_matrix_view.inl"
Definition bsr_matrix_view.h:8
Definition view_base.h:8
Definition type_modifier.h:22
Definition type_modifier.h:28