3#include <muda/mstl/span.h>
4#include <muda/cub/device/device_radix_sort.h>
8MUDA_INLINE
const T& LoggerMetaData::as()
10 if constexpr(std::is_same_v<T, int8_t>)
12 MUDA_ASSERT(type == LoggerBasicType::Int8,
"");
14 else if constexpr(std::is_same_v<T, int16_t>)
16 MUDA_ASSERT(type == LoggerBasicType::Int16,
"");
18 else if constexpr(std::is_same_v<T, int32_t>)
20 MUDA_ASSERT(type == LoggerBasicType::Int32,
"");
22 else if constexpr(std::is_same_v<T, int64_t>)
24 MUDA_ASSERT(type == LoggerBasicType::Int64,
"");
26 else if constexpr(std::is_same_v<T, uint8_t>)
28 MUDA_ASSERT(type == LoggerBasicType::UInt8,
"");
30 else if constexpr(std::is_same_v<T, uint16_t>)
32 MUDA_ASSERT(type == LoggerBasicType::UInt16,
"");
34 else if constexpr(std::is_same_v<T, uint32_t>)
36 MUDA_ASSERT(type == LoggerBasicType::UInt32,
"");
38 else if constexpr(std::is_same_v<T, uint64_t>)
40 MUDA_ASSERT(type == LoggerBasicType::UInt64,
"");
42 else if constexpr(std::is_same_v<T, float>)
44 MUDA_ASSERT(type == LoggerBasicType::Float,
"");
46 else if constexpr(std::is_same_v<T, double>)
48 MUDA_ASSERT(type == LoggerBasicType::Double,
"");
52 MUDA_ASSERT(type == LoggerBasicType::Object,
"");
54 return *
reinterpret_cast<const T*
>(data);
57MUDA_INLINE Logger::Logger(LoggerViewer* global_viewer,
size_t meta_size,
size_t buffer_size)
58 : m_meta_data_id(meta_size)
59 , m_meta_data(meta_size)
60 , m_sorted_meta_data_id(meta_size)
61 , m_sorted_meta_data(meta_size)
62 , m_h_meta_data(meta_size)
63 , m_buffer(buffer_size)
64 , m_h_buffer(buffer_size)
65 , m_log_viewer_ptr(global_viewer)
71MUDA_INLINE Logger::Logger(Logger&& other) noexcept
72 : m_meta_data_id(std::move(other.m_meta_data_id))
73 , m_meta_data(std::move(other.m_meta_data))
74 , m_sorted_meta_data_id(std::move(other.m_sorted_meta_data_id))
75 , m_sorted_meta_data(std::move(other.m_sorted_meta_data))
76 , m_h_meta_data(std::move(other.m_h_meta_data))
77 , m_buffer(std::move(other.m_buffer))
78 , m_h_buffer(std::move(other.m_h_buffer))
79 , m_offset(std::move(other.m_offset))
80 , m_h_offset(std::move(other.m_h_offset))
81 , m_log_viewer_ptr(std::move(other.m_log_viewer_ptr))
83 other.m_log_viewer_ptr =
nullptr;
87MUDA_INLINE Logger& Logger::operator=(Logger&& other)
noexcept
92 m_meta_data_id = std::move(other.m_meta_data_id);
93 m_meta_data = std::move(other.m_meta_data);
94 m_sorted_meta_data_id = std::move(other.m_sorted_meta_data_id);
95 m_sorted_meta_data = std::move(other.m_sorted_meta_data);
96 m_h_meta_data = std::move(other.m_h_meta_data);
97 m_buffer = std::move(other.m_buffer);
98 m_h_buffer = std::move(other.m_h_buffer);
99 m_offset = std::move(other.m_offset);
100 m_h_offset = std::move(other.m_h_offset);
101 m_log_viewer_ptr = std::move(other.m_log_viewer_ptr);
102 other.m_log_viewer_ptr =
nullptr;
108void Logger::_retrieve(F&& f)
112 auto is_debug_sync = muda::Debug::is_debug_sync_all();
113 muda::Debug::debug_sync_all(
false);
123 auto meta_data_span =
124 span<details::LoggerMetaData>{m_h_meta_data}.subspan(0, m_h_offset.meta_data_offset);
130 muda::Debug::debug_sync_all(is_debug_sync);
132MUDA_INLINE
void Logger::retrieve(std::ostream& os)
134 std::stringstream ss;
136 [&](
const span<details::LoggerMetaData>& meta_data_span)
138 for(
const auto& meta_data : meta_data_span)
140 if(meta_data.exceeded)
141 ss <<
"[log_id " << meta_data.id <<
": buffer exceeded]";
149MUDA_INLINE LoggerDataContainer Logger::retrieve_meta()
151 LoggerDataContainer ret;
153 [&](
const span<details::LoggerMetaData>& meta_data_span)
156 ret.m_buffer = m_h_buffer;
157 auto buffer = ret.m_buffer.data();
158 ret.m_meta_data.resize(meta_data_span.size());
159 std::transform(meta_data_span.begin(),
160 meta_data_span.end(),
161 ret.m_meta_data.begin(),
162 [buffer](
const details::LoggerMetaData& meta_data)
164 return LoggerMetaData{meta_data.id,
166 buffer + meta_data.offset,
173MUDA_INLINE
void Logger::expand_meta_data()
175 auto new_size = m_meta_data.size() * 2;
177 m_meta_data_id.resize(new_size);
178 m_meta_data.resize(new_size);
180 m_sorted_meta_data_id.resize(new_size);
181 m_sorted_meta_data.resize(new_size);
184MUDA_INLINE
void Logger::expand_buffer()
186 auto new_size = m_buffer.size() * 2;
187 m_buffer.resize(new_size);
190MUDA_INLINE
void Logger::upload()
194 m_offset = {m_h_offset};
196 m_viewer.m_offset = m_offset.data();
197 m_viewer.m_meta_data_id = m_meta_data_id.data();
198 m_viewer.m_meta_data_id_size = m_meta_data_id.size();
199 m_viewer.m_meta_data = m_meta_data.data();
200 m_viewer.m_meta_data_size = m_meta_data.size();
201 m_viewer.m_buffer = m_buffer.data();
202 m_viewer.m_buffer_size = m_buffer.size();
206 checkCudaErrors(cudaMemcpyAsync(
207 m_log_viewer_ptr, &m_viewer,
sizeof(m_viewer), cudaMemcpyHostToDevice,
nullptr));
209 checkCudaErrors(cudaDeviceSynchronize());
212MUDA_INLINE
void Logger::download()
215 std::vector<details::LoggerOffset> h_offset(1);
216 m_offset.copy_to(h_offset);
217 m_h_offset = h_offset[0];
221 DeviceRadixSort().SortPairs(m_meta_data_id.data(),
222 m_sorted_meta_data_id.data(),
224 m_sorted_meta_data.data(),
225 m_h_offset.meta_data_offset);
227 if(m_h_offset.meta_data_offset > 0)
229 m_h_meta_data.resize(m_h_offset.meta_data_offset);
230 checkCudaErrors(cudaMemcpyAsync(m_h_meta_data.data(),
231 m_sorted_meta_data.data(),
232 m_h_meta_data.size() *
sizeof(details::LoggerMetaData),
233 cudaMemcpyDeviceToHost));
236 if(m_h_offset.buffer_offset > 0)
238 m_h_buffer.resize(m_h_offset.buffer_offset);
239 checkCudaErrors(cudaMemcpyAsync(
240 m_h_buffer.data(), m_buffer.data(), m_h_offset.buffer_offset, cudaMemcpyDeviceToHost));
243 checkCudaErrors(cudaDeviceSynchronize());
246MUDA_INLINE
void Logger::expand_if_needed()
248 if(m_h_offset.exceed_meta_data)
250 auto old_size = m_meta_data.size();
252 auto new_size = m_meta_data.size();
254 m_h_offset.exceed_meta_data = 0;
255 MUDA_KERNEL_WARN_WITH_LOCATION(
256 "Logger meta data buffer expanded %d => %d", old_size, new_size);
258 if(m_h_offset.exceed_buffer)
260 auto old_size = m_buffer.size();
262 auto new_size = m_buffer.size();
264 m_h_offset.exceed_buffer = 0;
265 MUDA_KERNEL_WARN_WITH_LOCATION(
"Logger buffer expanded %d => %d", old_size, new_size);
269MUDA_INLINE
void Logger::put(std::ostream& os,
const details::LoggerMetaData& meta_data)
const
271 auto buffer = m_h_buffer.data();
272 auto offset = meta_data.offset;
273 auto type = meta_data.type;
274#define MUDA_PUT_CASE(EnumT, T) \
275 case LoggerBasicType::EnumT: \
276 os << *reinterpret_cast<const T*>(buffer + offset); \
281 case LoggerBasicType::String:
282 os << buffer + offset;
284 MUDA_PUT_CASE(Int8, int8_t);
285 MUDA_PUT_CASE(Int16, int16_t);
286 MUDA_PUT_CASE(Int32, int32_t);
287 MUDA_PUT_CASE(Int64, int64_t);
288 MUDA_PUT_CASE(UInt8, uint8_t);
289 MUDA_PUT_CASE(UInt16, uint16_t);
290 MUDA_PUT_CASE(UInt32, uint32_t);
291 MUDA_PUT_CASE(UInt64, uint64_t);
292 MUDA_PUT_CASE(Float,
float);
293 MUDA_PUT_CASE(Double,
double);
295 MUDA_ERROR_WITH_LOCATION(
"Unknown type");
301MUDA_INLINE Logger::~Logger() {}