MUDA
Loading...
Searching...
No Matches
cub_wrapper_macro_def.inl
1// don't place #pragma once at the beginning of this file
2// because it should be inserted in multiple files
3
4#define MUDA_CUB_WRAPPER_IMPL(x) \
5 cudaStream_t _stream = this->stream(); \
6 size_t temp_storage_bytes = 0; \
7 void* d_temp_storage = nullptr; \
8 \
9 checkCudaErrors(x); \
10 \
11 d_temp_storage = (void*)prepare_buffer(temp_storage_bytes); \
12 \
13 checkCudaErrors(x); \
14 \
15 return *this;
16
17#define MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(x) \
18 std::string_view name{__func__}; \
19 ComputeGraphBuilder::invoke_phase_actions( \
20 [&] \
21 { \
22 cudaStream_t _stream = this->stream(); \
23 checkCudaErrors(x); \
24 }, \
25 [&] \
26 { \
27 MUDA_ASSERT(!ComputeGraphBuilder::is_building() || d_temp_storage != nullptr, \
28 "d_temp_storage must not be nullptr when building graph. you should not" \
29 "query the temp_storage_size when building a compute graph, please do it outside" \
30 "a compute graph."); \
31 ComputeGraphBuilder::capture( \
32 name, [&](cudaStream_t _stream) { checkCudaErrors(x); }); \
33 }); \
34 return *this;