File compute_graph_catpure_node.h
File List > compute_graph > nodes > compute_graph_catpure_node.h
Go to the documentation of this file
#pragma once
#include <muda/compute_graph/compute_graph_node.h>
#include <muda/graph/graph.h>
namespace muda
{
class ComputeGraphCaptureNode : public ComputeGraphNodeBase
{
protected:
friend class ComputeGraph;
friend class details::ComputeGraphAccessor;
ComputeGraphCaptureNode(NodeId node_id, uint64_t access_index)
: ComputeGraphNodeBase(enum_name(ComputeGraphNodeType::CaptureNode),
node_id,
access_index,
ComputeGraphNodeType::CaptureNode)
{
auto n = std::string_view{
details::LaunchInfoCache::current_capture_name().auto_select()};
if(n.empty() || n == "")
m_name += std::string(":~");
else
m_name += std::string(":") + std::string(n.data());
}
virtual ~ComputeGraphCaptureNode() override { update_sub_graph(nullptr); }
void set_node(cudaGraphNode_t node) { set_handle(node); }
void update_sub_graph(cudaGraph_t sub_graph)
{
if(m_sub_graph)
checkCudaErrors(cudaGraphDestroy(m_sub_graph));
m_sub_graph = sub_graph;
}
cudaGraph_t m_sub_graph = nullptr;
};
} // namespace muda