File host_device_string_cache.h
File List > muda > tools > host_device_string_cache.h
Go to the documentation of this file
#pragma once
#include <cuda_runtime.h>
#include <muda/literal/unit.h>
#include <muda/muda_def.h>
#include <muda/check/check_cuda_errors.h>
#include <unordered_map>
#include <string>
#include <muda/tools/string_pointer.h>
#include <vector>
#include <cstring>
namespace muda::details
{
class HostDeviceStringCache
{
class StringLocation
{
public:
size_t buffer_index = ~0;
size_t offset = ~0;
size_t size = ~0;
};
std::unordered_map<std::string, StringLocation> m_string_map;
std::vector<char*> m_device_string_buffers;
std::vector<char*> m_host_string_buffers;
size_t m_current_buffer_offset;
size_t m_buffer_size;
StringPointer m_empty_string_pointer{};
public:
HostDeviceStringCache(size_t buffer_size = 4_M)
: m_buffer_size(buffer_size)
, m_current_buffer_offset(0)
{
m_device_string_buffers.reserve(32);
m_host_string_buffers.reserve(32);
char* s;
checkCudaErrors(cudaMalloc(&s, m_buffer_size * sizeof(char)));
m_device_string_buffers.emplace_back(s);
m_host_string_buffers.emplace_back(new char[m_buffer_size]);
m_empty_string_pointer = get_string_pointer(""); // insert empty string
}
~HostDeviceStringCache()
{
for(auto s : m_device_string_buffers)
cudaFree(s);
for(auto s : m_host_string_buffers)
delete[] s;
}
// delete copy
HostDeviceStringCache(const HostDeviceStringCache&) = delete;
HostDeviceStringCache& operator=(const HostDeviceStringCache&) = delete;
// move
HostDeviceStringCache(HostDeviceStringCache&&) = default;
HostDeviceStringCache& operator=(HostDeviceStringCache&&) = default;
StringPointer operator[](std::string_view s)
{
if(s.empty() || s == "")
{
return m_empty_string_pointer;
}
return get_string_pointer(s);
}
private:
StringPointer get_string_pointer(std::string_view s)
{
auto str = std::string{s};
auto it = m_string_map.find(str);
char* device_string = nullptr;
char* host_string = nullptr;
unsigned int str_length = 0;
if(it != m_string_map.end()) // cached
{
auto& loc = it->second;
device_string = m_device_string_buffers[loc.buffer_index] + loc.offset;
host_string = m_host_string_buffers[loc.buffer_index] + loc.offset;
str_length = static_cast<unsigned int>(loc.size - 1);
}
else // need insert
{
auto zero_end_length = str.size() + 1;
auto& loc = m_string_map[str]; // insert
if(m_current_buffer_offset + zero_end_length > m_buffer_size) // need new buffer
{
char* s;
checkCudaErrors(cudaMalloc(&s, m_buffer_size * sizeof(char)));
m_device_string_buffers.emplace_back(s);
m_host_string_buffers.emplace_back(new char[m_buffer_size]);
m_current_buffer_offset = 0;
}
auto device_buffer = m_device_string_buffers.back();
auto host_buffer = m_host_string_buffers.back();
// copy string to host buffer (with '\0' end)
host_buffer[m_current_buffer_offset + str.size()] = '\0';
std::memcpy(host_buffer + m_current_buffer_offset, str.data(), str.size());
// copy string from host buffer to device buffer
checkCudaErrors(cudaMemcpy(device_buffer + m_current_buffer_offset,
host_buffer + m_current_buffer_offset,
str.size() + 1,
cudaMemcpyHostToDevice));
loc.buffer_index = m_host_string_buffers.size() - 1;
loc.offset = m_current_buffer_offset;
loc.size = zero_end_length; // include '\0'
m_current_buffer_offset += zero_end_length;
device_string = device_buffer + loc.offset;
host_string = host_buffer + loc.offset;
str_length = static_cast<unsigned int>(loc.size - 1);
}
return StringPointer{device_string, host_string, str_length};
}
};
} // namespace muda::details