File linear_system_context.h
File List > ext > linear_system > linear_system_context.h
Go to the documentation of this file
#pragma once
#include <cublas_v2.h>
#include <cusparse_v2.h>
#include <cusolverDn.h>
#include <cusolverSp.h>
#include <list>
#include <muda/buffer/device_buffer.h>
#include <muda/literal/unit.h>
#include <muda/mstl/span.h>
#include <muda/ext/linear_system/dense_vector_view.h>
#include <muda/ext/linear_system/dense_matrix_view.h>
#include <muda/ext/linear_system/matrix_format_converter.h>
#include <muda/ext/linear_system/linear_system_handles.h>
#include <muda/ext/linear_system/linear_system_solve_tolerance.h>
#include <muda/ext/linear_system/linear_system_solve_reorder.h>
namespace muda
{
class LinearSystemContextCreateInfo
{
public:
cudaStream_t stream = nullptr;
// base size of temp buffer, if buffer is not enough
// we create a new buffer with size = buffer_byte_size_base * 2 / 4 / 8 / 16 / ...
// and we will not release the old buffer because of safety
size_t buffer_byte_size_base = 256_M;
};
class LinearSystemContext
{
private:
LinearSystemHandles m_handles;
std::list<DeviceBuffer<std::byte>> m_buffers;
std::list<std::vector<std::byte>> m_host_buffers;
DeviceBuffer<std::byte> m_scalar_buffer;
LinearSystemContextCreateInfo m_create_info;
std::list<std::function<void()>> m_sync_callbacks;
std::string m_current_label;
void set_pointer_mode_device();
void set_pointer_mode_host();
void shrink_temp_buffers();
void add_sync_callback(std::function<void()>&& callback);
BufferView<std::byte> temp_buffer(size_t size);
span<std::byte> temp_host_buffer(size_t size);
template <typename T>
BufferView<T> temp_buffer(size_t size);
template <typename T>
span<T> temp_host_buffer(size_t size);
template <typename T>
std::vector<T*> temp_buffers(size_t size_in_buffer, size_t num_buffer);
template <typename T>
std::vector<T*> temp_host_buffers(size_t size_in_buffer, size_t num_buffer);
LinearSystemSolveTolerance m_tolerance;
LinearSystemSolveReorder m_reorder;
MatrixFormatConverter m_converter;
private:
auto cublas() const { return m_handles.cublas(); }
auto cusparse() const { return m_handles.cusparse(); }
auto cusolver_dn() const { return m_handles.cusolver_dn(); }
auto cusolver_sp() const { return m_handles.cusolver_sp(); }
public:
LinearSystemContext(const LinearSystemContextCreateInfo& info = {});
LinearSystemContext(const LinearSystemContext&) = delete;
LinearSystemContext& operator=(const LinearSystemContext&) = delete;
LinearSystemContext(LinearSystemContext&&) = delete;
LinearSystemContext& operator=(LinearSystemContext&&) = delete;
~LinearSystemContext();
void label(std::string_view label) { m_current_label = label; }
auto label() const -> std::string_view { return m_current_label; }
auto stream() const { return m_handles.stream(); }
void stream(cudaStream_t stream);
void sync();
/***********************************************************************************************
Settings
***********************************************************************************************/
auto& tolerance() { return m_tolerance; }
auto& reorder() { return m_reorder; }
auto reserve_ratio() const { return m_handles.m_reserve_ratio; }
void reserve_ratio(float ratio) { m_handles.m_reserve_ratio = ratio; }
public:
/***********************************************************************************************
Converter
***********************************************************************************************/
// Triplet -> BCOO
template <typename T, int N>
void convert(const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to);
// BCOO -> Dense Matrix
template <typename T, int N>
void convert(const DeviceBCOOMatrix<T, N>& from,
DeviceDenseMatrix<T>& to,
bool clear_dense_matrix = true);
// BCOO -> COO
template <typename T, int N>
void convert(const DeviceBCOOMatrix<T, N>& from, DeviceCOOMatrix<T>& to);
// BCOO -> BSR
template <typename T, int N>
void convert(const DeviceBCOOMatrix<T, N>& from, DeviceBSRMatrix<T, N>& to);
// Doublet -> BCOO
template <typename T, int N>
void convert(const DeviceDoubletVector<T, N>& from, DeviceBCOOVector<T, N>& to);
// BCOO -> Dense Vector
template <typename T, int N>
void convert(const DeviceBCOOVector<T, N>& from,
DeviceDenseVector<T>& to,
bool clear_dense_vector = true);
// Doublet -> Dense Vector
template <typename T, int N>
void convert(const DeviceDoubletVector<T, N>& from,
DeviceDenseVector<T>& to,
bool clear_dense_vector = true);
// BSR -> CSR
template <typename T, int N>
void convert(const DeviceBSRMatrix<T, N>& from, DeviceCSRMatrix<T>& to);
// Triplet -> COO
template <typename T>
void convert(const DeviceTripletMatrix<T, 1>& from, DeviceCOOMatrix<T>& to);
// COO -> Dense Matrix
template <typename T>
void convert(const DeviceCOOMatrix<T>& from,
DeviceDenseMatrix<T>& to,
bool clear_dense_matrix = true);
// COO -> CSR
template <typename T>
void convert(const DeviceCOOMatrix<T>& from, DeviceCSRMatrix<T>& to);
template <typename T>
void convert(DeviceCOOMatrix<T>&& from, DeviceCSRMatrix<T>& to);
// Doublet -> COO
template <typename T>
void convert(const DeviceDoubletVector<T, 1>& from, DeviceCOOVector<T>& to);
// COO -> Dense Vector
template <typename T>
void convert(const DeviceCOOVector<T>& from,
DeviceDenseVector<T>& to,
bool clear_dense_vector = true);
// Doublet -> Dense Vector
template <typename T>
void convert(const DeviceDoubletVector<T, 1>& from,
DeviceDenseVector<T>& to,
bool clear_dense_vector = true);
public:
/***********************************************************************************************
Norm
***********************************************************************************************/
template <typename T>
T norm(CDenseVectorView<T> x);
template <typename T>
void norm(CDenseVectorView<T> x, VarView<T> result);
template <typename T>
void norm(CDenseVectorView<T> x, T* result);
/***********************************************************************************************
Dot
***********************************************************************************************/
template <typename T>
T dot(CDenseVectorView<T> x, CDenseVectorView<T> y);
template <typename T>
void dot(CDenseVectorView<T> x, CDenseVectorView<T> y, VarView<T> result);
template <typename T>
void dot(CDenseVectorView<T> x, CDenseVectorView<T> y, T* result);
/***********************************************************************************************
Max/Min
***********************************************************************************************/
//TODO:
/***********************************************************************************************
Axpby
y = alpha * x + beta * y
***********************************************************************************************/
// y = alpha * x + beta * y
template <typename T>
void axpby(const T& alpha, CDenseVectorView<T> x, const T& beta, DenseVectorView<T> y);
// y = alpha * x + beta * y
template <typename T>
void axpby(CVarView<T> alpha, CDenseVectorView<T> x, CVarView<T> beta, DenseVectorView<T> y);
// z = x + y
template <typename T>
void plus(CDenseVectorView<T> x, CDenseVectorView<T> y, DenseVectorView<T> z);
/***********************************************************************************************
Spmv
y = a * A * x + b * y
***********************************************************************************************/
// BSR
template <typename T, int N>
void spmv(const T& a,
CBSRMatrixView<T, N> A,
CDenseVectorView<T> x,
const T& b,
DenseVectorView<T>& y);
template <typename T, int N>
void spmv(CBSRMatrixView<T, N> A, CDenseVectorView<T> x, DenseVectorView<T> y);
// CSR
template <typename T>
void spmv(const T& a, CCSRMatrixView<T> A, CDenseVectorView<T> x, const T& b, DenseVectorView<T>& y);
template <typename T>
void spmv(CCSRMatrixView<T> A, CDenseVectorView<T> x, DenseVectorView<T> y);
// BCOO & Triplet
template <typename T, int N>
void spmv(const T& a,
CTripletMatrixView<T, N> A,
CDenseVectorView<T> x,
const T& b,
DenseVectorView<T>& y);
template <typename T, int N>
void spmv(CTripletMatrixView<T, N> A, CDenseVectorView<T> x, DenseVectorView<T> y);
// COO
template <typename T>
void spmv(const T& a, CCOOMatrixView<T> A, CDenseVectorView<T> x, const T& b, DenseVectorView<T>& y);
template <typename T>
void spmv(CCOOMatrixView<T> A, CDenseVectorView<T> x, DenseVectorView<T> y);
/***********************************************************************************************
Mv
y = a * A * x + b * y
***********************************************************************************************/
template <typename T>
void mv(CDenseMatrixView<T> A,
const T& alpha,
CDenseVectorView<T> x,
const T& beta,
DenseVectorView<T> y);
template <typename T>
void mv(CDenseMatrixView<T> A,
CVarView<T> alpha,
CDenseVectorView<T> x,
CVarView<T> beta,
DenseVectorView<T> y);
template <typename T>
void mv(CDenseMatrixView<T> A, CDenseVectorView<T> x, DenseVectorView<T> y);
/***********************************************************************************************
Solve
A * x = b
***********************************************************************************************/
// solve Ax = b, A will be modified for factorization
// and b will be modified to store the solution
template <typename T>
void solve(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
// solve Ax = b
// A is the CSR Matrix
template <typename T>
void solve(DenseVectorView<T> x, CCSRMatrixView<T> A, CDenseVectorView<T> b);
private:
template <typename T>
void generic_spmv(const T& a,
cusparseOperation_t op,
cusparseSpMatDescr_t A,
const cusparseDnVecDescr* x,
const T& b,
cusparseDnVecDescr_t y);
template <typename T>
void sysv(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
template <typename T>
void gesv(DenseMatrixView<T> A_to_fact, DenseVectorView<T> b_to_x);
};
} // namespace muda
#include "details/linear_system_context.inl"
#include "details/routines/convert.inl"
#include "details/routines/norm.inl"
#include "details/routines/dot.inl"
#include "details/routines/axpby.inl"
#include "details/routines/spmv.inl"
#include "details/routines/mv.inl"
#include "details/routines/solve.inl"
#include "details/routines/mm.inl"