Skip to content

File compute_graph.h

File List > compute_graph > compute_graph.h

Go to the documentation of this file

#pragma once
#include <map>
#include <functional>
#include <set>
#include <muda/launch/stream.h>
#include <muda/launch/event.h>
#include <muda/mstl/span.h>
#include <muda/graph/graph.h>
#include <muda/graph/graph_viewer.h>
#include <muda/compute_graph/compute_graph_flag.h>
#include <muda/compute_graph/compute_graph_phase.h>
#include <muda/compute_graph/compute_graph_node_type.h>
#include <muda/compute_graph/compute_graph_node_id.h>
#include <muda/compute_graph/compute_graph_closure_id.h>
#include <muda/compute_graph/compute_graph_var_id.h>
#include <muda/compute_graph/compute_graph_var_usage.h>
#include <muda/compute_graph/compute_graph_dependency.h>
#include <muda/compute_graph/graphviz_options.h>
#include <muda/compute_graph/compute_graph_fwd.h>

namespace muda
{
namespace details
{
    class LocalVarId : public U64IdWithType
    {
        using U64IdWithType::U64IdWithType;
    };
    class LocalVarInfo
    {
      public:
        LocalVarId           id{};
        ComputeGraphVarBase* var = nullptr;
    };
}  // namespace details

class ComputeGraph
{
  public:
    class AddNodeProxy
    {
        ComputeGraph& m_cg;
        std::string   m_node_name;

      public:
        AddNodeProxy(ComputeGraph& cg, std::string_view node_name);
        ComputeGraph& operator<<(std::function<void()>&& f) &&;
    };
    // A depends on B : from B to A
    using Dependency = ComputeGraphDependency;

    class GraphPhaseGuard
    {
        ComputeGraph& m_cg;

      public:
        GraphPhaseGuard(ComputeGraph& cg, ComputeGraphPhase phase);
        ~GraphPhaseGuard();
    };

    // delete copy
    ComputeGraph(const ComputeGraph&)            = delete;
    ComputeGraph& operator=(const ComputeGraph&) = delete;

    // delete move
    ComputeGraph(ComputeGraph&&)            = delete;
    ComputeGraph& operator=(ComputeGraph&&) = delete;

  private:
    //class TempNodeInfo
    //{
    //  public:
    //    std::map<VarId, ComputeGraphVarUsage> var_usage;
    //};
    template <typename T>
    using U = std::unique_ptr<T>;
    template <typename T>
    using S = std::shared_ptr<T>;

    friend class ComputeGraphVarBase;

    Graph        m_graph;
    S<GraphExec> m_graph_exec{nullptr};

    std::unordered_map<NodeId::value_type, cudaGraph_t> m_sub_graphs;

    std::vector<std::pair<std::string, ComputeGraphClosure*>> m_closures;

    std::map<VarId, details::LocalVarId> m_global_to_local_var_id;
    std::vector<details::LocalVarInfo>   m_related_vars;
    void emplace_related_var(ComputeGraphVarBase* var);


    std::vector<ComputeGraphNodeBase*>              m_nodes;
    std::vector<std::vector<ComputeGraphNodeBase*>> m_graph_nodes;
    std::vector<Dependency>                         m_deps;

    std::vector<int>        m_closure_need_update;
    ComputeGraphVarManager* m_var_manager = nullptr;

    friend class ComputeGraphVarManager;

    Event                      m_event;
    mutable Event::QueryResult m_event_result = Event::QueryResult::eFinished;
    Flags<GraphInstantiateFlagBit> m_flags;

  public:
    ComputeGraph(ComputeGraphVarManager& manager,
                 std::string_view        name = "graph",
                 ComputeGraphFlag        flag = ComputeGraphFlag::HostLaunch);

    ~ComputeGraph();

    /**************************************************************
    * 
    * Info API
    * 
    ***************************************************************/

    std::string_view name() const { return m_name; }

    /**************************************************************
    * 
    * GraphNode API
    * 
    ***************************************************************/

    AddNodeProxy create_node(std::string_view node_name);


    /**************************************************************
    * 
    * Graph Launch API
    * 
    ***************************************************************/

    void update();

    void build();

    void launch(bool single_stream, cudaStream_t s = nullptr);

    void launch(cudaStream_t s = nullptr) { return launch(false, s); }

    /**************************************************************
    * 
    * Graph Event Query API
    * 
    ***************************************************************/

    Event::QueryResult query() const;

    /**************************************************************
    * 
    * Graph Closure Capture Node API
    * 
    ***************************************************************/

    void capture(std::function<void(cudaStream_t)>&& f);
    void capture(std::string_view name, std::function<void(cudaStream_t)>&& f);

    /**************************************************************
    * 
    * Graph Visualization API
    * 
    ***************************************************************/

    void graphviz(std::ostream& o, const ComputeGraphGraphvizOptions& options = {});

    /**************************************************************
    * 
    * Graph Viewer API
    * 
    ***************************************************************/

    GraphViewer viewer();

    operator GraphViewer() { return viewer(); }

  private:  // internal method
    void topo_build();

    void cuda_graph_add_deps();

    void build_deps();

    void serial_launch();

    void _update();

    void check_vars_valid();

    friend class AddNodeProxy;
    ComputeGraph& add_node(std::string&& name, const std::function<void()>& f);

    friend class ComputeGraphNodeBase;
    friend class ComputeGraphClosure;
    span<const Dependency> dep_span(size_t begin, size_t count) const;

    void set_current_graph_as_this();

    static void clear_current_graph();

    static Stream& shared_capture_stream();

    friend class ComputeGraphBuilder;
    ClosureId current_closure_id() const { return m_current_closure_id; };

    NodeId current_node_id() const { return m_current_node_id; };

    size_t current_access_index() const { return m_access_graph_index; }

    ComputeGraphPhase current_graph_phase() const;

  private:  // internal data
    friend class muda::details::ComputeGraphAccessor;
    std::string       m_name;
    bool              m_need_update = false;
    ClosureId         m_current_closure_id;
    NodeId            m_current_node_id;
    ComputeGraphPhase m_current_graph_phase = ComputeGraphPhase::None;
    bool              m_allow_access_graph  = false;
    size_t            m_access_graph_index  = 0;
    bool              m_allow_node_adding   = true;
    // TempNodeInfo      m_temp_node_info;
    cudaStream_t m_current_single_stream = nullptr;
    bool         m_is_capturing          = false;
    // in capture func, we don't allow any var eval()
    bool m_is_in_capture_func = false;
    // if we have already built the topo, we don't do that again
    bool m_is_topo_built = false;
};
}  // namespace muda

#include "details/compute_graph.inl"