MUDA
Loading...
Searching...
No Matches
host_device_string_cache.h
1#pragma once
2#include <cuda_runtime.h>
3#include <muda/literal/unit.h>
4#include <muda/muda_def.h>
5#include <muda/check/check_cuda_errors.h>
6#include <unordered_map>
7#include <string>
8#include <muda/tools/string_pointer.h>
9#include <vector>
10#include <cstring>
11
12namespace muda::details
13{
15{
16 class StringLocation
17 {
18 public:
19 size_t buffer_index = ~0;
20 size_t offset = ~0;
21 size_t size = ~0;
22 };
23
24 std::unordered_map<std::string, StringLocation> m_string_map;
25
26 std::vector<char*> m_device_string_buffers;
27 std::vector<char*> m_host_string_buffers;
28
29 size_t m_current_buffer_offset;
30 size_t m_buffer_size;
31
32 StringPointer m_empty_string_pointer{};
33
34 public:
35 HostDeviceStringCache(size_t buffer_size = 4_M)
36 : m_buffer_size(buffer_size)
37 , m_current_buffer_offset(0)
38 {
39 m_device_string_buffers.reserve(32);
40 m_host_string_buffers.reserve(32);
41
42 char* s;
43 checkCudaErrors(cudaMalloc(&s, m_buffer_size * sizeof(char)));
44 m_device_string_buffers.emplace_back(s);
45 m_host_string_buffers.emplace_back(new char[m_buffer_size]);
46
47 m_empty_string_pointer = get_string_pointer(""); // insert empty string
48 }
50 {
51 for(auto s : m_device_string_buffers)
52 cudaFree(s);
53 for(auto s : m_host_string_buffers)
54 delete[] s;
55 }
56 // delete copy
58 HostDeviceStringCache& operator=(const HostDeviceStringCache&) = delete;
59 // move
61 HostDeviceStringCache& operator=(HostDeviceStringCache&&) = default;
62
63 StringPointer operator[](std::string_view s)
64 {
65 if(s.empty() || s == "")
66 {
67 return m_empty_string_pointer;
68 }
69 return get_string_pointer(s);
70 }
71
72 private:
73 StringPointer get_string_pointer(std::string_view s)
74 {
75 auto str = std::string{s};
76 auto it = m_string_map.find(str);
77 char* device_string = nullptr;
78 char* host_string = nullptr;
79 unsigned int str_length = 0;
80
81 if(it != m_string_map.end()) // cached
82 {
83 auto& loc = it->second;
84 device_string = m_device_string_buffers[loc.buffer_index] + loc.offset;
85 host_string = m_host_string_buffers[loc.buffer_index] + loc.offset;
86 str_length = static_cast<unsigned int>(loc.size - 1);
87 }
88 else // need insert
89 {
90 auto zero_end_length = str.size() + 1;
91 auto& loc = m_string_map[str]; // insert
92
93 if(m_current_buffer_offset + zero_end_length > m_buffer_size) // need new buffer
94 {
95 char* s;
96 checkCudaErrors(cudaMalloc(&s, m_buffer_size * sizeof(char)));
97 m_device_string_buffers.emplace_back(s);
98 m_host_string_buffers.emplace_back(new char[m_buffer_size]);
99 m_current_buffer_offset = 0;
100 }
101
102 auto device_buffer = m_device_string_buffers.back();
103 auto host_buffer = m_host_string_buffers.back();
104
105 // copy string to host buffer (with '\0' end)
106 host_buffer[m_current_buffer_offset + str.size()] = '\0';
107 std::memcpy(host_buffer + m_current_buffer_offset, str.data(), str.size());
108
109 // copy string from host buffer to device buffer
110 checkCudaErrors(cudaMemcpy(device_buffer + m_current_buffer_offset,
111 host_buffer + m_current_buffer_offset,
112 str.size() + 1,
113 cudaMemcpyHostToDevice));
114
115 loc.buffer_index = m_host_string_buffers.size() - 1;
116 loc.offset = m_current_buffer_offset;
117 loc.size = zero_end_length; // include '\0'
118
119 m_current_buffer_offset += zero_end_length;
120
121 device_string = device_buffer + loc.offset;
122 host_string = host_buffer + loc.offset;
123 str_length = static_cast<unsigned int>(loc.size - 1);
124 }
125 return StringPointer{device_string, host_string, str_length};
126 }
127};
128} // namespace muda::details
Definition host_device_string_cache.h:15
Definition string_pointer.h:7