File host_call.h
File List > launch > host_call.h
Go to the documentation of this file
#pragma once
#include <muda/launch/launch_base.h>
namespace muda
{
namespace details
{
template <typename F, typename UserTag>
MUDA_HOST void CUDARTAPI generic_host_call(void* userdata)
{
auto f = reinterpret_cast<F*>(userdata);
(*f)();
}
template <typename F, typename UserTag>
MUDA_HOST void CUDARTAPI delete_function_object(void* userdata)
{
auto f = reinterpret_cast<F*>(userdata);
delete f;
}
} // namespace details
class HostCall : public LaunchBase<HostCall>
{
public:
MUDA_HOST HostCall(cudaStream_t stream = nullptr)
: LaunchBase(stream)
{
}
template <typename F, typename UserTag = DefaultTag>
MUDA_HOST HostCall& apply(F&& f, UserTag tag = {})
{
MUDA_ASSERT(ComputeGraphBuilder::is_phase_none(),
"HostCall must be can't appear in a compute graph");
using CallableType = raw_type_t<F>;
static_assert(std::is_invocable_v<CallableType>, "f:void (void)");
auto userdata = new CallableType(std::forward<F>(f));
checkCudaErrors(cudaLaunchHostFunc(
this->stream(), details::generic_host_call<CallableType, UserTag>, userdata));
checkCudaErrors(cudaLaunchHostFunc(
this->stream(), details::delete_function_object<CallableType, UserTag>, userdata));
return *this;
}
template <typename F, typename UserTag = DefaultTag>
MUDA_NODISCARD MUDA_HOST auto as_node_parms(F&& f, UserTag tag = {})
{
using CallableType = raw_type_t<F>;
auto parms = std::make_shared<HostNodeParms<CallableType>>(std::forward<F>(f));
parms->fn((cudaHostFn_t)details::generic_host_call<CallableType, UserTag>);
return parms;
}
};
} // namespace muda