Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
namespace = "FlexFlow"
name = "ParallelComputationGraph"
type = "struct"
features = [ ]
features = [
"json",
]

includes = [
"utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h",
Expand All @@ -10,6 +12,10 @@ includes = [
"op-attrs/tensor_slot_name.dtg.h",
]

src_includes = [
"utils/graph/labelled_kwarg_dataflow_graph/json.h",
]

[[fields]]
name = "raw_graph"
type = "::FlexFlow::LabelledKwargDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs, ::FlexFlow::TensorSlotName>"
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

includes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

includes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

includes = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H

#include "utils/containers/filter.h"
#include "utils/containers/generate_map.h"
#include "utils/containers/keys.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h"
#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h"
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h"
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/node/node_source.h"

namespace FlexFlow {

template <typename SlotName>
struct UnorderedSetKwargDataflowGraph final
: public IKwargDataflowGraph<SlotName> {
public:
UnorderedSetKwargDataflowGraph() = default;

KwargNodeAddedResult<SlotName> add_node(
std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> const &inputs,
std::unordered_set<SlotName> const &output_slots) override {
Node new_node = this->node_source.new_node();
this->nodes.insert(new_node);

for (auto const &[slot, src] : inputs) {
this->edges.insert(KwargDataflowEdge<SlotName>{
src,
KwargDataflowInput<SlotName>{new_node, slot},
});
}

std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> outputs =
generate_map(
output_slots,
[&](SlotName const &slot) -> KwargDataflowOutput<SlotName> {
KwargDataflowOutput<SlotName> out{new_node, slot};
this->outputs.insert(out);
return out;
});

return KwargNodeAddedResult<SlotName>{
/*node=*/new_node,
/*outputs=*/outputs,
};
}

void add_node_unsafe(
Node const &node,
std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> const &inputs,
std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> const
&outputs) override {
this->nodes.insert(node);

for (auto const &[slot, src] : inputs) {
this->edges.insert(KwargDataflowEdge<SlotName>{
src,
KwargDataflowInput<SlotName>{node, slot},
});
}

for (auto const &[slot, out] : outputs) {
this->outputs.insert(out);
}
}

std::unordered_set<Node> query_nodes(NodeQuery const &q) const override {
return filter(this->nodes,
[&](Node const &n) { return includes(q.nodes, n); });
}

std::unordered_set<KwargDataflowEdge<SlotName>>
query_edges(KwargDataflowEdgeQuery<SlotName> const &q) const override {
return filter(this->edges, [&](KwargDataflowEdge<SlotName> const &e) {
return kwarg_dataflow_edge_query_includes(q, e);
});
}

std::unordered_set<KwargDataflowOutput<SlotName>> query_outputs(
KwargDataflowOutputQuery<SlotName> const &q) const override {
return filter(this->outputs, [&](KwargDataflowOutput<SlotName> const &o) {
return kwarg_dataflow_output_query_includes(q, o);
});
}

void inplace_materialize_from(
KwargDataflowGraphView<SlotName> const &view) override {
this->nodes = get_nodes(view);
this->edges = get_all_kwarg_dataflow_edges(view);
this->outputs = get_all_kwarg_dataflow_outputs(view);
}

UnorderedSetKwargDataflowGraph *clone() const override {
return new UnorderedSetKwargDataflowGraph{
this->node_source,
this->nodes,
this->edges,
this->outputs,
};
}

private:
UnorderedSetKwargDataflowGraph(
NodeSource const &node_source,
std::unordered_set<Node> const &nodes,
std::unordered_set<KwargDataflowEdge<SlotName>> const &edges,
std::unordered_set<KwargDataflowOutput<SlotName>> const &outputs)
: node_source(node_source), nodes(nodes), edges(edges), outputs(outputs) {
}

private:
NodeSource node_source;
std::unordered_set<Node> nodes;
std::unordered_set<KwargDataflowEdge<SlotName>> edges;
std::unordered_set<KwargDataflowOutput<SlotName>> outputs;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetKwargDataflowGraph<int>);

} // namespace FlexFlow

#endif
45 changes: 45 additions & 0 deletions lib/utils/include/utils/graph/kwarg_dataflow_graph/json.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_JSON_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_JSON_H

#include "utils/graph/instances/unordered_set_kwarg_dataflow_graph.h"
#include "utils/graph/kwarg_dataflow_graph/algorithms/view_as_open_kwarg_dataflow_graph.h"
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h"
#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h"
#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h"
#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h"
#include "utils/json/check_is_json_deserializable.h"
#include "utils/json/check_is_json_serializable.h"
#include "utils/json/monostate.h"
#include <nlohmann/json.hpp>
#include <variant>

namespace nlohmann {

template <typename SlotName>
struct adl_serializer<::FlexFlow::KwargDataflowGraph<SlotName>> {
static ::FlexFlow::KwargDataflowGraph<SlotName> from_json(json const &j) {
CHECK_IS_JSON_DESERIALIZABLE(
::FlexFlow::OpenKwargDataflowGraphData<std::monostate, SlotName>);

auto data = j.template get<
::FlexFlow::OpenKwargDataflowGraphData<std::monostate, SlotName>>();
::FlexFlow::OpenKwargDataflowGraphView<std::monostate, SlotName> view =
::FlexFlow::view_from_open_kwarg_dataflow_graph_data(data);
return ::FlexFlow::KwargDataflowGraph<SlotName>::template create_copy_of<
::FlexFlow::UnorderedSetKwargDataflowGraph<SlotName>>(view);
}

static void to_json(json &j,
::FlexFlow::KwargDataflowGraph<SlotName> const &g) {
CHECK_IS_JSON_SERIALIZABLE(
::FlexFlow::OpenKwargDataflowGraphData<std::monostate, SlotName>);

::FlexFlow::OpenKwargDataflowGraphView<std::monostate, SlotName> open_view =
::FlexFlow::view_as_open_kwarg_dataflow_graph<std::monostate>(g);
j = ::FlexFlow::get_open_kwarg_dataflow_graph_data(open_view);
}
};

} // namespace nlohmann

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

template_params = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

template_params = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"ord",
"hash",
"fmt",
"json",
]

template_params = [
Expand Down
49 changes: 49 additions & 0 deletions lib/utils/include/utils/graph/labelled_dataflow_graph/json.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_JSON_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_JSON_H

#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h"
#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h"
#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h"
#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h"
#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h"
#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h"
#include "utils/json/check_is_json_deserializable.h"
#include "utils/json/check_is_json_serializable.h"
#include <nlohmann/json.hpp>

namespace nlohmann {

template <typename NodeLabel, typename OutputLabel>
struct adl_serializer<
::FlexFlow::LabelledDataflowGraph<NodeLabel, OutputLabel>> {
static ::FlexFlow::LabelledDataflowGraph<NodeLabel, OutputLabel>
from_json(json const &j) {
CHECK_IS_JSON_DESERIALIZABLE(
::FlexFlow::LabelledOpenDataflowGraphData<NodeLabel, OutputLabel>);

auto data = j.template get<
::FlexFlow::LabelledOpenDataflowGraphData<NodeLabel, OutputLabel>>();
::FlexFlow::LabelledOpenDataflowGraphView<NodeLabel, OutputLabel>
open_view = ::FlexFlow::from_labelled_open_dataflow_graph_data(data);
return ::FlexFlow::LabelledDataflowGraph<NodeLabel, OutputLabel>::
template create_copy_of<
::FlexFlow::UnorderedSetLabelledOpenDataflowGraph<NodeLabel,
OutputLabel>>(
open_view);
}

static void to_json(
json &j,
::FlexFlow::LabelledDataflowGraph<NodeLabel, OutputLabel> const &g) {
CHECK_IS_JSON_SERIALIZABLE(
::FlexFlow::LabelledOpenDataflowGraphData<NodeLabel, OutputLabel>);

::FlexFlow::LabelledOpenDataflowGraphView<NodeLabel, OutputLabel>
open_view = ::FlexFlow::view_as_labelled_open_dataflow_graph(g);
j = ::FlexFlow::get_graph_data(open_view);
}
};

} // namespace nlohmann

#endif
77 changes: 77 additions & 0 deletions lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/json.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_JSON_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_JSON_H

#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h"
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h"
#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h"
#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h"
#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/view_from_labelled_open_kwarg_dataflow_graph_data.h"
#include "utils/json/check_is_json_deserializable.h"
#include "utils/json/check_is_json_serializable.h"
#include "utils/json/monostate.h"
#include <nlohmann/json.hpp>
#include <variant>

namespace nlohmann {

template <typename NodeLabel, typename OutputLabel, typename SlotName>
struct adl_serializer<
::FlexFlow::LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>> {
static ::FlexFlow::
LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>
from_json(json const &j) {
CHECK_IS_JSON_DESERIALIZABLE(
::FlexFlow::LabelledOpenKwargDataflowGraphData<NodeLabel,
OutputLabel,
std::monostate,
SlotName>);

auto data = j.template get<
::FlexFlow::LabelledOpenKwargDataflowGraphData<NodeLabel,
OutputLabel,
std::monostate,
SlotName>>();
::FlexFlow::LabelledOpenKwargDataflowGraphView<NodeLabel,
OutputLabel,
std::monostate,
SlotName>
open_view =
::FlexFlow::view_from_labelled_open_kwarg_dataflow_graph_data(data);
return ::FlexFlow::
LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>::
template create_copy_of<
::FlexFlow::UnorderedSetLabelledOpenKwargDataflowGraph<
NodeLabel,
OutputLabel,
std::monostate,
SlotName>>(open_view);
}

static void
to_json(json &j,
::FlexFlow::LabelledKwargDataflowGraph<NodeLabel,
OutputLabel,
SlotName> const &g) {
CHECK_IS_JSON_SERIALIZABLE(
::FlexFlow::LabelledOpenKwargDataflowGraphData<NodeLabel,
OutputLabel,
std::monostate,
SlotName>);

::FlexFlow::LabelledOpenKwargDataflowGraphView<NodeLabel,
OutputLabel,
std::monostate,
SlotName>
open_view = ::FlexFlow::view_as_labelled_open_kwarg_dataflow_graph<
NodeLabel,
OutputLabel,
std::monostate,
SlotName>(g);
j = ::FlexFlow::get_labelled_open_kwarg_dataflow_graph_data(open_view);
}
};

} // namespace nlohmann

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "utils/containers/filtrans.h"
#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h"
#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h"
#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h"
#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h"
#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ features = [
"eq",
"hash",
"fmt",
"json",
]

template_params = ["NodeLabel", "ValueLabel"]
Expand Down
Loading
Loading