File compute_graph_accessor.h
File List > compute_graph > compute_graph_accessor.h
Go to the documentation of this file
#pragma once
#include <cuda_runtime.h>
#include <muda/compute_graph/compute_graph_fwd.h>
#include <muda/graph/kernel_node.h>
#include <muda/graph/memory_node.h>
#include <muda/graph/event_node.h>
namespace muda
{
namespace details
{
// allow devlopers to access some internal function
class ComputeGraphAccessor
{
friend class ComputeGraph;
ComputeGraph& m_cg;
template <typename T>
using S = std::shared_ptr<T>;
public:
ComputeGraphAccessor();
ComputeGraphAccessor(ComputeGraph& graph);
ComputeGraphAccessor(ComputeGraph* graph);
/************************************************************************************
*
* Graph Add/Update node API
*
* Automatically add or update graph node by parms (distincted by ComputeGraphPhase)
*
*************************************************************************************/
template <typename T>
void set_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
void set_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
void set_memcpy_node(const cudaMemcpy3DParms& parms);
void set_memset_node(const cudaMemsetParams& parms);
void set_event_record_node(cudaEvent_t event);
void set_event_wait_node(cudaEvent_t event);
void set_capture_node(cudaGraph_t sub_graph);
/************************************************************************************
*
* Current State Query API
*
*************************************************************************************/
auto current_closure() const
-> const std::pair<std::string, ComputeGraphClosure*>&;
auto current_closure() -> std::pair<std::string, ComputeGraphClosure*>&;
template <typename T>
T* current_node();
const ComputeGraphNodeBase* current_node() const;
ComputeGraphNodeBase* current_node();
cudaStream_t current_stream() const;
cudaStream_t capture_stream() const;
bool is_topo_built() const;
/************************************************************************************
*
* Current State Check API
*
*************************************************************************************/
void check_allow_var_eval() const;
void check_allow_node_adding() const;
private:
friend class muda::ComputeGraphVarBase;
void set_var_usage(VarId id, ComputeGraphVarUsage usage);
template <typename T>
void add_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
template <typename T>
void update_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
void add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
void update_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
void add_memcpy_node(const cudaMemcpy3DParms& parms);
void update_memcpy_node(const cudaMemcpy3DParms& parms);
void add_memset_node(const cudaMemsetParams& parms);
void update_memset_node(const cudaMemsetParams& parms);
void add_event_record_node(cudaEvent_t event);
void update_event_record_node(cudaEvent_t event);
void add_event_wait_node(cudaEvent_t event);
void update_event_wait_node(cudaEvent_t event);
void add_capture_node(cudaGraph_t sub_graph);
void update_capture_node(cudaGraph_t sub_graph);
template <typename F>
void access_graph(F&& f);
template <typename F>
void access_graph_exec(F&& f);
//auto&& temp_var_usage()
//{
// return std::move(m_cg.m_temp_node_info.var_usage);
//}
template <typename NodeType, typename F>
NodeType* get_or_create_node(F&& f);
};
} // namespace details
} // namespace muda
#include "details/compute_graph_accessor.inl"