File bcoo_matrix_view.h
File List > ext > linear_system > bcoo_matrix_view.h
Go to the documentation of this file
#pragma once
#include <muda/ext/linear_system/triplet_matrix_view.h>
#include <muda/ext/linear_system/bcoo_matrix_viewer.h>
namespace muda
{
template <typename T, int N>
using BCOOMatrixView = TripletMatrixView<T, N>;
template <typename T, int N>
using CBCOOMatrixView = CTripletMatrixView<T, N>;
} // namespace muda
namespace muda
{
template <bool IsConst, typename Ty>
class COOMatrixViewBase : public ViewBase<IsConst>
{
using Base = ViewBase<IsConst>;
template<typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
static_assert(!std::is_const_v<Ty>, "Ty must be non-const");
using NonConstView = COOMatrixViewBase<false, Ty>;
using ConstView = COOMatrixViewBase<true, Ty>;
using ThisView = COOMatrixViewBase<IsConst, Ty>;
protected:
// matrix info
int m_rows = 0;
int m_cols = 0;
// triplet info
int m_triplet_index_offset = 0;
int m_triplet_count = 0;
int m_total_triplet_count = 0;
// sub matrix info
int2 m_submatrix_offset = {0, 0};
int2 m_submatrix_extent = {0, 0};
// data
auto_const_t<int>* m_row_indices;
auto_const_t<int>* m_col_indices;
auto_const_t<Ty>* m_values;
mutable cusparseMatDescr_t m_legacy_descr = nullptr;
mutable cusparseSpMatDescr_t m_descr = nullptr;
bool m_trans = false;
public:
MUDA_GENERIC COOMatrixViewBase() = default;
MUDA_GENERIC COOMatrixViewBase(int rows,
int cols,
int triplet_index_offset,
int triplet_count,
int total_triplet_count,
int2 submatrix_offset,
int2 submatrix_extent,
auto_const_t<int>* row_indices,
auto_const_t<int>* col_indices,
auto_const_t<Ty>* values,
cusparseSpMatDescr_t descr,
cusparseMatDescr_t legacy_descr,
bool trans)
: m_rows(rows)
, m_cols(cols)
, m_triplet_index_offset(triplet_index_offset)
, m_triplet_count(triplet_count)
, m_total_triplet_count(total_triplet_count)
, m_row_indices(row_indices)
, m_col_indices(col_indices)
, m_values(values)
, m_submatrix_offset(submatrix_offset)
, m_submatrix_extent(submatrix_extent)
, m_descr(descr)
, m_legacy_descr(legacy_descr)
, m_trans(trans)
{
MUDA_KERNEL_ASSERT(triplet_index_offset + triplet_count <= total_triplet_count,
"COOMatrixView: out of range, m_total_triplet_count=%d, "
"your triplet_index_offset=%d, triplet_count=%d",
total_triplet_count,
triplet_index_offset,
triplet_count);
MUDA_KERNEL_ASSERT(submatrix_offset.x >= 0 && submatrix_offset.y >= 0,
"TripletMatrixView: submatrix_offset is out of range, submatrix_offset.x=%d, submatrix_offset.y=%d",
submatrix_offset.x,
submatrix_offset.y);
MUDA_KERNEL_ASSERT(submatrix_offset.x + submatrix_extent.x <= rows,
"TripletMatrixView: submatrix is out of range, submatrix_offset.x=%d, submatrix_extent.x=%d, total_block_rows=%d",
submatrix_offset.x,
submatrix_extent.x,
rows);
MUDA_KERNEL_ASSERT(submatrix_offset.y + submatrix_extent.y <= cols,
"TripletMatrixView: submatrix is out of range, submatrix_offset.y=%d, submatrix_extent.y=%d, total_block_cols=%d",
submatrix_offset.y,
submatrix_extent.y,
cols);
}
MUDA_GENERIC COOMatrixViewBase(int rows,
int cols,
int total_triplet_count,
auto_const_t<int>* row_indices,
auto_const_t<int>* col_indices,
auto_const_t<Ty>* values,
cusparseSpMatDescr_t descr,
cusparseMatDescr_t legacy_descr,
bool trans)
: COOMatrixViewBase(rows,
cols,
0,
total_triplet_count,
total_triplet_count,
{0, 0},
{rows, cols},
row_indices,
col_indices,
values,
descr,
legacy_descr,
trans)
{
}
MUDA_GENERIC auto as_const() const
{
return ConstView{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values,
m_descr,
m_legacy_descr,
m_trans};
}
MUDA_GENERIC operator ConstView() const { return as_const(); }
MUDA_GENERIC auto cviewer() const
{
MUDA_KERNEL_ASSERT(!m_trans,
"COOMatrixView: cviewer() is not supported for "
"transposed matrix, please use a non-transposed view of this matrix");
return CTripletMatrixViewer<Ty, 1>{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values};
}
MUDA_GENERIC auto viewer()
{
MUDA_ASSERT(!m_trans,
"COOMatrixView: viewer() is not supported for "
"transposed matrix, please use a non-transposed view of this matrix");
return TripletMatrixViewer<Ty, 1>{m_rows,
m_cols,
m_triplet_index_offset,
m_triplet_count,
m_total_triplet_count,
m_submatrix_offset,
m_submatrix_extent,
m_row_indices,
m_col_indices,
m_values};
}
// non-const access
auto_const_t<Ty>* block_values() { return m_values; }
auto_const_t<int>* block_row_indices() { return m_row_indices; }
auto_const_t<int>* block_col_indices() { return m_col_indices; }
// const access
auto block_values() const { return m_values; }
auto block_row_indices() const { return m_row_indices; }
auto block_col_indices() const { return m_col_indices; }
auto block_rows() const { return m_rows; }
auto block_cols() const { return m_cols; }
auto triplet_count() const { return m_triplet_count; }
auto tripet_index_offset() const { return m_triplet_index_offset; }
auto total_triplet_count() const { return m_total_triplet_count; }
auto is_trans() const { return m_trans; }
auto legacy_descr() const { return m_legacy_descr; }
auto descr() const { return m_descr; }
};
template <typename Ty>
using COOMatrixView = COOMatrixViewBase<false, Ty>;
template <typename Ty>
using CCOOMatrixView = COOMatrixViewBase<true, Ty>;
} // namespace muda
namespace muda
{
template <typename T>
struct read_only_viewer<COOMatrixView<T>>
{
using type = CCOOMatrixView<T>;
};
template <typename T>
struct read_write_viewer<CCOOMatrixView<T>>
{
using type = COOMatrixView<T>;
};
} // namespace muda
#include "details/bcoo_matrix_view.inl"