17 using S = std::shared_ptr<T>;
34 void set_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
35 void set_memcpy_node(
const cudaMemcpy3DParms& parms);
36 void set_memset_node(
const cudaMemsetParams& parms);
37 void set_event_record_node(cudaEvent_t event);
38 void set_event_wait_node(cudaEvent_t event);
39 void set_capture_node(cudaGraph_t sub_graph);
46 auto current_closure()
const
47 ->
const std::pair<std::string, ComputeGraphClosure*>&;
48 auto current_closure() -> std::pair<std::string, ComputeGraphClosure*>&;
53 cudaStream_t current_stream()
const;
54 cudaStream_t capture_stream()
const;
56 bool is_topo_built()
const;
63 void check_allow_var_eval()
const;
64 void check_allow_node_adding()
const;
68 void set_var_usage(
VarId id, ComputeGraphVarUsage usage);
75 void add_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
76 void update_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
77 void add_memcpy_node(
const cudaMemcpy3DParms& parms);
78 void update_memcpy_node(
const cudaMemcpy3DParms& parms);
80 void add_memset_node(
const cudaMemsetParams& parms);
81 void update_memset_node(
const cudaMemsetParams& parms);
83 void add_event_record_node(cudaEvent_t event);
84 void update_event_record_node(cudaEvent_t event);
86 void add_event_wait_node(cudaEvent_t event);
87 void update_event_wait_node(cudaEvent_t event);
89 void add_capture_node(cudaGraph_t sub_graph);
90 void update_capture_node(cudaGraph_t sub_graph);
93 void access_graph(F&& f);
96 void access_graph_exec(F&& f);
103 template <
typename NodeType,
typename F>
104 NodeType* get_or_create_node(F&& f);