File device_bcoo_vector.h
File List > ext > linear_system > device_bcoo_vector.h
Go to the documentation of this file
#pragma once
#include <muda/buffer/device_buffer.h>
#include <muda/ext/linear_system/bcoo_vector_view.h>
#include <muda/ext/linear_system/device_doublet_vector.h>
namespace muda
{
template <typename T, int N>
class DeviceBCOOVector : public DeviceDoubletVector<T, N>
{
friend class details::MatrixFormatConverter<T, N>;
public:
using SegmentVector = Eigen::Matrix<T, N, 1>;
DeviceBCOOVector() = default;
~DeviceBCOOVector() = default;
DeviceBCOOVector(const DeviceBCOOVector&) = default;
DeviceBCOOVector(DeviceBCOOVector&&) = default;
DeviceBCOOVector& operator=(const DeviceBCOOVector&) = default;
DeviceBCOOVector& operator=(DeviceBCOOVector&&) = default;
auto non_zeros() const { return this->m_values.size(); }
};
template <typename T>
class DeviceBCOOVector<T, 1> : public DeviceDoubletVector<T, 1>
{
template <typename U, int N>
friend class details::MatrixFormatConverter;
protected:
mutable cusparseSpVecDescr_t m_descr = nullptr;
public:
DeviceBCOOVector() = default;
~DeviceBCOOVector() { destroy_descr(); }
DeviceBCOOVector(const DeviceBCOOVector& other)
: DeviceDoubletVector<T, 1>(other)
, m_descr(nullptr)
{
}
DeviceBCOOVector(DeviceBCOOVector&& other)
: DeviceDoubletVector<T, 1>(std::move(other))
, m_descr(other.m_descr)
{
other.m_descr = nullptr;
}
DeviceBCOOVector& operator=(const DeviceBCOOVector& other)
{
DeviceDoubletVector<T, 1>::operator=(other);
destroy_descr();
return *this;
}
DeviceBCOOVector& operator=(DeviceBCOOVector&& other)
{
DeviceDoubletVector<T, 1>::operator=(std::move(other));
destroy_descr();
m_descr = other.m_descr;
other.m_descr = nullptr;
return *this;
}
auto non_zeros() const { return this->m_values.size(); }
auto descr() const
{
if(!m_descr)
{
checkCudaErrors(cusparseCreateSpVec(
&m_descr,
this->m_size,
this->m_values.size(),
(int*)this->m_indices.data(),
(T*)this->m_values.data(),
cusparse_index_type<decltype(this->m_indices)::value_type>(),
CUSPARSE_INDEX_BASE_ZERO,
cuda_data_type<T>()));
}
return m_descr;
}
private:
void destroy_descr() const
{
if(m_descr)
{
checkCudaErrors(cusparseDestroySpVec(m_descr));
m_descr = nullptr;
}
}
};
template <typename T>
using DeviceCOOVector = DeviceBCOOVector<T, 1>;
} // namespace muda
#include "details/device_bcoo_vector.inl"