MUDA
Loading...
Searching...
No Matches
device_segmented_radix_sort.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_segmented_radix_sort.cuh>
6#endif
7
8namespace muda
9{
10//ref: https://nvlabs.github.io/cub/structcub_1_1_device_spmv.html
11class DeviceSegmentedRadixSort : public CubWrapper<DeviceSegmentedRadixSort>
12{
14
15 public:
16 using Base::Base;
17
18 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
19 DeviceSegmentedRadixSort& SortPairs(const KeyT* d_keys_in,
20 KeyT* d_keys_out,
21 const ValueT* d_values_in,
22 ValueT* d_values_out,
23 int num_items,
24 int num_segments,
25 BeginOffsetIteratorT d_begin_offsets,
26 EndOffsetIteratorT d_end_offsets,
27 int begin_bit,
28 int end_bit)
29 {
30 MUDA_CUB_WRAPPER_IMPL(
31 cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
32 temp_storage_bytes,
33 d_keys_in,
34 d_keys_out,
35 d_values_in,
36 d_values_out,
37 num_items,
38 num_segments,
39 d_begin_offsets,
40 d_end_offsets,
41 begin_bit,
42 end_bit,
43 _stream,
44 false));
45 }
46
47
48 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
49 DeviceSegmentedRadixSort& SortPairs(cub::DoubleBuffer<KeyT>& d_keys,
50 cub::DoubleBuffer<ValueT>& d_values,
51 int num_items,
52 int num_segments,
53 BeginOffsetIteratorT d_begin_offsets,
54 EndOffsetIteratorT d_end_offsets,
55 int begin_bit,
56 int end_bit)
57 {
58 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
59 temp_storage_bytes,
60 d_keys,
61 d_values,
62 num_items,
63 num_segments,
64 d_begin_offsets,
65 d_end_offsets,
66 begin_bit,
67 end_bit,
68 _stream,
69 false));
70 }
71
72
73 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
74 DeviceSegmentedRadixSort& SortPairsDescending(const KeyT* d_keys_in,
75 KeyT* d_keys_out,
76 const ValueT* d_values_in,
77 ValueT* d_values_out,
78 int num_items,
79 int num_segments,
80 BeginOffsetIteratorT d_begin_offsets,
81 EndOffsetIteratorT d_end_offsets,
82 int begin_bit,
83 int end_bit)
84 {
85 MUDA_CUB_WRAPPER_IMPL(
86 cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
87 temp_storage_bytes,
88 d_keys_in,
89 d_keys_out,
90 d_values_in,
91 d_values_out,
92 num_items,
93 num_segments,
94 d_begin_offsets,
95 d_end_offsets,
96 begin_bit,
97 end_bit,
98 _stream,
99 false));
100 }
101
102
103 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
104 DeviceSegmentedRadixSort& SortPairsDescending(cub::DoubleBuffer<KeyT>& d_keys,
105 cub::DoubleBuffer<ValueT>& d_values,
106 int num_items,
107 int num_segments,
108 BeginOffsetIteratorT d_begin_offsets,
109 EndOffsetIteratorT d_end_offsets,
110 int begin_bit,
111 int end_bit)
112 {
113 MUDA_CUB_WRAPPER_IMPL(
114 cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
115 temp_storage_bytes,
116 d_keys,
117 d_values,
118 num_items,
119 num_segments,
120 d_begin_offsets,
121 d_end_offsets,
122 begin_bit,
123 end_bit,
124 _stream,
125 false));
126 }
127
128
129 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
130 DeviceSegmentedRadixSort& SortKeys(const KeyT* d_keys_in,
131 KeyT* d_keys_out,
132 int num_items,
133 int num_segments,
134 BeginOffsetIteratorT d_begin_offsets,
135 EndOffsetIteratorT d_end_offsets,
136 int begin_bit,
137 int end_bit)
138 {
139 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage,
140 temp_storage_bytes,
141 d_keys_in,
142 d_keys_out,
143 num_items,
144 num_segments,
145 d_begin_offsets,
146 d_end_offsets,
147 begin_bit,
148 end_bit,
149 _stream,
150 false));
151 }
152
153
154 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
155 DeviceSegmentedRadixSort& SortKeys(cub::DoubleBuffer<KeyT>& d_keys,
156 int num_items,
157 int num_segments,
158 BeginOffsetIteratorT d_begin_offsets,
159 EndOffsetIteratorT d_end_offsets,
160 int begin_bit,
161 int end_bit)
162 {
163 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage,
164 temp_storage_bytes,
165 d_keys,
166 num_items,
167 num_segments,
168 d_begin_offsets,
169 d_end_offsets,
170 begin_bit,
171 end_bit,
172 _stream,
173 false));
174 }
175
176
177 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
178 DeviceSegmentedRadixSort& SortKeysDescending(const KeyT* d_keys_in,
179 KeyT* d_keys_out,
180 int num_items,
181 int num_segments,
182 BeginOffsetIteratorT d_begin_offsets,
183 EndOffsetIteratorT d_end_offsets,
184 int begin_bit,
185 int end_bit)
186 {
187 MUDA_CUB_WRAPPER_IMPL(
188 cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage,
189 temp_storage_bytes,
190 d_keys_in,
191 d_keys_out,
192 num_items,
193 num_segments,
194 d_begin_offsets,
195 d_end_offsets,
196 begin_bit,
197 end_bit,
198 _stream,
199 false));
200 }
201
202
203 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
204 DeviceSegmentedRadixSort& SortKeysDescending(cub::DoubleBuffer<KeyT>& d_keys,
205 int num_items,
206 int num_segments,
207 BeginOffsetIteratorT d_begin_offsets,
208 EndOffsetIteratorT d_end_offsets,
209 int begin_bit,
210 int end_bit)
211 {
212 MUDA_CUB_WRAPPER_IMPL(
213 cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage,
214 temp_storage_bytes,
215 d_keys,
216 num_items,
217 num_segments,
218 d_begin_offsets,
219 d_end_offsets,
220 begin_bit,
221 end_bit,
222 _stream,
223 false));
224 }
225
226
227 // Origin:
228
229 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
230 DeviceSegmentedRadixSort& SortPairs(void* d_temp_storage,
231 size_t& temp_storage_bytes,
232 const KeyT* d_keys_in,
233 KeyT* d_keys_out,
234 const ValueT* d_values_in,
235 ValueT* d_values_out,
236 int num_items,
237 int num_segments,
238 BeginOffsetIteratorT d_begin_offsets,
239 EndOffsetIteratorT d_end_offsets,
240 int begin_bit,
241 int end_bit)
242 {
243 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
244 cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
245 temp_storage_bytes,
246 d_keys_in,
247 d_keys_out,
248 d_values_in,
249 d_values_out,
250 num_items,
251 num_segments,
252 d_begin_offsets,
253 d_end_offsets,
254 begin_bit,
255 end_bit,
256 _stream,
257 false));
258 }
259
260
261 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
262 DeviceSegmentedRadixSort& SortPairs(void* d_temp_storage,
263 size_t& temp_storage_bytes,
264 cub::DoubleBuffer<KeyT>& d_keys,
265 cub::DoubleBuffer<ValueT>& d_values,
266 int num_items,
267 int num_segments,
268 BeginOffsetIteratorT d_begin_offsets,
269 EndOffsetIteratorT d_end_offsets,
270 int begin_bit,
271 int end_bit)
272 {
273 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
274 cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
275 temp_storage_bytes,
276 d_keys,
277 d_values,
278 num_items,
279 num_segments,
280 d_begin_offsets,
281 d_end_offsets,
282 begin_bit,
283 end_bit,
284 _stream,
285 false));
286 }
287
288
289 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
290 DeviceSegmentedRadixSort& SortPairsDescending(void* d_temp_storage,
291 size_t& temp_storage_bytes,
292 const KeyT* d_keys_in,
293 KeyT* d_keys_out,
294 const ValueT* d_values_in,
295 ValueT* d_values_out,
296 int num_items,
297 int num_segments,
298 BeginOffsetIteratorT d_begin_offsets,
299 EndOffsetIteratorT d_end_offsets,
300 int begin_bit,
301 int end_bit)
302 {
303 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
304 cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
305 temp_storage_bytes,
306 d_keys_in,
307 d_keys_out,
308 d_values_in,
309 d_values_out,
310 num_items,
311 num_segments,
312 d_begin_offsets,
313 d_end_offsets,
314 begin_bit,
315 end_bit,
316 _stream,
317 false));
318 }
319
320
321 template <typename KeyT, typename ValueT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
322 DeviceSegmentedRadixSort& SortPairsDescending(void* d_temp_storage,
323 size_t& temp_storage_bytes,
324 cub::DoubleBuffer<KeyT>& d_keys,
325 cub::DoubleBuffer<ValueT>& d_values,
326 int num_items,
327 int num_segments,
328 BeginOffsetIteratorT d_begin_offsets,
329 EndOffsetIteratorT d_end_offsets,
330 int begin_bit,
331 int end_bit)
332 {
333 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
334 cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
335 temp_storage_bytes,
336 d_keys,
337 d_values,
338 num_items,
339 num_segments,
340 d_begin_offsets,
341 d_end_offsets,
342 begin_bit,
343 end_bit,
344 _stream,
345 false));
346 }
347
348
349 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
350 DeviceSegmentedRadixSort& SortKeys(void* d_temp_storage,
351 size_t& temp_storage_bytes,
352 const KeyT* d_keys_in,
353 KeyT* d_keys_out,
354 int num_items,
355 int num_segments,
356 BeginOffsetIteratorT d_begin_offsets,
357 EndOffsetIteratorT d_end_offsets,
358 int begin_bit,
359 int end_bit)
360 {
361 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
362 cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage,
363 temp_storage_bytes,
364 d_keys_in,
365 d_keys_out,
366 num_items,
367 num_segments,
368 d_begin_offsets,
369 d_end_offsets,
370 begin_bit,
371 end_bit,
372 _stream,
373 false));
374 }
375
376
377 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
378 DeviceSegmentedRadixSort& SortKeys(void* d_temp_storage,
379 size_t& temp_storage_bytes,
380 cub::DoubleBuffer<KeyT>& d_keys,
381 int num_items,
382 int num_segments,
383 BeginOffsetIteratorT d_begin_offsets,
384 EndOffsetIteratorT d_end_offsets,
385 int begin_bit,
386 int end_bit)
387 {
388 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
389 cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage,
390 temp_storage_bytes,
391 d_keys,
392 num_items,
393 num_segments,
394 d_begin_offsets,
395 d_end_offsets,
396 begin_bit,
397 end_bit,
398 _stream,
399 false));
400 }
401
402
403 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
404 DeviceSegmentedRadixSort& SortKeysDescending(void* d_temp_storage,
405 size_t& temp_storage_bytes,
406 const KeyT* d_keys_in,
407 KeyT* d_keys_out,
408 int num_items,
409 int num_segments,
410 BeginOffsetIteratorT d_begin_offsets,
411 EndOffsetIteratorT d_end_offsets,
412 int begin_bit,
413 int end_bit)
414 {
415 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
416 cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage,
417 temp_storage_bytes,
418 d_keys_in,
419 d_keys_out,
420 num_items,
421 num_segments,
422 d_begin_offsets,
423 d_end_offsets,
424 begin_bit,
425 end_bit,
426 _stream,
427 false));
428 }
429
430
431 template <typename KeyT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
432 DeviceSegmentedRadixSort& SortKeysDescending(void* d_temp_storage,
433 size_t& temp_storage_bytes,
434 cub::DoubleBuffer<KeyT>& d_keys,
435 int num_items,
436 int num_segments,
437 BeginOffsetIteratorT d_begin_offsets,
438 EndOffsetIteratorT d_end_offsets,
439 int begin_bit,
440 int end_bit)
441 {
442 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
443 cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage,
444 temp_storage_bytes,
445 d_keys,
446 num_items,
447 num_segments,
448 d_begin_offsets,
449 d_end_offsets,
450 begin_bit,
451 end_bit,
452 _stream,
453 false));
454 }
455};
456} // namespace muda
457
458#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_segmented_radix_sort.h:12
Definition launch_base.h:42