File bcoo_vector_view.h
File List > ext > linear_system > bcoo_vector_view.h
Go to the documentation of this file
#pragma once
#include <muda/ext/linear_system/doublet_vector_view.h>
#include <muda/ext/linear_system/bcoo_vector_viewer.h>
namespace muda
{
template <typename T, int N>
using BCOOVectorView = DoubletVectorView<T, N>;
template <typename T, int N>
using CBCOOVectorView = CDoubletVectorView<T, N>;
} // namespace muda
namespace muda
{
template <bool IsConst, typename T>
class COOVectorViewBase : public ViewBase<IsConst>
{
using Base = ViewBase<IsConst>;
template <typename U>
using auto_const_t = typename Base::template auto_const_t<U>;
public:
static_assert(!std::is_const_v<T>, "T must be non-const");
using NonConstView = COOVectorViewBase<false, T>;
using ConstView = COOVectorViewBase<true, T>;
using ThisView = COOVectorViewBase<IsConst, T>;
using CViewer = CCOOVectorViewer<T>;
using Viewer = COOVectorViewer<T>;
using ThisViewer = std::conditional_t<IsConst, CViewer, Viewer>;
protected:
// vector info
int m_size = 0;
//doublet info
int m_doublet_index_offset = 0;
int m_doublet_count = 0;
int m_total_doublet_count = 0;
// data
auto_const_t<int>* m_indices = nullptr;
auto_const_t<T>* m_values = nullptr;
mutable cusparseSpVecDescr_t m_descr = nullptr;
public:
MUDA_GENERIC COOVectorViewBase() = default;
MUDA_GENERIC COOVectorViewBase(int size,
int doublet_index_offset,
int doublet_count,
int total_doublet_count,
auto_const_t<int>* indices,
auto_const_t<T>* values,
cusparseSpVecDescr_t descr)
: m_size(size)
, m_doublet_index_offset(doublet_index_offset)
, m_doublet_count(doublet_count)
, m_total_doublet_count(total_doublet_count)
, m_indices(indices)
, m_values(values)
, m_descr(descr)
{
MUDA_KERNEL_ASSERT(doublet_index_offset + doublet_count <= total_doublet_count,
"COOVectorView: out of range, m_total_doublet_count=%d, "
"your doublet_index_offset=%d, doublet_count=%d",
total_doublet_count,
doublet_index_offset,
doublet_count);
}
// implicit conversion
MUDA_GENERIC auto as_const() const -> ConstView
{
return ConstView{m_size,
m_doublet_index_offset,
m_doublet_count,
m_total_doublet_count,
m_indices,
m_values,
m_descr};
}
MUDA_GENERIC operator ConstView() const { return as_const(); }
// non-const accessor
MUDA_GENERIC auto viewer()
{
return ThisViewer{
m_size, m_doublet_index_offset, m_doublet_count, m_total_doublet_count, m_indices, m_values};
}
MUDA_GENERIC auto subview(int offset, int count)
{
return ThisView{m_size,
m_doublet_index_offset + offset,
count,
m_total_doublet_count,
m_indices,
m_values,
m_descr};
}
MUDA_GENERIC auto subview(int offset)
{
return subview(offset, m_doublet_count - offset);
}
// const accessor
MUDA_GENERIC ConstView subview(int offset, int count) const
{
return remove_const(*this).subview(offset, count);
}
MUDA_GENERIC ConstView subview(int offset) const
{
return remove_const(*this).subview(offset);
}
MUDA_GENERIC auto cviewer() const { return remove_const(*this).viewer(); }
MUDA_GENERIC auto vector_size() const { return m_size; }
MUDA_GENERIC auto doublet_index_offset() const
{
return m_doublet_index_offset;
}
MUDA_GENERIC auto doublet_count() const { return m_doublet_count; }
MUDA_GENERIC auto total_doublet_count() const
{
return m_total_doublet_count;
}
MUDA_GENERIC auto descr() const { return m_descr; }
};
template <typename T>
using COOVectorView = COOVectorViewBase<false, T>;
template <typename T>
using CCOOVectorView = COOVectorViewBase<true, T>;
} // namespace muda
namespace muda
{
template <typename T>
struct read_only_viewer<COOVectorView<T>>
{
using type = CCOOVectorView<T>;
};
template <typename T>
struct read_write_viewer<CCOOVectorView<T>>
{
using type = COOVectorView<T>;
};
} // namespace muda
#include "details/bcoo_vector_view.inl"