Skip to content

File compute_graph_node.h

File List > compute_graph > compute_graph_node.h

Go to the documentation of this file

#pragma once
#include <map>
#include <string>
#include <muda/compute_graph/compute_graph_node_type.h>
#include <muda/compute_graph/compute_graph_node_id.h>
#include <muda/compute_graph/compute_graph_var_usage.h>
#include <muda/compute_graph/compute_graph_var_id.h>
#include <muda/compute_graph/compute_graph_fwd.h>

namespace muda
{
class ComputeGraphNodeBase
{
  public:
    auto node_id() const { return m_node_id; }
    auto access_index() const { return m_access_index; }
    auto type() const { return m_type; }
    auto name() const { return std::string_view{m_name}; }

    virtual ~ComputeGraphNodeBase() = default;

  protected:
    template <typename T>
    using S = std::shared_ptr<T>;

    friend class ComputeGraph;
    friend class ComputeGraphVarBase;
    ComputeGraphNodeBase(std::string_view name, NodeId node_id, uint64_t access_index, ComputeGraphNodeType type)
        : m_name(name)
        , m_node_id(node_id)
        , m_access_index(access_index)
        , m_type(type)
    {
    }

    std::string m_name;
    NodeId      m_node_id;
    uint64_t    m_access_index;

    ComputeGraphNodeType m_type;
    cudaGraphNode_t      m_cuda_node = nullptr;


    auto handle() const { return m_cuda_node; }
    void set_handle(cudaGraphNode_t handle) { m_cuda_node = handle; }
    auto is_valid() const { return m_cuda_node; }
};

template <typename NodeT, ComputeGraphNodeType Type>
class ComputeGraphNode : public ComputeGraphNodeBase
{
  protected:
    friend class ComputeGraph;
    friend class details::ComputeGraphAccessor;
    ComputeGraphNode(NodeId node_id, uint64_t access_graph_index);

    S<NodeT> m_node;
    void     set_node(S<NodeT> node);
    virtual ~ComputeGraphNode() = default;
};
}  // namespace muda

#include "details/compute_graph_node.inl"