MUDA
Loading...
Searching...
No Matches
compute_graph_builder.h
1#pragma once
2#include <functional>
3#include <muda/compute_graph/compute_graph_phase.h>
4#include <muda/compute_graph/compute_graph_fwd.h>
5#include <functional>
6
7namespace muda
8{
10{
11 static ComputeGraphBuilder& instance();
12 using Phase = ComputeGraphPhase;
13 using PhaseAction = std::function<void()>;
14 using CaptureAction = std::function<void(cudaStream_t)>;
15
16 public:
17 static Phase current_phase();
18 static void capture(CaptureAction&& cap);
19 static void capture(std::string_view name, CaptureAction&& cap);
20 static bool is_phase_none();
21 static bool is_phase_serial_launching();
22 static bool is_topo_building();
23 static bool is_building();
24 // return true when no graph is building or the graph is in serial launching mode
25 static bool is_direct_launching();
26 static bool is_caturing();
27
28
29 // do_when_direct_launch
30 // do_when_set_node => do_when_add_node & do_when_update_node
31 // if do_when_topo_building_set_node == nullptr, do_when_set_node will be called
32 // if do_when_topo_building_set_node != nullptr, do_when_topo_building_set_node will be called
33 // copy this code to use:
34 /*
35 ComputeGraphBuilder::invoke_phase_actions(
36 [&] // do_when_direct_launch
37 {
38
39 },
40 [&] // do_when_set_node
41 {
42
43 },
44 [&] // do_when_topo_building_set_node
45 {
46
47 });
48 */
49 static void invoke_phase_actions(PhaseAction&& do_when_direct_launch,
50 PhaseAction&& do_when_set_node,
51 PhaseAction&& do_when_topo_building_set_node);
52
53 // copy this code to use:
54 /*
55 ComputeGraphBuilder::invoke_phase_actions(
56 [&] // do_when_direct_launch
57 {
58
59 },
60 [&] // do_when_set_node and do_when_topo_building_set_node
61 {
62
63 });
64 */
65 static void invoke_phase_actions(PhaseAction&& do_when_direct_launch,
66 PhaseAction&& do_when_set_node);
67
68 // copy this code to use:
69 /*
70 ComputeGraphBuilder::invoke_phase_actions(
71 [&] // do_in_every_phase
72 {
73
74 });
75 */
76 static void invoke_phase_actions(PhaseAction&& do_in_every_phase);
77
78 private:
79 friend class ComputeGraph;
80 friend class ComputeGraphVarBase;
81
82 static void current_graph(ComputeGraph* graph);
84 static auto current_graph() { return instance().m_current_graph; }
85
86 ComputeGraphBuilder() = default;
87 ~ComputeGraphBuilder() = default;
88
89 ComputeGraph* m_current_graph = nullptr;
90};
91} // namespace muda
92
93#include "details/compute_graph_builder.inl"
Definition compute_graph_builder.h:10
Definition compute_graph.h:38
Definition compute_graph_var.h:17
Definition compute_graph_accessor.h:13