File device_doublet_vector.h
File List > ext > linear_system > device_doublet_vector.h
Go to the documentation of this file
#pragma once
#include <muda/buffer/device_buffer.h>
#include <muda/ext/eigen/eigen_core_cxx20.h>
#include <muda/ext/linear_system/doublet_vector_view.h>
namespace muda::details
{
template <typename T, int N>
class MatrixFormatConverter;
}
namespace muda
{
template <typename T, int N>
class DeviceDoubletVector
{
template <typename U, int M>
friend class details::MatrixFormatConverter;
public:
using SegmentVector = Eigen::Vector<T, N>;
protected:
muda::DeviceBuffer<SegmentVector> m_segment_values;
muda::DeviceBuffer<int> m_segment_indices;
int m_segment_count = 0;
public:
DeviceDoubletVector() = default;
~DeviceDoubletVector() = default;
void reshape(int num_segment) { m_segment_count = num_segment; }
void resize_doublets(size_t nonzero_count)
{
m_segment_values.resize(nonzero_count);
m_segment_indices.resize(nonzero_count);
}
void reserve_doublets(size_t nonzero_count)
{
m_segment_values.reserve(nonzero_count);
m_segment_indices.reserve(nonzero_count);
}
void resize(int num_segment, size_t nonzero_count)
{
reshape(num_segment);
resize_doublets(nonzero_count);
}
void clear()
{
m_segment_values.clear();
m_segment_indices.clear();
}
auto segment_count() const { return m_segment_count; }
auto segment_values() { return m_segment_values.view(); }
auto segment_values() const { return m_segment_values.view(); }
auto segment_indices() { return m_segment_indices.view(); }
auto segment_indices() const { return m_segment_indices.view(); }
auto doublet_count() const { return m_segment_values.size(); }
auto doublet_capacity() const { return m_segment_values.capacity(); }
auto view()
{
return DoubletVectorView<T, N>{m_segment_count,
(int)m_segment_values.size(),
m_segment_indices.data(),
m_segment_values.data()};
}
auto view() const { return remove_const(*this).view().as_const(); }
auto viewer() { return view().viewer(); }
auto viewer() const { return view().cviewer(); };
};
template <typename T>
class DeviceDoubletVector<T, 1>
{
template <typename U, int M>
friend class details::MatrixFormatConverter;
protected:
muda::DeviceBuffer<T> m_values;
muda::DeviceBuffer<int> m_indices;
int m_size = 0;
public:
DeviceDoubletVector() = default;
~DeviceDoubletVector() = default;
void reshape(int num) { m_size = num; }
void resize_doublet(size_t nonzero_count)
{
m_values.resize(nonzero_count);
m_indices.resize(nonzero_count);
}
void resize(int num, size_t nonzero_count)
{
reshape(num);
resize_doublet(nonzero_count);
}
void clear()
{
m_values.clear();
m_indices.clear();
}
auto size() const { return m_size; }
auto values() { return m_values.view(); }
auto values() const { return m_values.view(); }
auto indices() { return m_indices.view(); }
auto indices() const { return m_indices.view(); }
auto doublet_count() const { return m_values.size(); }
auto view()
{
return DoubletVectorView<T, 1>{
m_size, (int)m_values.size(), m_indices.data(), m_values.data()};
}
auto view() const { return remove_const(*this).view().as_const(); }
auto viewer() { return view().viewer(); }
auto viewer() const { return view().cviewer(); };
};
} // namespace muda
#include "details/device_doublet_vector.inl"