MUDA
Loading...
Searching...
No Matches
compute_graph_accessor.h
1#pragma once
2#include <cuda_runtime.h>
3#include <muda/compute_graph/compute_graph_fwd.h>
4#include <muda/graph/kernel_node.h>
5#include <muda/graph/memory_node.h>
6#include <muda/graph/event_node.h>
7namespace muda
8{
9namespace details
10{
11 // allow devlopers to access some internal function
13 {
14 friend class ComputeGraph;
15 ComputeGraph& m_cg;
16 template <typename T>
17 using S = std::shared_ptr<T>;
18
19 public:
21
24
25 /************************************************************************************
26 *
27 * Graph Add/Update node API
28 *
29 * Automatically add or update graph node by parms (distincted by ComputeGraphPhase)
30 *
31 *************************************************************************************/
32 template <typename T>
33 void set_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
34 void set_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
35 void set_memcpy_node(const cudaMemcpy3DParms& parms);
36 void set_memset_node(const cudaMemsetParams& parms);
37 void set_event_record_node(cudaEvent_t event);
38 void set_event_wait_node(cudaEvent_t event);
39 void set_capture_node(cudaGraph_t sub_graph);
40
41 /************************************************************************************
42 *
43 * Current State Query API
44 *
45 *************************************************************************************/
46 auto current_closure() const
47 -> const std::pair<std::string, ComputeGraphClosure*>&;
48 auto current_closure() -> std::pair<std::string, ComputeGraphClosure*>&;
49 template <typename T>
50 T* current_node();
51 const ComputeGraphNodeBase* current_node() const;
52 ComputeGraphNodeBase* current_node();
53 cudaStream_t current_stream() const;
54 cudaStream_t capture_stream() const;
55
56 bool is_topo_built() const;
57
58 /************************************************************************************
59 *
60 * Current State Check API
61 *
62 *************************************************************************************/
63 void check_allow_var_eval() const;
64 void check_allow_node_adding() const;
65
66 private:
67 friend class muda::ComputeGraphVarBase;
68 void set_var_usage(VarId id, ComputeGraphVarUsage usage);
69
70 template <typename T>
71 void add_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
72 template <typename T>
73 void update_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
74
75 void add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
76 void update_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
77 void add_memcpy_node(const cudaMemcpy3DParms& parms);
78 void update_memcpy_node(const cudaMemcpy3DParms& parms);
79
80 void add_memset_node(const cudaMemsetParams& parms);
81 void update_memset_node(const cudaMemsetParams& parms);
82
83 void add_event_record_node(cudaEvent_t event);
84 void update_event_record_node(cudaEvent_t event);
85
86 void add_event_wait_node(cudaEvent_t event);
87 void update_event_wait_node(cudaEvent_t event);
88
89 void add_capture_node(cudaGraph_t sub_graph);
90 void update_capture_node(cudaGraph_t sub_graph);
91
92 template <typename F>
93 void access_graph(F&& f);
94
95 template <typename F>
96 void access_graph_exec(F&& f);
97
98 //auto&& temp_var_usage()
99 //{
100 // return std::move(m_cg.m_temp_node_info.var_usage);
101 //}
102
103 template <typename NodeType, typename F>
104 NodeType* get_or_create_node(F&& f);
105 };
106} // namespace details
107} // namespace muda
108
109#include "details/compute_graph_accessor.inl"
Definition compute_graph.h:38
Definition compute_graph_node.h:13
Definition compute_graph_var.h:17
Definition kernel_node.h:15
Definition compute_graph_var_id.h:6
Definition compute_graph_accessor.h:13