MUDA
Loading...
Searching...
No Matches
csr_matrix_view.h
1#pragma once
2#include <muda/ext/linear_system/common.h>
3#include <muda/view/view_base.h>
4namespace muda
5{
6template <bool IsConst, typename Ty>
7class CSRMatrixViewBase : public ViewBase<IsConst>
8{
10 template <typename U>
11 using auto_const_t = typename Base::template auto_const_t<U>;
12
13 public:
14 static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
18
19 protected:
20 // data
21 int m_row = 0;
22 int m_col = 0;
23
24
25 auto_const_t<int>* m_row_offsets = nullptr;
26 auto_const_t<int>* m_col_indices = nullptr;
27 auto_const_t<Ty>* m_values = nullptr;
28 int m_non_zero = 0;
29
30 mutable cusparseSpMatDescr_t m_descr = nullptr;
31 mutable cusparseMatDescr_t m_legacy_descr = nullptr;
32
33
34 bool m_trans = false;
35
36 public:
37 CSRMatrixViewBase() = default;
38 CSRMatrixViewBase(int row,
39 int col,
40 auto_const_t<int>* row_offsets,
41 auto_const_t<int>* col_indices,
42 auto_const_t<Ty>* values,
43 int non_zero,
44 cusparseSpMatDescr_t descr,
45 cusparseMatDescr_t legacy_descr,
46 bool trans)
47 : m_row(row)
48 , m_col(col)
49 , m_row_offsets(row_offsets)
50 , m_col_indices(col_indices)
51 , m_values(values)
52 , m_non_zero(non_zero)
53 , m_descr(descr)
54 , m_legacy_descr(legacy_descr)
55 , m_trans(trans)
56 {
57 }
58
59 ConstView as_const() const
60 {
61 return ConstView{
62 m_row, m_col, m_row_offsets, m_col_indices, m_values, m_non_zero, m_descr, m_legacy_descr, m_trans};
63 }
64
65 // implicit conversion to const
66 operator ConstView() const { return as_const(); }
67
68 auto_const_t<Ty>* values() { return m_values; }
69 auto_const_t<int>* row_offsets() { return m_row_offsets; }
70 auto_const_t<int>* col_indices() { return m_col_indices; }
71
72 auto values() const { return m_values; }
73 auto row_offsets() const { return m_row_offsets; }
74 auto col_indices() const { return m_col_indices; }
75 auto rows() const { return m_row; }
76 auto cols() const { return m_col; }
77 auto non_zeros() const { return m_non_zero; }
78 auto descr() const { return m_descr; }
79 auto legacy_descr() const { return m_legacy_descr; }
80 auto is_trans() const { return m_trans; }
81 auto T() const
82 {
83 return ThisView{
84 m_row, m_col, m_row_offsets, m_col_indices, m_values, m_non_zero, m_descr, m_legacy_descr, !m_trans};
85 }
86};
87
88template <typename Ty>
90template <typename Ty>
92} // namespace muda
93
94namespace muda
95{
96template <typename Ty>
98{
100};
101
102template <typename Ty>
104{
105 using type = CSRMatrixView<Ty>;
106};
107} // namespace muda
108
109
110#include "details/csr_matrix_view.inl"
Definition csr_matrix_view.h:8
Definition view_base.h:8
Definition type_modifier.h:22
Definition type_modifier.h:28