MUDA
Loading...
Searching...
No Matches
device_radix_sort.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_radix_sort.cuh>
6#else
7//namespace cub
8//{
9//template <typename KeyT>
10//struct DoubleBuffer;
11//}
12#endif
13
14namespace muda
15{
16//ref: https://nvlabs.github.io/cub/structcub_1_1_device_radix_sort.html
17class DeviceRadixSort : public CubWrapper<DeviceRadixSort>
18{
20
21 public:
22 using Base::Base;
23
24 template <typename KeyT, typename ValueT, typename NumItemsT>
25 DeviceRadixSort& SortPairs(const KeyT* d_keys_in,
26 KeyT* d_keys_out,
27 const ValueT* d_values_in,
28 ValueT* d_values_out,
29 NumItemsT num_items,
30 int begin_bit = 0,
31 int end_bit = sizeof(KeyT) * 8)
32 {
33 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortPairs(d_temp_storage,
34 temp_storage_bytes,
35 d_keys_in,
36 d_keys_out,
37 d_values_in,
38 d_values_out,
39 num_items,
40 begin_bit,
41 end_bit,
42 _stream));
43 }
44
45 template <typename KeyT, typename ValueT, typename NumItemsT>
46 DeviceRadixSort& SortPairs(DeviceVector<std::byte>& external_buffer,
47 cub::DoubleBuffer<KeyT>& d_keys,
48 cub::DoubleBuffer<ValueT>& d_values,
49 NumItemsT num_items,
50 int begin_bit = 0,
51 int end_bit = sizeof(KeyT) * 8)
52 {
53 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortPairs(
54 d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, _stream));
55 }
56
57 template <typename KeyT, typename ValueT, typename NumItemsT>
58 DeviceRadixSort& SortPairsDescending(const KeyT* d_keys_in,
59 KeyT* d_keys_out,
60 const ValueT* d_values_in,
61 ValueT* d_values_out,
62 NumItemsT num_items,
63 int begin_bit = 0,
64 int end_bit = sizeof(KeyT) * 8)
65 {
66 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortPairsDescending(
67 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, begin_bit, end_bit, _stream));
68 }
69
70 template <typename KeyT, typename ValueT, typename NumItemsT>
71 DeviceRadixSort& SortPairsDescending(cub::DoubleBuffer<KeyT>& d_keys,
72 cub::DoubleBuffer<ValueT>& d_values,
73 NumItemsT num_items,
74 int begin_bit = 0,
75 int end_bit = sizeof(KeyT) * 8)
76 {
77 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortPairsDescending(
78 d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, _stream));
79 }
80
81 template <typename KeyT, typename NumItemsT>
82 DeviceRadixSort& SortKeys(const KeyT* d_keys_in,
83 KeyT* d_keys_out,
84 NumItemsT num_items,
85 int begin_bit = 0,
86 int end_bit = sizeof(KeyT) * 8)
87 {
88 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortKeys(
89 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items, begin_bit, end_bit, _stream));
90 }
91
92 template <typename KeyT, typename NumItemsT>
93 DeviceRadixSort& SortKeys(cub::DoubleBuffer<KeyT>& d_keys,
94 NumItemsT num_items,
95 int begin_bit = 0,
96 int end_bit = sizeof(KeyT) * 8)
97 {
98 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortKeys(
99 d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, _stream));
100 }
101
102 template <typename KeyT, typename NumItemsT>
103 DeviceRadixSort& SortKeysDescending(const KeyT* d_keys_in,
104 KeyT* d_keys_out,
105 NumItemsT num_items,
106 int begin_bit = 0,
107 int end_bit = sizeof(KeyT) * 8)
108 {
109 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortKeysDescending(
110 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items, begin_bit, end_bit, _stream));
111 }
112
113 template <typename KeyT, typename NumItemsT>
114 DeviceRadixSort& SortKeysDescending(cub::DoubleBuffer<KeyT>& d_keys,
115 NumItemsT num_items,
116 int begin_bit = 0,
117 int end_bit = sizeof(KeyT) * 8)
118 {
119 MUDA_CUB_WRAPPER_IMPL(cub::DeviceRadixSort::SortKeysDescending(
120 d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, _stream));
121 }
122
123 // Origin:
124
125 template <typename KeyT, typename ValueT, typename NumItemsT>
126 DeviceRadixSort& SortPairs(void* d_temp_storage,
127 size_t& temp_storage_bytes,
128 const KeyT* d_keys_in,
129 KeyT* d_keys_out,
130 const ValueT* d_values_in,
131 ValueT* d_values_out,
132 NumItemsT num_items,
133 int begin_bit = 0,
134 int end_bit = sizeof(KeyT) * 8)
135 {
136 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortPairs(
137 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, begin_bit, end_bit, _stream));
138 }
139
140 template <typename KeyT, typename ValueT, typename NumItemsT>
141 DeviceRadixSort& SortPairs(void* d_temp_storage,
142 size_t& temp_storage_bytes,
143 cub::DoubleBuffer<KeyT>& d_keys,
144 cub::DoubleBuffer<ValueT>& d_values,
145 NumItemsT num_items,
146 int begin_bit = 0,
147 int end_bit = sizeof(KeyT) * 8)
148 {
149 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortPairs(
150 d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, _stream));
151 }
152
153 template <typename KeyT, typename ValueT, typename NumItemsT>
154 DeviceRadixSort& SortPairsDescending(void* d_temp_storage,
155 size_t& temp_storage_bytes,
156 const KeyT* d_keys_in,
157 KeyT* d_keys_out,
158 const ValueT* d_values_in,
159 ValueT* d_values_out,
160 NumItemsT num_items,
161 int begin_bit = 0,
162 int end_bit = sizeof(KeyT) * 8)
163 {
164 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortPairsDescending(
165 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, begin_bit, end_bit, _stream));
166 }
167
168 template <typename KeyT, typename ValueT, typename NumItemsT>
169 DeviceRadixSort& SortPairsDescending(void* d_temp_storage,
170 size_t& temp_storage_bytes,
171 cub::DoubleBuffer<KeyT>& d_keys,
172 cub::DoubleBuffer<ValueT>& d_values,
173 NumItemsT num_items,
174 int begin_bit = 0,
175 int end_bit = sizeof(KeyT) * 8)
176 {
177 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortPairsDescending(
178 d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, _stream));
179 }
180
181 template <typename KeyT, typename NumItemsT>
182 DeviceRadixSort& SortKeys(void* d_temp_storage,
183 size_t& temp_storage_bytes,
184 const KeyT* d_keys_in,
185 KeyT* d_keys_out,
186 NumItemsT num_items,
187 int begin_bit = 0,
188 int end_bit = sizeof(KeyT) * 8)
189 {
190 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortKeys(
191 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items, begin_bit, end_bit, _stream));
192 }
193
194 template <typename KeyT, typename NumItemsT>
195 DeviceRadixSort& SortKeys(void* d_temp_storage,
196 size_t& temp_storage_bytes,
197 cub::DoubleBuffer<KeyT>& d_keys,
198 NumItemsT num_items,
199 int begin_bit = 0,
200 int end_bit = sizeof(KeyT) * 8)
201 {
202 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortKeys(
203 d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, _stream));
204 }
205
206 template <typename KeyT, typename NumItemsT>
207 DeviceRadixSort& SortKeysDescending(void* d_temp_storage,
208 size_t& temp_storage_bytes,
209 const KeyT* d_keys_in,
210 KeyT* d_keys_out,
211 NumItemsT num_items,
212 int begin_bit = 0,
213 int end_bit = sizeof(KeyT) * 8)
214 {
215 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortKeysDescending(
216 d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items, begin_bit, end_bit, _stream));
217 }
218
219 template <typename KeyT, typename NumItemsT>
220 DeviceRadixSort& SortKeysDescending(void* d_temp_storage,
221 size_t& temp_storage_bytes,
222 cub::DoubleBuffer<KeyT>& d_keys,
223 NumItemsT num_items,
224 int begin_bit = 0,
225 int end_bit = sizeof(KeyT) * 8)
226 {
227 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceRadixSort::SortKeysDescending(
228 d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, _stream));
229 }
230};
231} // namespace muda
232
233#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_radix_sort.h:18
Definition vector.h:23
Definition launch_base.h:42