File bcoo_matrix_view.h
File List > ext > linear_system > bcoo_matrix_view.h
Go to the documentation of this file
#pragma once
#include <muda/ext/linear_system/triplet_matrix_view.h>
#include <muda/ext/linear_system/bcoo_matrix_viewer.h>
namespace muda
{
template <typename T, int N>
using BCOOMatrixView = TripletMatrixView<T, N>;
template <typename T, int N>
using CBCOOMatrixView = CTripletMatrixView<T, N>;
} // namespace muda
namespace muda
{
template <bool IsConst, typename Ty>
class COOMatrixViewT : public ViewBase<IsConst>
{
using Base = ViewBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
template <bool OtherIsConst, typename U>
friend class COOMatrixViewT;
public:
static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
using NonConstView = COOMatrixViewT<false, Ty>;
using ConstView = COOMatrixViewT<true, Ty>;
using ThisView = COOMatrixViewT<IsConst, Ty>;
protected:
// matrix info
int m_rows = 0;
int m_cols = 0;
// triplet info
int m_triplet_index_offset = 0;
int m_triplet_count = 0;
int m_total_triplet_count = 0;
// sub matrix info
int2 m_submatrix_offset = {0, 0};
int2 m_submatrix_extent = {0, 0};
// data
auto_const_t<int>* m_row_indices;
auto_const_t<int>* m_col_indices;
auto_const_t<Ty>* m_values;
mutable cusparseMatDescr_t m_legacy_descr = nullptr;
mutable cusparseSpMatDescr_t m_descr = nullptr;
bool m_trans = false;
public:
MUDA_GENERIC COOMatrixViewT() = default;
MUDA_GENERIC COOMatrixViewT(int rows,
int cols,
int triplet_index_offset,
int triplet_count,
int total_triplet_count,
int2 submatrix_offset,
int2 submatrix_extent,
auto_const_t<int>* row_indices,
auto_const_t<int>* col_indices,
auto_const_t<Ty>* values,
cusparseSpMatDescr_t descr,
cusparseMatDescr_t legacy_descr,
bool trans)
: m_rows(rows)
, m_cols(cols)
, m_triplet_index_offset(triplet_index_offset)
, m_triplet_count(triplet_count)
, m_total_triplet_count(total_triplet_count)
, m_row_indices(row_indices)
, m_col_indices(col_indices)
, m_values(values)
, m_submatrix_offset(submatrix_offset)
, m_submatrix_extent(submatrix_extent)
, m_descr(descr)
, m_legacy_descr(legacy_descr)
, m_trans(trans)
{
MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
"COOMatrixView: out of range, m_total_triplet_count=%d, "
"your triplet_index_offset=%d, triplet_count=%d",
total_triplet_count,
triplet_index_offset,
triplet_count);
MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
"TripletMatrixView: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
submatrix_offset.x,
submatrix_offset.y);
MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= rows,
"TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
submatrix_offset.x,
submatrix_extent.x,
rows);
MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= cols,
"TripletMatrixView: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
submatrix_offset.y,
submatrix_extent.y,
cols);
}
MUDA_GENERIC COOMatrixViewT(int rows,
int cols,
int total_triplet_count,
auto_const_t<int>* row_indices,
auto_const_t<int>* col_indices,
auto_const_t<Ty>* values,
cusparseSpMatDescr_t descr,
cusparseMatDescr_t legacy_descr,
bool trans)
: COOMatrixViewT(rows,
cols,
0,
total_triplet_count,
total_triplet_count,
{0, 0},
{rows, cols},
row_indices,
col_indices,
values,
descr,
legacy_descr,
trans)
{
}
template <bool OtherIsConst>
MUDA_GENERIC COOMatrixViewT(const COOMatrixViewT<OtherIsConst, Ty>& other)
: m_rows(other.m_rows)
, m_cols(other.m_cols)
, m_triplet_index_offset(other.m_triplet_index_offset)
, m_triplet_count(other.m_triplet_count)
, m_total_triplet_count(other.m_total_triplet_count)
, m_submatrix_offset(other.m_submatrix_offset)
, m_submatrix_extent(other.m_submatrix_extent)
, m_row_indices(other.m_row_indices)
, m_col_indices(other.m_col_indices)
, m_values(other.m_values)
, m_descr(other.m_descr)
, m_legacy_descr(other.m_legacy)
{
}
MUDA_GENERIC auto as_const() const
{
return ConstView{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values,
m_descr,
m_legacy_descr,
m_trans};
}
MUDA_GENERIC auto cviewer() const
{
MUDA_KERNEL_ASSERT(!m_trans,
"COOMatrixView: cviewer() is not supported for "
"transposed matrix, please use a non-transposed view of this matrix");
return CTripletMatrixViewer<Ty, 1>{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values};
}
MUDA_GENERIC auto viewer()
{
MUDA_ASSERT(!m_trans,
"COOMatrixView: viewer() is not supported for "
"transposed matrix, please use a non-transposed view of this matrix");
return TripletMatrixViewer<Ty, 1>{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values};
}
// const access
auto values() const { return m_values; }
auto row_indices() const { return m_row_indices; }
auto col_indices() const { return m_col_indices; }
auto rows() const { return m_rows; }
auto cols() const { return m_cols; }
auto triplet_count() const { return m_triplet_count; }
auto tripet_index_offset() const { return m_triplet_index_offset; }
auto total_triplet_count() const { return m_total_triplet_count; }
auto is_trans() const { return m_trans; }
auto legacy_descr() const { return m_legacy_descr; }
auto descr() const { return m_descr; }
};
template <typename Ty>
using COOMatrixView = COOMatrixViewT<false, Ty>;
template <typename Ty>
using CCOOMatrixView = COOMatrixViewT<true, Ty>;
} // namespace muda
namespace muda
{
template <typename T>
struct read_only_view<COOMatrixView<T>>
{
using type = CCOOMatrixView<T>;
};
template <typename T>
struct read_write_view<CCOOMatrixView<T>>
{
using type = COOMatrixView<T>;
};
} // namespace muda
#include "details/bcoo_matrix_view.inl"