Skip to content

File linear_system_handles.h

File List > ext > linear_system > linear_system_handles.h

Go to the documentation of this file

#pragma once
#include <cublas_v2.h>
#include <cusparse_v2.h>
#include <cusolverDn.h>
#include <cusolverSp.h>
#include <muda/muda_def.h>
#include <muda/check/check_cusparse.h>
#include <muda/check/check_cublas.h>
#include <muda/check/check_cusolver.h>
#include <muda/check/check.h>

namespace muda
{
class LinearSystemContext;
class LinearSystemHandles
{
    friend class LinearSystemContext;
    cudaStream_t       m_stream              = nullptr;
    cublasHandle_t     m_cublas              = nullptr;
    cusparseHandle_t   m_cusparse            = nullptr;
    cusolverDnHandle_t m_cusolver_dn         = nullptr;
    cusolverSpHandle_t m_cusolver_sp         = nullptr;
    bool               m_pointer_mode_device = false;
    float              m_reserve_ratio         = 1.5f;

  public:
    LinearSystemHandles(cudaStream_t s)
        : m_stream(s)
    {
        checkCudaErrors(cusparseCreate(&m_cusparse));
        checkCudaErrors(cublasCreate(&m_cublas));
        checkCudaErrors(cusolverDnCreate(&m_cusolver_dn));
        checkCudaErrors(cusparseSetStream(m_cusparse, m_stream));
        checkCudaErrors(cublasSetStream(m_cublas, m_stream));
        checkCudaErrors(cusolverDnSetStream(m_cusolver_dn, m_stream));
        checkCudaErrors(cusolverSpCreate(&m_cusolver_sp));
        checkCudaErrors(cusolverSpSetStream(m_cusolver_sp, m_stream));
        set_pointer_mode_host();
    }
    ~LinearSystemHandles()
    {
        if(m_cusparse)
            checkCudaErrors(cusparseDestroy(m_cusparse));
        if(m_cublas)
            checkCudaErrors(cublasDestroy(m_cublas));
        if(m_cusolver_dn)
            checkCudaErrors(cusolverDnDestroy(m_cusolver_dn));
        if(m_cusolver_sp)
            checkCudaErrors(cusolverSpDestroy(m_cusolver_sp));
    }

    void stream(cudaStream_t s)
    {
        m_stream = s;
        checkCudaErrors(cusparseSetStream(m_cusparse, m_stream));
        checkCudaErrors(cublasSetStream(m_cublas, m_stream));
        checkCudaErrors(cusolverDnSetStream(m_cusolver_dn, m_stream));
        checkCudaErrors(cusolverSpSetStream(m_cusolver_sp, m_stream));
    }

    MUDA_INLINE void set_pointer_mode_device()
    {
        if(m_pointer_mode_device)
            return;
        checkCudaErrors(cusparseSetPointerMode(m_cusparse, CUSPARSE_POINTER_MODE_DEVICE));
        checkCudaErrors(cublasSetPointerMode(m_cublas, CUBLAS_POINTER_MODE_DEVICE));
        m_pointer_mode_device = true;
    }

    MUDA_INLINE void set_pointer_mode_host()
    {
        if(!m_pointer_mode_device)
            return;
        checkCudaErrors(cusparseSetPointerMode(m_cusparse, CUSPARSE_POINTER_MODE_HOST));
        checkCudaErrors(cublasSetPointerMode(m_cublas, CUBLAS_POINTER_MODE_HOST));
        m_pointer_mode_device = false;
    }

    cudaStream_t       stream() const { return m_stream; }
    cublasHandle_t     cublas() const { return m_cublas; }
    cusparseHandle_t   cusparse() const { return m_cusparse; }
    cusolverDnHandle_t cusolver_dn() const { return m_cusolver_dn; }
    cusolverSpHandle_t cusolver_sp() const { return m_cusolver_sp; }
    auto reserve_ratio() const { return m_reserve_ratio; }
};
}  // namespace muda