MUDA
Loading...
Searching...
No Matches
device_select.h
1#pragma once
2#include <muda/cub/device/cub_wrapper.h>
3#include "details/cub_wrapper_macro_def.inl"
4#ifndef __INTELLISENSE__
5#include <cub/device/device_select.cuh>
6#endif
7
8namespace muda
9{
10//ref: https://nvlabs.github.io/cub/structcub_1_1_device_select.html
11class DeviceSelect : public CubWrapper<DeviceSelect>
12{
14
15 public:
16 using Base::Base;
17
18 template <typename InputIteratorT, typename FlagIterator, typename OutputIteratorT, typename NumSelectedIteratorT>
19 DeviceSelect& Flagged(InputIteratorT d_in,
20 FlagIterator d_flags,
21 OutputIteratorT d_out,
22 NumSelectedIteratorT d_num_selected_out,
23 int num_items)
24 {
25 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSelect::Flagged(
26 d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items, _stream, false));
27 }
28
29 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT, typename SelectOp>
30 DeviceSelect& If(InputIteratorT d_in,
31 OutputIteratorT d_out,
32 NumSelectedIteratorT d_num_selected_out,
33 int num_items,
34 SelectOp select_op)
35 {
36 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSelect::If(
37 d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, _stream, false));
38 }
39
40 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
41 DeviceSelect& Unique(InputIteratorT d_in,
42 OutputIteratorT d_out,
43 NumSelectedIteratorT d_num_selected_out,
44 int num_items)
45 {
46 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSelect::Unique(
47 d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, _stream, false));
48 }
49#if CUB_VERSION >= 200200
50 template <typename KeyInputIteratorT, typename ValueInputIteratorT, typename KeyOutputIteratorT, typename ValueOutputIteratorT, typename NumSelectedIteratorT>
51 DeviceSelect& UniqueByKey(KeyInputIteratorT d_keys_in,
52 ValueInputIteratorT d_values_in,
53 KeyOutputIteratorT d_keys_out,
54 ValueOutputIteratorT d_values_out,
55 NumSelectedIteratorT d_num_selected_out,
56 int num_items)
57 {
58 MUDA_CUB_WRAPPER_IMPL(cub::DeviceSelect::UniqueByKey(d_temp_storage,
59 temp_storage_bytes,
60 d_keys_in,
61 d_values_in,
62 d_keys_out,
63 d_values_out,
64 d_num_selected_out,
65 num_items,
66 _stream,
67 false));
68 }
69#endif
70
71 // Origin:
72
73 template <typename InputIteratorT, typename FlagIterator, typename OutputIteratorT, typename NumSelectedIteratorT>
74 DeviceSelect& Flagged(void* d_temp_storage,
75 size_t& temp_storage_bytes,
76 InputIteratorT d_in,
77 FlagIterator d_flags,
78 OutputIteratorT d_out,
79 NumSelectedIteratorT d_num_selected_out,
80 int num_items)
81 {
82 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSelect::Flagged(
83 d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items, _stream, false));
84 }
85
86 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT, typename SelectOp>
87 DeviceSelect& If(void* d_temp_storage,
88 size_t& temp_storage_bytes,
89 InputIteratorT d_in,
90 OutputIteratorT d_out,
91 NumSelectedIteratorT d_num_selected_out,
92 int num_items,
93 SelectOp select_op)
94 {
95 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSelect::If(
96 d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, _stream, false));
97 }
98
99 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
100 DeviceSelect& Unique(void* d_temp_storage,
101 size_t& temp_storage_bytes,
102 InputIteratorT d_in,
103 OutputIteratorT d_out,
104 NumSelectedIteratorT d_num_selected_out,
105 int num_items)
106 {
107 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSelect::Unique(
108 d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, _stream, false));
109 }
110#if CUB_VERSION >= 200200
111 template <typename KeyInputIteratorT, typename ValueInputIteratorT, typename KeyOutputIteratorT, typename ValueOutputIteratorT, typename NumSelectedIteratorT>
112 DeviceSelect& UniqueByKey(void* d_temp_storage,
113 size_t& temp_storage_bytes,
114 KeyInputIteratorT d_keys_in,
115 ValueInputIteratorT d_values_in,
116 KeyOutputIteratorT d_keys_out,
117 ValueOutputIteratorT d_values_out,
118 NumSelectedIteratorT d_num_selected_out,
119 int num_items)
120 {
121 MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(
122 cub::DeviceSelect::UniqueByKey(d_temp_storage,
123 temp_storage_bytes,
124 d_keys_in,
125 d_values_in,
126 d_keys_out,
127 d_values_out,
128 d_num_selected_out,
129 num_items,
130 _stream,
131 false));
132 }
133#endif
134};
135} // namespace muda
136
137#include "details/cub_wrapper_macro_undef.inl"
Definition cub_wrapper.h:14
Definition device_select.h:12
Definition launch_base.h:42