Skip to content

File compute_graph_var_manager.h

File List > compute_graph > compute_graph_var_manager.h

Go to the documentation of this file

#pragma once
#include <driver_types.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <memory>
#include <muda/mstl/span.h>
#include <muda/compute_graph/compute_graph_flag.h>
#include <muda/compute_graph/compute_graph_fwd.h>
#include <muda/compute_graph/graphviz_options.h>
namespace muda
{
class ComputeGraphVarManager
{
    template <typename T>
    using S = std::shared_ptr<T>;

  public:
    ComputeGraphVarManager() = default;
    ~ComputeGraphVarManager();

    S<ComputeGraph> create_graph(std::string_view name  = "graph",
                                 ComputeGraphFlag flags = {});


    /**************************************************************
    * 
    * GraphVar API
    * 
    ***************************************************************/
    template <typename T>
    ComputeGraphVar<T>& create_var(std::string_view name);
    template <typename T>
    ComputeGraphVar<T>& create_var(std::string_view name, const T& init_value);
    template <typename T>
    ComputeGraphVar<T>* find_var(std::string_view name);

    bool is_using() const;
    void sync() const;
    void sync_on(cudaStream_t stream) const;

    template <typename... T>
    bool is_using(const ComputeGraphVar<T>&... vars) const;
    template <typename... T>
    void sync(const ComputeGraphVar<T>&... vars) const;
    template <typename... T>
    void sync_on(cudaStream_t stream, const ComputeGraphVar<T>&... vars) const;

    bool is_using(const span<const ComputeGraphVarBase*> vars) const;
    void sync(const span<const ComputeGraphVarBase*> vars) const;
    void sync_on(cudaStream_t stream, const span<const ComputeGraphVarBase*> vars) const;

    const auto& graphs() const { return m_graphs; }
    void graphviz(std::ostream& os, const ComputeGraphGraphvizOptions& options = {}) const;

  private:
    friend class ComputeGraph;
    friend class ComputeGraphNodeBase;
    friend class ComputeGraphClosure;
    std::vector<ComputeGraph*> unique_graphs(span<const ComputeGraphVarBase*> vars) const;
    std::unordered_map<std::string, ComputeGraphVarBase*> m_vars_map;
    std::vector<ComputeGraphVarBase*>                     m_vars;
    std::unordered_set<ComputeGraph*>                     m_graphs;
    span<const ComputeGraphVarBase*>                      var_span() const;
};
}  // namespace muda

#include "details/compute_graph_var_manager.inl"