1#include <muda/check/check_cublas.h>
2#include <muda/check/check_cusolver.h>
3#include <muda/check/check_cusparse.h>
7MUDA_INLINE LinearSystemContext::LinearSystemContext(
const LinearSystemContextCreateInfo& info)
9 , m_handles(info.stream)
10 , m_converter(m_handles)
12 m_buffers.emplace_back(info.buffer_byte_size_base);
15MUDA_INLINE LinearSystemContext::~LinearSystemContext() {}
17MUDA_INLINE
void LinearSystemContext::stream(cudaStream_t stream) {}
19MUDA_INLINE
void LinearSystemContext::shrink_temp_buffers()
21 checkCudaErrors(cudaStreamSynchronize(m_handles.stream()));
23 auto first = m_buffers.begin();
24 auto last = std::prev(m_buffers.end());
25 std::iter_swap(first, last);
30MUDA_INLINE
void LinearSystemContext::set_pointer_mode_device()
32 m_handles.set_pointer_mode_device();
34MUDA_INLINE
void LinearSystemContext::set_pointer_mode_host()
36 m_handles.set_pointer_mode_host();
38MUDA_INLINE
void LinearSystemContext::add_sync_callback(std::function<
void()>&& callback)
40 m_sync_callbacks.emplace_back(std::move(callback));
43MUDA_INLINE span<std::byte> LinearSystemContext::temp_host_buffer(
size_t size)
45 for(
auto& b : m_host_buffers)
47 return span<std::byte>{b.data(), size};
48 auto base = m_create_info.buffer_byte_size_base;
50 auto rounded_size = ((size + base - 1) / base) * base;
51 return span<std::byte>{m_host_buffers.emplace_back(rounded_size).data(), size};
54MUDA_INLINE BufferView<std::byte> LinearSystemContext::temp_buffer(
size_t size)
56 for(
auto& b : m_buffers)
58 return b.view(0, size);
59 auto base = m_create_info.buffer_byte_size_base;
61 auto rounded_size = ((size + base - 1) / base) * base;
62 return m_buffers.emplace_back(rounded_size).view(0, size);
66BufferView<T> LinearSystemContext::temp_buffer(
size_t size)
68 BufferView<std::byte> bbuf = temp_buffer(size *
sizeof(T));
69 return BufferView<T>{
reinterpret_cast<T*
>(bbuf.data()), 0, size};
73span<T> LinearSystemContext::temp_host_buffer(
size_t size)
75 span<std::byte> bbuf = temp_host_buffer(size *
sizeof(T));
76 return span<T>{
reinterpret_cast<T*
>(bbuf.data()), size};
80std::vector<T*> LinearSystemContext::temp_buffers(
size_t size_in_buffer,
size_t num_buffer)
82 BufferView<T> total = temp_buffer<T>(size_in_buffer * num_buffer);
83 std::vector<T*> ret(num_buffer);
84 for(
int i = 0; i < num_buffer; ++i)
85 ret[i] = total.data(i * size_in_buffer);
90std::vector<T*> LinearSystemContext::temp_host_buffers(
size_t size_in_buffer,
size_t num_buffer)
92 span<T> total = temp_host_buffer<T>(size_in_buffer * num_buffer);
93 std::vector<T*> ret(num_buffer);
94 for(
int i = 0; i < num_buffer; ++i)
95 ret[i] = total.data() + i * size_in_buffer;
99MUDA_INLINE
void LinearSystemContext::sync()
103 if(m_buffers.size() > 1)
104 shrink_temp_buffers();
106 for(
auto& cb : m_sync_callbacks)
108 m_sync_callbacks.clear();