MUDA
Loading...
Searching...
No Matches
linear_system_handles.h
1#pragma once
2#include <cublas_v2.h>
3#include <cusparse_v2.h>
4#include <cusolverDn.h>
5#include <cusolverSp.h>
6#include <muda/muda_def.h>
7#include <muda/check/check_cusparse.h>
8#include <muda/check/check_cublas.h>
9#include <muda/check/check_cusolver.h>
10#include <muda/check/check.h>
11
12namespace muda
13{
14class LinearSystemContext;
16{
17 friend class LinearSystemContext;
18 cudaStream_t m_stream = nullptr;
19 cublasHandle_t m_cublas = nullptr;
20 cusparseHandle_t m_cusparse = nullptr;
21 cusolverDnHandle_t m_cusolver_dn = nullptr;
22 cusolverSpHandle_t m_cusolver_sp = nullptr;
23 bool m_pointer_mode_device = false;
24 float m_reserve_ratio = 1.5f;
25
26 public:
27 LinearSystemHandles(cudaStream_t s)
28 : m_stream(s)
29 {
30 checkCudaErrors(cusparseCreate(&m_cusparse));
31 checkCudaErrors(cublasCreate(&m_cublas));
32 checkCudaErrors(cusolverDnCreate(&m_cusolver_dn));
33 checkCudaErrors(cusparseSetStream(m_cusparse, m_stream));
34 checkCudaErrors(cublasSetStream(m_cublas, m_stream));
35 checkCudaErrors(cusolverDnSetStream(m_cusolver_dn, m_stream));
36 checkCudaErrors(cusolverSpCreate(&m_cusolver_sp));
37 checkCudaErrors(cusolverSpSetStream(m_cusolver_sp, m_stream));
38 set_pointer_mode_host();
39 }
41 {
42 if(m_cusparse)
43 checkCudaErrors(cusparseDestroy(m_cusparse));
44 if(m_cublas)
45 checkCudaErrors(cublasDestroy(m_cublas));
46 if(m_cusolver_dn)
47 checkCudaErrors(cusolverDnDestroy(m_cusolver_dn));
48 if(m_cusolver_sp)
49 checkCudaErrors(cusolverSpDestroy(m_cusolver_sp));
50 }
51
52 void stream(cudaStream_t s)
53 {
54 m_stream = s;
55 checkCudaErrors(cusparseSetStream(m_cusparse, m_stream));
56 checkCudaErrors(cublasSetStream(m_cublas, m_stream));
57 checkCudaErrors(cusolverDnSetStream(m_cusolver_dn, m_stream));
58 checkCudaErrors(cusolverSpSetStream(m_cusolver_sp, m_stream));
59 }
60
61 MUDA_INLINE void set_pointer_mode_device()
62 {
63 if(m_pointer_mode_device)
64 return;
65 checkCudaErrors(cusparseSetPointerMode(m_cusparse, CUSPARSE_POINTER_MODE_DEVICE));
66 checkCudaErrors(cublasSetPointerMode(m_cublas, CUBLAS_POINTER_MODE_DEVICE));
67 m_pointer_mode_device = true;
68 }
69
70 MUDA_INLINE void set_pointer_mode_host()
71 {
72 if(!m_pointer_mode_device)
73 return;
74 checkCudaErrors(cusparseSetPointerMode(m_cusparse, CUSPARSE_POINTER_MODE_HOST));
75 checkCudaErrors(cublasSetPointerMode(m_cublas, CUBLAS_POINTER_MODE_HOST));
76 m_pointer_mode_device = false;
77 }
78
79 cudaStream_t stream() const { return m_stream; }
80 cublasHandle_t cublas() const { return m_cublas; }
81 cusparseHandle_t cusparse() const { return m_cusparse; }
82 cusolverDnHandle_t cusolver_dn() const { return m_cusolver_dn; }
83 cusolverSpHandle_t cusolver_sp() const { return m_cusolver_sp; }
84 auto reserve_ratio() const { return m_reserve_ratio; }
85};
86} // namespace muda
Definition linear_system_context.h:28
Definition linear_system_handles.h:16