MUDA
Loading...
Searching...
No Matches
device_segmented_reduce.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_segmented_reduce.cuh>
6#endif
7
8namespace muda
9{
10//ref: https://nvlabs.github.io/cub/structcub_1_1_device_reduce.html
11class DeviceSegmentedReduce : public CubWrapper<DeviceSegmentedReduce>
12{
14
15 public:
16 using Base::Base;
17
18 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT, typename ReductionOp, typename T>
19 DeviceSegmentedReduce& Reduce(InputIteratorT d_in,
20 OutputIteratorT d_out,
21 int num_segments,
22 BeginOffsetIteratorT d_begin_offsets,
23 EndOffsetIteratorT d_end_offsets,
24 ReductionOp reduction_op,
25 T initial_value)
26 {
27
28
29 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::Reduce(d_temp_storage,
30 temp_storage_bytes,
31 d_in,
32 d_out,
33 num_segments,
34 d_begin_offsets,
35 d_end_offsets,
36 reduction_op,
37 initial_value,
38 _stream,
39 false));
40 }
41
42 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
43 DeviceSegmentedReduce& Sum(InputIteratorT d_in,
44 OutputIteratorT d_out,
45 int num_segments,
46 BeginOffsetIteratorT d_begin_offsets,
47 EndOffsetIteratorT d_end_offsets)
48 {
49 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::Sum(
50 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
51 }
52
53 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
54 DeviceSegmentedReduce& Min(InputIteratorT d_in,
55 OutputIteratorT d_out,
56 int num_segments,
57 BeginOffsetIteratorT d_begin_offsets,
58 EndOffsetIteratorT d_end_offsets)
59 {
60 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::Min(
61 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
62 }
63
64 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
65 DeviceSegmentedReduce& ArgMin(InputIteratorT d_in,
66 OutputIteratorT d_out,
67 int num_segments,
68 BeginOffsetIteratorT d_begin_offsets,
69 EndOffsetIteratorT d_end_offsets)
70 {
71 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::ArgMin(
72 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
73 }
74
75 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
76 DeviceSegmentedReduce& Max(InputIteratorT d_in,
77 OutputIteratorT d_out,
78 int num_segments,
79 BeginOffsetIteratorT d_begin_offsets,
80 EndOffsetIteratorT d_end_offsets)
81 {
82 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::Max(
83 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
84 }
85
86 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
87 DeviceSegmentedReduce& ArgMax(InputIteratorT d_in,
88 OutputIteratorT d_out,
89 int num_segments,
90 BeginOffsetIteratorT d_begin_offsets,
91 EndOffsetIteratorT d_end_offsets)
92 {
93 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSegmentedReduce::ArgMax(
94 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
95 }
96
97 // Origin:
98
99 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT, typename ReductionOp, typename T>
100 DeviceSegmentedReduce& Reduce(void* d_temp_storage,
101 size_t& temp_storage_bytes,
102 InputIteratorT d_in,
103 OutputIteratorT d_out,
104 int num_segments,
105 BeginOffsetIteratorT d_begin_offsets,
106 EndOffsetIteratorT d_end_offsets,
107 ReductionOp reduction_op,
108 T initial_value)
109 {
110 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
111 cub::DeviceSegmentedReduce::Reduce(d_temp_storage,
112 temp_storage_bytes,
113 d_in,
114 d_out,
115 num_segments,
116 d_begin_offsets,
117 d_end_offsets,
118 reduction_op,
119 initial_value,
120 _stream,
121 false));
122 }
123
124 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
125 DeviceSegmentedReduce& Sum(void* d_temp_storage,
126 size_t& temp_storage_bytes,
127 InputIteratorT d_in,
128 OutputIteratorT d_out,
129 int num_segments,
130 BeginOffsetIteratorT d_begin_offsets,
131 EndOffsetIteratorT d_end_offsets)
132 {
133 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::Sum(
134 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
135 }
136
137 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
138 DeviceSegmentedReduce& Min(void* d_temp_storage,
139 size_t& temp_storage_bytes,
140 InputIteratorT d_in,
141 OutputIteratorT d_out,
142 int num_segments,
143 BeginOffsetIteratorT d_begin_offsets,
144 EndOffsetIteratorT d_end_offsets)
145 {
146 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::Min(
147 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
148 }
149
150 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
151 DeviceSegmentedReduce& ArgMin(void* d_temp_storage,
152 size_t& temp_storage_bytes,
153 InputIteratorT d_in,
154 OutputIteratorT d_out,
155 int num_segments,
156 BeginOffsetIteratorT d_begin_offsets,
157 EndOffsetIteratorT d_end_offsets)
158 {
159 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::ArgMin(
160 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
161 }
162
163 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
164 DeviceSegmentedReduce& Max(void* d_temp_storage,
165 size_t& temp_storage_bytes,
166 InputIteratorT d_in,
167 OutputIteratorT d_out,
168 int num_segments,
169 BeginOffsetIteratorT d_begin_offsets,
170 EndOffsetIteratorT d_end_offsets)
171 {
172 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::Max(
173 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
174 }
175
176
177 template <typename InputIteratorT, typename OutputIteratorT>
178 DeviceSegmentedReduce& ArgMax(void* d_temp_storage,
179 size_t& temp_storage_bytes,
180 InputIteratorT d_in,
181 OutputIteratorT d_out,
182 int num_items)
183 {
184 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::ArgMax(
185 d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, _stream, false));
186 }
187
188 template <typename InputIteratorT, typename OutputIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
189 DeviceSegmentedReduce& ArgMax(void* d_temp_storage,
190 size_t& temp_storage_bytes,
191 InputIteratorT d_in,
192 OutputIteratorT d_out,
193 int num_segments,
194 BeginOffsetIteratorT d_begin_offsets,
195 EndOffsetIteratorT d_end_offsets)
196 {
197 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSegmentedReduce::ArgMax(
198 d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, d_begin_offsets, d_end_offsets, _stream, false));
199 }
200};
201} // namespace muda
202
203#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_segmented_reduce.h:12
Definition launch_base.h:42