MUDA
Loading...
Searching...
No Matches
device_reduce.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_reduce.cuh>
6#endif
7
8namespace muda
9{
10//ref: https://nvlabs.github.io/cub/structcub_1_1_device_reduce.html
11class DeviceReduce : public CubWrapper<DeviceReduce>
12{
14
15 public:
16 using Base::Base;
17
18 // DeviceVector:
19
20 template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
21 DeviceReduce& Reduce(InputIteratorT d_in,
22 OutputIteratorT d_out,
23 int num_items,
24 ReductionOpT reduction_op,
25 T init)
26 {
27 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::Reduce(
28 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, _stream, false));
29 }
30
31 template <typename InputIteratorT, typename OutputIteratorT>
32 DeviceReduce& Sum(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
33 {
34 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::Sum(
35 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
36 }
37
38
39 template <typename InputIteratorT, typename OutputIteratorT>
40 DeviceReduce& Min(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
41 {
42 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::Min(
43 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
44 }
45
46
47 template <typename InputIteratorT, typename OutputIteratorT>
48 DeviceReduce& ArgMin(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
49 {
50 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::ArgMin(
51 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
52 }
53
54
55 template <typename InputIteratorT, typename OutputIteratorT>
56 DeviceReduce& Max(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
57 {
58
59 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::Max(
60 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
61 }
62
63
64 template <typename InputIteratorT, typename OutputIteratorT>
65 DeviceReduce& ArgMax(InputIteratorT d_in, OutputIteratorT d_out, int num_items)
66 {
67 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::ArgMax(
68 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
69 }
70
71 template <typename KeysInputIteratorT, typename UniqueOutputIteratorT, typename ValuesInputIteratorT, typename AggregatesOutputIteratorT, typename NumRunsOutputIteratorT, typename ReductionOpT>
72 DeviceReduce& ReduceByKey(KeysInputIteratorT d_keys_in,
73 UniqueOutputIteratorT d_unique_out,
74 ValuesInputIteratorT d_values_in,
75 AggregatesOutputIteratorT d_aggregates_out,
76 NumRunsOutputIteratorT d_num_runs_out,
77 ReductionOpT reduction_op,
78 int num_items)
79 {
80 MUDA_CUB_WRAPPER_IMPL(cub::DeviceReduce::ReduceByKey(d_temp_storage,
81 temp_storage_bytes,
82 d_keys_in,
83 d_unique_out,
84 d_values_in,
85 d_aggregates_out,
86 d_num_runs_out,
87 reduction_op,
88 num_items));
89 }
90
91
92 // Origin:
93
94 template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
95 DeviceReduce& Reduce(void* d_temp_storage,
96 size_t& temp_storage_bytes,
97 InputIteratorT d_in,
98 OutputIteratorT d_out,
99 int num_items,
100 ReductionOpT reduction_op,
101 T init)
102 {
103
104 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::Reduce(
105 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, _stream, false));
106 }
107
108 template <typename InputIteratorT, typename OutputIteratorT>
109 DeviceReduce& Sum(void* d_temp_storage,
110 size_t& temp_storage_bytes,
111 InputIteratorT d_in,
112 OutputIteratorT d_out,
113 int num_items)
114 {
115 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::Sum(
116 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
117 }
118
119
120 template <typename InputIteratorT, typename OutputIteratorT>
121 DeviceReduce& Min(void* d_temp_storage,
122 size_t& temp_storage_bytes,
123 InputIteratorT d_in,
124 OutputIteratorT d_out,
125 int num_items)
126 {
127 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::Min(
128 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
129 }
130
131 template <typename InputIteratorT, typename OutputIteratorT>
132 DeviceReduce& ArgMin(void* d_temp_storage,
133 size_t& temp_storage_bytes,
134 InputIteratorT d_in,
135 OutputIteratorT d_out,
136 int num_items)
137 {
138 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::ArgMin(
139 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
140 }
141
142 template <typename InputIteratorT, typename OutputIteratorT>
143 DeviceReduce& Max(void* d_temp_storage,
144 size_t& temp_storage_bytes,
145 InputIteratorT d_in,
146 OutputIteratorT d_out,
147 int num_items)
148 {
149
150 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::Max(
151 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
152 }
153
154 template <typename InputIteratorT, typename OutputIteratorT>
155 DeviceReduce& ArgMax(void* d_temp_storage,
156 size_t& temp_storage_bytes,
157 InputIteratorT d_in,
158 OutputIteratorT d_out,
159 int num_items)
160 {
161 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceReduce::ArgMax(
162 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
163 }
164
165 template <typename KeysInputIteratorT, typename UniqueOutputIteratorT, typename ValuesInputIteratorT, typename AggregatesOutputIteratorT, typename NumRunsOutputIteratorT, typename ReductionOpT>
166 DeviceReduce& ReduceByKey(void* d_temp_storage,
167 size_t& temp_storage_bytes,
168 KeysInputIteratorT d_keys_in,
169 UniqueOutputIteratorT d_unique_out,
170 ValuesInputIteratorT d_values_in,
171 AggregatesOutputIteratorT d_aggregates_out,
172 NumRunsOutputIteratorT d_num_runs_out,
173 ReductionOpT reduction_op,
174 int num_items)
175 {
176 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
177 cub::DeviceReduce::ReduceByKey(d_temp_storage,
178 temp_storage_bytes,
179 d_keys_in,
180 d_unique_out,
181 d_values_in,
182 d_aggregates_out,
183 d_num_runs_out,
184 reduction_op,
185 num_items));
186 }
187};
188} // namespace muda
189
190#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_reduce.h:12
Definition launch_base.h:42