MUDA
Loading...
Searching...
No Matches
device_scan.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_scan.cuh>
6#else
7namespace cub
8{
9class Equality
10{
11 //dummy class just for Intellisense
12};
13} // namespace cub
14#endif
15
16
17namespace muda
18{
19//ref: https://nvlabs.github.io/cub/structcub_1_1_device_scan.html
20class DeviceScan : public CubWrapper<DeviceScan>
21{
23
24 public:
25 using Base::Base;
26
27 template <typename InputIteratorT, typename OutputIteratorT>
28 DeviceScan& ExclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
29 {
30 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::ExclusiveSum(
31 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
32 }
33
34
35 template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT>
36 DeviceScan& ExclusiveScan(InputIteratorT d_in,
37 OutputIteratorT d_out,
38 ScanOpT scan_op,
39 InitValueT init_value,
40 int num_items)
41 {
42 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::ExclusiveScan(
43 d_temp_storage, temp_storage_bytes, d_in, d_out, scan_op, init_value, num_items, _stream, false));
44 }
45
46
47 template <typename InputIteratorT, typename OutputIteratorT>
48 DeviceScan& InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
49 {
50 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::InclusiveSum(
51 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
52 }
53
54 template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT>
55 DeviceScan& InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op, int num_items)
56 {
57 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::InclusiveScan(
58 d_temp_storage, temp_storage_bytes, d_in, d_out, scan_op, num_items, _stream, false));
59 }
60
61 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename EqualityOpT = cub::Equality>
62 DeviceScan& ExclusiveSumByKey(KeysInputIteratorT d_keys_in,
63 ValuesInputIteratorT d_values_in,
64 ValuesOutputIteratorT d_values_out,
65 int num_items,
66 EqualityOpT equality_op = EqualityOpT())
67 {
68 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::ExclusiveSumByKey(
69 d_temp_storage, temp_storage_bytes, d_keys_in, d_values_in, d_values_out, num_items, equality_op, _stream, false));
70 }
71
72 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT, typename InitValueT, typename EqualityOpT = cub::Equality>
73 DeviceScan& ExclusiveScanByKey(KeysInputIteratorT d_keys_in,
74 ValuesInputIteratorT d_values_in,
75 ValuesOutputIteratorT d_values_out,
76 ScanOpT scan_op,
77 InitValueT init_value,
78 int num_items,
79 EqualityOpT equality_op = EqualityOpT())
80 {
81 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::ExclusiveScanByKey(d_temp_storage,
82 temp_storage_bytes,
83 d_keys_in,
84 d_values_in,
85 d_values_out,
86 scan_op,
87 init_value,
88 num_items,
89 equality_op,
90 _stream,
91 false));
92 }
93
94 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename EqualityOpT = cub::Equality>
95 DeviceScan& InclusiveSumByKey(KeysInputIteratorT d_keys_in,
96 ValuesInputIteratorT d_values_in,
97 ValuesOutputIteratorT d_values_out,
98 int num_items,
99 EqualityOpT equality_op = EqualityOpT())
100 {
101 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::InclusiveSumByKey(
102 d_temp_storage, temp_storage_bytes, d_keys_in, d_values_in, d_values_out, num_items, equality_op, _stream, false));
103 }
104
105 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT, typename EqualityOpT = cub::Equality>
106 DeviceScan& InclusiveScanByKey(KeysInputIteratorT d_keys_in,
107 ValuesInputIteratorT d_values_in,
108 ValuesOutputIteratorT d_values_out,
109 ScanOpT scan_op,
110 int num_items,
111 EqualityOpT equality_op = EqualityOpT())
112 {
113 MUDA_CUB_WRAPPER_IMPL(cub::DeviceScan::InclusiveScanByKey(d_temp_storage,
114 temp_storage_bytes,
115 d_keys_in,
116 d_values_in,
117 d_values_out,
118 scan_op,
119 num_items,
120 equality_op,
121 _stream,
122 false));
123 }
124
125 // Origin:
126
127 template <typename InputIteratorT, typename OutputIteratorT>
128 DeviceScan& ExclusiveSum(void* d_temp_storage,
129 size_t& temp_storage_bytes,
130 InputIteratorT d_in,
131 OutputIteratorT d_out,
132 int num_items)
133 {
134 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::ExclusiveSum(
135 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
136 }
137
138
139 template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT>
140 DeviceScan& ExclusiveScan(void* d_temp_storage,
141 size_t& temp_storage_bytes,
142 InputIteratorT d_in,
143 OutputIteratorT d_out,
144 ScanOpT scan_op,
145 InitValueT init_value,
146 int num_items)
147 {
148 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::ExclusiveScan(
149 d_temp_storage, temp_storage_bytes, d_in, d_out, scan_op, init_value, num_items, _stream, false));
150 }
151
152
153 template <typename InputIteratorT, typename OutputIteratorT>
154 DeviceScan& InclusiveSum(void* d_temp_storage,
155 size_t& temp_storage_bytes,
156 InputIteratorT d_in,
157 OutputIteratorT d_out,
158 int num_items)
159 {
160 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::InclusiveSum(
161 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
162 }
163
164 template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT>
165 DeviceScan& InclusiveScan(void* d_temp_storage,
166 size_t& temp_storage_bytes,
167 InputIteratorT d_in,
168 OutputIteratorT d_out,
169 ScanOpT scan_op,
170 int num_items)
171 {
172 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::InclusiveScan(
173 d_temp_storage, temp_storage_bytes, d_in, d_out, scan_op, num_items, _stream, false));
174 }
175
176 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename EqualityOpT = cub::Equality>
177 DeviceScan& ExclusiveSumByKey(void* d_temp_storage,
178 size_t& temp_storage_bytes,
179 KeysInputIteratorT d_keys_in,
180 ValuesInputIteratorT d_values_in,
181 ValuesOutputIteratorT d_values_out,
182 int num_items,
183 EqualityOpT equality_op = EqualityOpT())
184 {
185 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::ExclusiveSumByKey(
186 d_temp_storage, temp_storage_bytes, d_keys_in, d_values_in, d_values_out, num_items, equality_op, _stream, false));
187 }
188
189 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT, typename InitValueT, typename EqualityOpT = cub::Equality>
190 DeviceScan& ExclusiveScanByKey(void* d_temp_storage,
191 size_t& temp_storage_bytes,
192 KeysInputIteratorT d_keys_in,
193 ValuesInputIteratorT d_values_in,
194 ValuesOutputIteratorT d_values_out,
195 ScanOpT scan_op,
196 InitValueT init_value,
197 int num_items,
198 EqualityOpT equality_op = EqualityOpT())
199 {
200 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
201 cub::DeviceScan::ExclusiveScanByKey(d_temp_storage,
202 temp_storage_bytes,
203 d_keys_in,
204 d_values_in,
205 d_values_out,
206 scan_op,
207 init_value,
208 num_items,
209 equality_op,
210 _stream,
211 false));
212 }
213
214 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename EqualityOpT = cub::Equality>
215 DeviceScan& InclusiveSumByKey(void* d_temp_storage,
216 size_t& temp_storage_bytes,
217 KeysInputIteratorT d_keys_in,
218 ValuesInputIteratorT d_values_in,
219 ValuesOutputIteratorT d_values_out,
220 int num_items,
221 EqualityOpT equality_op = EqualityOpT())
222 {
223 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::InclusiveSumByKey(
224 d_temp_storage, temp_storage_bytes, d_keys_in, d_values_in, d_values_out, num_items, equality_op, _stream, false));
225 }
226
227 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT, typename EqualityOpT = cub::Equality>
228 DeviceScan& InclusiveScanByKey(void* d_temp_storage,
229 size_t& temp_storage_bytes,
230 KeysInputIteratorT d_keys_in,
231 ValuesInputIteratorT d_values_in,
232 ValuesOutputIteratorT d_values_out,
233 ScanOpT scan_op,
234 int num_items,
235 EqualityOpT equality_op = EqualityOpT())
236 {
237 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceScan::InclusiveScanByKey(
238 d_temp_storage, temp_storage_bytes, d_keys_in, d_values_in, d_values_out, scan_op, num_items, equality_op, _stream, false));
239 }
240};
241} // namespace muda
242
243#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_scan.h:21
Definition launch_base.h:42