MUDA
Loading...
Searching...
No Matches
host_call.h
1#pragma once
2#include <muda/launch/launch_base.h>
3
4namespace muda
5{
6namespace details
7{
8 template <typename F, typename UserTag>
9 MUDA_HOST void CUDARTAPI generic_host_call(void* userdata)
10 {
11 auto f = reinterpret_cast<F*>(userdata);
12 (*f)();
13 }
14
15 template <typename F, typename UserTag>
16 MUDA_HOST void CUDARTAPI delete_function_object(void* userdata)
17 {
18 auto f = reinterpret_cast<F*>(userdata);
19 delete f;
20 }
21} // namespace details
22
23
24class HostCall : public LaunchBase<HostCall>
25{
26 public:
27 MUDA_HOST HostCall(cudaStream_t stream = nullptr)
28 : LaunchBase(stream)
29 {
30 }
31
32 template <typename F, typename UserTag = DefaultTag>
33 MUDA_HOST HostCall& apply(F&& f, UserTag tag = {})
34 {
35 MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
36 "HostCall must be can't appear in a compute graph");
37 using CallableType = raw_type_t<F>;
38 static_assert(std::is_invocable_v<CallableType>, "f:void (void)");
39 auto userdata = new CallableType(std::forward<F>(f));
40 checkCudaErrors(cudaLaunchHostFunc(
41 this->stream(), details::generic_host_call<CallableType, UserTag>, userdata));
42 checkCudaErrors(cudaLaunchHostFunc(
43 this->stream(), details::delete_function_object<CallableType, UserTag>, userdata));
44 return *this;
45 }
46
55 template <typename F, typename UserTag = DefaultTag>
56 MUDA_NODISCARD MUDA_HOST auto as_node_parms(F&& f, UserTag tag = {})
57 {
58 using CallableType = raw_type_t<F>;
59 auto parms = std::make_shared<HostNodeParms<CallableType>>(std::forward<F>(f));
60 parms->fn((cudaHostFn_t)details::generic_host_call<CallableType, UserTag>);
61 return parms;
62 }
63};
64} // namespace muda
Definition host_call.h:25
MUDA_NODISCARD MUDA_HOST auto as_node_parms(F &&f, UserTag tag={})
Definition host_call.h:56
Definition launch_base.h:86