MUDA
Loading...
Searching...
No Matches
logger.inl
1#include <algorithm>
2#include <sstream>
3#include <muda/mstl/span.h>
4#include <muda/cub/device/device_radix_sort.h>
5namespace muda
6{
7template <typename T>
8MUDA_INLINE const T& LoggerMetaData::as()
9{
10 if constexpr(std::is_same_v<T, int8_t>)
11 {
12 MUDA_ASSERT(type == LoggerBasicType::Int8, "");
13 }
14 else if constexpr(std::is_same_v<T, int16_t>)
15 {
16 MUDA_ASSERT(type == LoggerBasicType::Int16, "");
17 }
18 else if constexpr(std::is_same_v<T, int32_t>)
19 {
20 MUDA_ASSERT(type == LoggerBasicType::Int32, "");
21 }
22 else if constexpr(std::is_same_v<T, int64_t>)
23 {
24 MUDA_ASSERT(type == LoggerBasicType::Int64, "");
25 }
26 else if constexpr(std::is_same_v<T, uint8_t>)
27 {
28 MUDA_ASSERT(type == LoggerBasicType::UInt8, "");
29 }
30 else if constexpr(std::is_same_v<T, uint16_t>)
31 {
32 MUDA_ASSERT(type == LoggerBasicType::UInt16, "");
33 }
34 else if constexpr(std::is_same_v<T, uint32_t>)
35 {
36 MUDA_ASSERT(type == LoggerBasicType::UInt32, "");
37 }
38 else if constexpr(std::is_same_v<T, uint64_t>)
39 {
40 MUDA_ASSERT(type == LoggerBasicType::UInt64, "");
41 }
42 else if constexpr(std::is_same_v<T, float>)
43 {
44 MUDA_ASSERT(type == LoggerBasicType::Float, "");
45 }
46 else if constexpr(std::is_same_v<T, double>)
47 {
48 MUDA_ASSERT(type == LoggerBasicType::Double, "");
49 }
50 else
51 {
52 MUDA_ASSERT(type == LoggerBasicType::Object, "");
53 };
54 return *reinterpret_cast<const T*>(data);
55}
56
57MUDA_INLINE Logger::Logger(LoggerViewer* global_viewer, size_t meta_size, size_t buffer_size)
58 : m_meta_data_id(meta_size)
59 , m_meta_data(meta_size)
60 , m_sorted_meta_data_id(meta_size)
61 , m_sorted_meta_data(meta_size)
62 , m_h_meta_data(meta_size)
63 , m_buffer(buffer_size)
64 , m_h_buffer(buffer_size)
65 , m_log_viewer_ptr(global_viewer)
66 , m_offset(1)
67{
68 upload();
69}
70
71MUDA_INLINE Logger::Logger(Logger&& other) noexcept
72 : m_meta_data_id(std::move(other.m_meta_data_id))
73 , m_meta_data(std::move(other.m_meta_data))
74 , m_sorted_meta_data_id(std::move(other.m_sorted_meta_data_id))
75 , m_sorted_meta_data(std::move(other.m_sorted_meta_data))
76 , m_h_meta_data(std::move(other.m_h_meta_data))
77 , m_buffer(std::move(other.m_buffer))
78 , m_h_buffer(std::move(other.m_h_buffer))
79 , m_offset(std::move(other.m_offset))
80 , m_h_offset(std::move(other.m_h_offset))
81 , m_log_viewer_ptr(std::move(other.m_log_viewer_ptr))
82{
83 other.m_log_viewer_ptr = nullptr;
84 other.m_viewer = {};
85}
86
87MUDA_INLINE Logger& Logger::operator=(Logger&& other) noexcept
88{
89 if(this == &other)
90 return *this;
91
92 m_meta_data_id = std::move(other.m_meta_data_id);
93 m_meta_data = std::move(other.m_meta_data);
94 m_sorted_meta_data_id = std::move(other.m_sorted_meta_data_id);
95 m_sorted_meta_data = std::move(other.m_sorted_meta_data);
96 m_h_meta_data = std::move(other.m_h_meta_data);
97 m_buffer = std::move(other.m_buffer);
98 m_h_buffer = std::move(other.m_h_buffer);
99 m_offset = std::move(other.m_offset);
100 m_h_offset = std::move(other.m_h_offset);
101 m_log_viewer_ptr = std::move(other.m_log_viewer_ptr);
102 other.m_log_viewer_ptr = nullptr;
103 other.m_viewer = {};
104
105 return *this;
106}
107template <typename F>
108void Logger::_retrieve(F&& f)
109{
110 // don't allow automatic sync in this region
111 // or it may cause infinite loop
112 auto is_debug_sync = muda::Debug::is_debug_sync_all();
113 muda::Debug::debug_sync_all(false);
114
115 download();
116 //auto meta_data_span =
117 // span<details::LoggerMetaData>{m_h_meta_data}.subspan(0, m_h_offset.meta_data_offset);
118 //std::stable_sort(meta_data_span.begin(),
119 // meta_data_span.end(),
120 // [](const details::LoggerMetaData& a, const details::LoggerMetaData& b)
121 // { return a.id < b.id; });
122
123 auto meta_data_span =
124 span<details::LoggerMetaData>{m_h_meta_data}.subspan(0, m_h_offset.meta_data_offset);
125
126 f(meta_data_span);
127
128 expand_if_needed();
129 upload();
130 muda::Debug::debug_sync_all(is_debug_sync);
131}
132MUDA_INLINE void Logger::retrieve(std::ostream& os)
133{
134 std::stringstream ss;
135 Logger::_retrieve(
136 [&](const span<details::LoggerMetaData>& meta_data_span)
137 {
138 for(const auto& meta_data : meta_data_span)
139 {
140 if(meta_data.exceeded)
141 ss << "[log_id " << meta_data.id << ": buffer exceeded]";
142 else
143 put(ss, meta_data);
144 }
145 });
146 os << ss.str();
147}
148
149MUDA_INLINE LoggerDataContainer Logger::retrieve_meta()
150{
151 LoggerDataContainer ret;
152 Logger::_retrieve(
153 [&](const span<details::LoggerMetaData>& meta_data_span)
154 {
155 // copy buffer for safety
156 ret.m_buffer = m_h_buffer;
157 auto buffer = ret.m_buffer.data();
158 ret.m_meta_data.resize(meta_data_span.size());
159 std::transform(meta_data_span.begin(),
160 meta_data_span.end(),
161 ret.m_meta_data.begin(),
162 [buffer](const details::LoggerMetaData& meta_data)
163 {
164 return LoggerMetaData{meta_data.id,
165 meta_data.type,
166 buffer + meta_data.offset,
167 meta_data.fmt_arg};
168 });
169 });
170 return ret;
171}
172
173MUDA_INLINE void Logger::expand_meta_data()
174{
175 auto new_size = m_meta_data.size() * 2;
176
177 m_meta_data_id.resize(new_size);
178 m_meta_data.resize(new_size);
179
180 m_sorted_meta_data_id.resize(new_size);
181 m_sorted_meta_data.resize(new_size);
182}
183
184MUDA_INLINE void Logger::expand_buffer()
185{
186 auto new_size = m_buffer.size() * 2;
187 m_buffer.resize(new_size);
188}
189
190MUDA_INLINE void Logger::upload()
191{
192 // reset
193 m_h_offset = {};
194 m_offset = {m_h_offset};
195
196 m_viewer.m_offset = m_offset.data();
197 m_viewer.m_meta_data_id = m_meta_data_id.data();
198 m_viewer.m_meta_data_id_size = m_meta_data_id.size();
199 m_viewer.m_meta_data = m_meta_data.data();
200 m_viewer.m_meta_data_size = m_meta_data.size();
201 m_viewer.m_buffer = m_buffer.data();
202 m_viewer.m_buffer_size = m_buffer.size();
203
204 if(m_log_viewer_ptr)
205 {
206 checkCudaErrors(cudaMemcpyAsync(
207 m_log_viewer_ptr, &m_viewer, sizeof(m_viewer), cudaMemcpyHostToDevice, nullptr));
208 }
209 checkCudaErrors(cudaDeviceSynchronize());
210}
211
212MUDA_INLINE void Logger::download()
213{
214 // copy back
215 std::vector<details::LoggerOffset> h_offset(1);
216 m_offset.copy_to(h_offset);
217 m_h_offset = h_offset[0];
218
219 // sort meta data
220
221 DeviceRadixSort().SortPairs(m_meta_data_id.data(),
222 m_sorted_meta_data_id.data(),
223 m_meta_data.data(),
224 m_sorted_meta_data.data(),
225 m_h_offset.meta_data_offset);
226
227 if(m_h_offset.meta_data_offset > 0)
228 {
229 m_h_meta_data.resize(m_h_offset.meta_data_offset);
230 checkCudaErrors(cudaMemcpyAsync(m_h_meta_data.data(),
231 m_sorted_meta_data.data(),
232 m_h_meta_data.size() * sizeof(details::LoggerMetaData),
233 cudaMemcpyDeviceToHost));
234 }
235
236 if(m_h_offset.buffer_offset > 0)
237 {
238 m_h_buffer.resize(m_h_offset.buffer_offset);
239 checkCudaErrors(cudaMemcpyAsync(
240 m_h_buffer.data(), m_buffer.data(), m_h_offset.buffer_offset, cudaMemcpyDeviceToHost));
241 }
242
243 checkCudaErrors(cudaDeviceSynchronize());
244}
245
246MUDA_INLINE void Logger::expand_if_needed()
247{
248 if(m_h_offset.exceed_meta_data)
249 {
250 auto old_size = m_meta_data.size();
251 expand_meta_data();
252 auto new_size = m_meta_data.size();
253
254 m_h_offset.exceed_meta_data = 0;
255 MUDA_KERNEL_WARN_WITH_LOCATION(
256 "Logger meta data buffer expanded %d => %d", old_size, new_size);
257 }
258 if(m_h_offset.exceed_buffer)
259 {
260 auto old_size = m_buffer.size();
261 expand_buffer();
262 auto new_size = m_buffer.size();
263
264 m_h_offset.exceed_buffer = 0;
265 MUDA_KERNEL_WARN_WITH_LOCATION("Logger buffer expanded %d => %d", old_size, new_size);
266 }
267}
268
269MUDA_INLINE void Logger::put(std::ostream& os, const details::LoggerMetaData& meta_data) const
270{
271 auto buffer = m_h_buffer.data();
272 auto offset = meta_data.offset;
273 auto type = meta_data.type;
274#define MUDA_PUT_CASE(EnumT, T) \
275 case LoggerBasicType::EnumT: \
276 os << *reinterpret_cast<const T*>(buffer + offset); \
277 break;
278
279 switch(type)
280 {
281 case LoggerBasicType::String:
282 os << buffer + offset;
283 break;
284 MUDA_PUT_CASE(Int8, int8_t);
285 MUDA_PUT_CASE(Int16, int16_t);
286 MUDA_PUT_CASE(Int32, int32_t);
287 MUDA_PUT_CASE(Int64, int64_t);
288 MUDA_PUT_CASE(UInt8, uint8_t);
289 MUDA_PUT_CASE(UInt16, uint16_t);
290 MUDA_PUT_CASE(UInt32, uint32_t);
291 MUDA_PUT_CASE(UInt64, uint64_t);
292 MUDA_PUT_CASE(Float, float);
293 MUDA_PUT_CASE(Double, double);
294 default:
295 MUDA_ERROR_WITH_LOCATION("Unknown type");
296 break;
297 }
298#undef MUDA_PUT_CASE
299}
300
301MUDA_INLINE Logger::~Logger() {}
302} // namespace muda