Skip to content

File csr_matrix_view.h

File List > ext > linear_system > csr_matrix_view.h

Go to the documentation of this file

#pragma once
#include <muda/ext/linear_system/common.h>
#include <muda/view/view_base.h>
namespace muda
{
template <bool IsConst, typename Ty>
class CSRMatrixViewBase : public ViewBase<IsConst>
{
    using Base = ViewBase<IsConst>;
    template <typename U>
    using auto_const_t = typename Base::template auto_const_t<U>;

  public:
    static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
    using ConstView    = CSRMatrixViewBase<true, Ty>;
    using NonConstView = CSRMatrixViewBase<false, Ty>;
    using ThisView     = CSRMatrixViewBase<IsConst, Ty>;

  protected:
    // data
    int m_row = 0;
    int m_col = 0;


    auto_const_t<int>* m_row_offsets = nullptr;
    auto_const_t<int>* m_col_indices = nullptr;
    auto_const_t<Ty>*  m_values      = nullptr;
    int                m_non_zero    = 0;

    mutable cusparseSpMatDescr_t m_descr        = nullptr;
    mutable cusparseMatDescr_t   m_legacy_descr = nullptr;


    bool m_trans = false;

  public:
    CSRMatrixViewBase() = default;
    CSRMatrixViewBase(int                  row,
                      int                  col,
                      auto_const_t<int>*   row_offsets,
                      auto_const_t<int>*   col_indices,
                      auto_const_t<Ty>*    values,
                      int                  non_zero,
                      cusparseSpMatDescr_t descr,
                      cusparseMatDescr_t   legacy_descr,
                      bool                 trans)
        : m_row(row)
        , m_col(col)
        , m_row_offsets(row_offsets)
        , m_col_indices(col_indices)
        , m_values(values)
        , m_non_zero(non_zero)
        , m_descr(descr)
        , m_legacy_descr(legacy_descr)
        , m_trans(trans)
    {
    }

    ConstView as_const() const
    {
        return ConstView{
            m_row, m_col, m_row_offsets, m_col_indices, m_values, m_non_zero, m_descr, m_legacy_descr, m_trans};
    }

    // implicit conversion to const
    operator ConstView() const { return as_const(); }

    auto_const_t<Ty>*  values() { return m_values; }
    auto_const_t<int>* row_offsets() { return m_row_offsets; }
    auto_const_t<int>* col_indices() { return m_col_indices; }

    auto values() const { return m_values; }
    auto row_offsets() const { return m_row_offsets; }
    auto col_indices() const { return m_col_indices; }
    auto rows() const { return m_row; }
    auto cols() const { return m_col; }
    auto non_zeros() const { return m_non_zero; }
    auto descr() const { return m_descr; }
    auto legacy_descr() const { return m_legacy_descr; }
    auto is_trans() const { return m_trans; }
    auto T() const
    {
        return ThisView{
            m_row, m_col, m_row_offsets, m_col_indices, m_values, m_non_zero, m_descr, m_legacy_descr, !m_trans};
    }
};

template <typename Ty>
using CSRMatrixView = CSRMatrixViewBase<false, Ty>;
template <typename Ty>
using CCSRMatrixView = CSRMatrixViewBase<true, Ty>;
}  // namespace muda

namespace muda
{
template <typename Ty>
struct read_only_viewer<CSRMatrixView<Ty>>
{
    using type = CCSRMatrixView<Ty>;
};

template <typename Ty>
struct read_write_viewer<CSRMatrixView<Ty>>
{
    using type = CSRMatrixView<Ty>;
};
}  // namespace muda


#include "details/csr_matrix_view.inl"