MUDA
Loading...
Searching...
No Matches
linear_system_context.inl
1#include <muda/check/check_cublas.h>
2#include <muda/check/check_cusolver.h>
3#include <muda/check/check_cusparse.h>
4
5namespace muda
6{
7MUDA_INLINE LinearSystemContext::LinearSystemContext(const LinearSystemContextCreateInfo& info)
8 : m_create_info(info)
9 , m_handles(info.stream)
10 , m_converter(m_handles)
11{
12 m_buffers.emplace_back(info.buffer_byte_size_base);
13}
14
15MUDA_INLINE LinearSystemContext::~LinearSystemContext() {}
16
17MUDA_INLINE void LinearSystemContext::stream(cudaStream_t stream) {}
18
19MUDA_INLINE void LinearSystemContext::shrink_temp_buffers()
20{
21 checkCudaErrors(cudaStreamSynchronize(m_handles.stream()));
22 // get the largest buffer
23 auto first = m_buffers.begin();
24 auto last = std::prev(m_buffers.end());
25 std::iter_swap(first, last);
26 // remove all but the largest buffer
27 m_buffers.resize(1);
28}
29
30MUDA_INLINE void LinearSystemContext::set_pointer_mode_device()
31{
32 m_handles.set_pointer_mode_device();
33}
34MUDA_INLINE void LinearSystemContext::set_pointer_mode_host()
35{
36 m_handles.set_pointer_mode_host();
37}
38MUDA_INLINE void LinearSystemContext::add_sync_callback(std::function<void()>&& callback)
39{
40 m_sync_callbacks.emplace_back(std::move(callback));
41}
42
43MUDA_INLINE span<std::byte> LinearSystemContext::temp_host_buffer(size_t size)
44{
45 for(auto& b : m_host_buffers)
46 if(b.size() >= size)
47 return span<std::byte>{b.data(), size};
48 auto base = m_create_info.buffer_byte_size_base;
49 // round up to multiple of base
50 auto rounded_size = ((size + base - 1) / base) * base;
51 return span<std::byte>{m_host_buffers.emplace_back(rounded_size).data(), size};
52}
53
54MUDA_INLINE BufferView<std::byte> LinearSystemContext::temp_buffer(size_t size)
55{
56 for(auto& b : m_buffers)
57 if(b.size() >= size)
58 return b.view(0, size);
59 auto base = m_create_info.buffer_byte_size_base;
60 // round up to multiple of base
61 auto rounded_size = ((size + base - 1) / base) * base;
62 return m_buffers.emplace_back(rounded_size).view(0, size);
63}
64
65template <typename T>
66BufferView<T> LinearSystemContext::temp_buffer(size_t size)
67{
68 BufferView<std::byte> bbuf = temp_buffer(size * sizeof(T));
69 return BufferView<T>{reinterpret_cast<T*>(bbuf.data()), 0, size};
70}
71
72template <typename T>
73span<T> LinearSystemContext::temp_host_buffer(size_t size)
74{
75 span<std::byte> bbuf = temp_host_buffer(size * sizeof(T));
76 return span<T>{reinterpret_cast<T*>(bbuf.data()), size};
77}
78
79template <typename T>
80std::vector<T*> LinearSystemContext::temp_buffers(size_t size_in_buffer, size_t num_buffer)
81{
82 BufferView<T> total = temp_buffer<T>(size_in_buffer * num_buffer);
83 std::vector<T*> ret(num_buffer);
84 for(int i = 0; i < num_buffer; ++i)
85 ret[i] = total.data(i * size_in_buffer);
86 return ret;
87}
88
89template <typename T>
90std::vector<T*> LinearSystemContext::temp_host_buffers(size_t size_in_buffer, size_t num_buffer)
91{
92 span<T> total = temp_host_buffer<T>(size_in_buffer * num_buffer);
93 std::vector<T*> ret(num_buffer);
94 for(int i = 0; i < num_buffer; ++i)
95 ret[i] = total.data() + i * size_in_buffer;
96 return ret;
97}
98
99MUDA_INLINE void LinearSystemContext::sync()
100{
101 on(stream()).wait();
102 // wait and reduce temp buffers
103 if(m_buffers.size() > 1)
104 shrink_temp_buffers();
105 // call callbacks
106 for(auto& cb : m_sync_callbacks)
107 cb();
108 m_sync_callbacks.clear();
109}
110} // namespace muda