From e848c0206b78878cfaacab2e86d9ddbfe3e4cb1b Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 11:31:04 -0800 Subject: [PATCH 01/88] Add make_dynamic_open_dataflow_graph_from_pcg. --- .../parallel_computation_graph.h | 6 ++ .../parallel_computation_graph.cc | 21 +++++ ...ake_dynamic_open_dataflow_graph_from_pcg.h | 14 ++++ ...ke_dynamic_open_dataflow_graph_from_pcg.cc | 77 +++++++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 25dc0721cd..3d948ac107 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -54,6 +54,9 @@ std::unordered_map std::unordered_set get_initial_layers(ParallelComputationGraph const &); +std::unordered_map + get_outgoing_tensors(ParallelComputationGraph const &, + parallel_layer_guid_t const &); std::unordered_map get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -107,6 +110,9 @@ ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, std::vector topological_ordering(ParallelComputationGraph const &); +std::unordered_map + get_parallel_layer_attrs_mapping(ParallelComputationGraph const &pcg); + parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index f83628b8e1..907dc05620 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -212,6 +212,16 @@ std::unordered_set [](Node const &n) { return parallel_layer_guid_t{n}; }); } +std::unordered_map + get_outgoing_tensors(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return map_values(get_outgoing_kwarg_dataflow_outputs_for_node( + pcg.raw_graph, l.raw_graph_node), + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + std::unordered_map get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -378,6 +388,17 @@ std::vector [](Node const &n) { return parallel_layer_guid_t{n}; }); } +std::unordered_map + get_parallel_layer_attrs_mapping(ParallelComputationGraph const &pcg) { + std::unordered_map + layer_attrs_mapping; + for (parallel_layer_guid_t const &layer_guid : get_parallel_layers(pcg)) { + layer_attrs_mapping.insert( + {layer_guid, get_parallel_layer_attrs(pcg, layer_guid)}); + } + return layer_attrs_mapping; +} + parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name) { diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h new file mode 100644 index 0000000000..a71eb558c1 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph + make_dynamic_open_dataflow_graph_from_pcg(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc new file mode 100644 index 0000000000..841be27dfd --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc @@ -0,0 +1,77 @@ +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/containers/generate_map.h" +#include +#include +#include + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( + ParallelComputationGraph const &pcg) { + DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); + + for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/attrs.op_attrs, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + std::unordered_map result_inputs = + transform(get_incoming_tensors(pcg, layer), + [&](TensorSlotName const &slot_name, + parallel_tensor_guid_t const &tensor) { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + return std::pair{ + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }; + }); + std::unordered_map result_outputs = + transform(get_outgoing_tensors(pcg, layer), + [&](TensorSlotName const &slot_name, + parallel_tensor_guid_t const &tensor) { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + return std::pair{ + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }; + }); + + result.invocations.emplace(result_inputs, result_attrs, result_outputs); + } + + return result; +} + +} // namespace FlexFlow From 40c560952b723e307cc7650abe61959e86927739 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:10:13 -0800 Subject: [PATCH 02/88] Empty skeleton of the realm-execution backend. --- .proj.toml | 7 +++++++ lib/CMakeLists.txt | 1 + lib/realm-execution/CMakeLists.txt | 21 +++++++++++++++++++ .../parallel_computation_graph_instance.h | 12 +++++++++++ .../parallel_computation_graph_instance.cc | 1 + lib/realm-execution/test/CMakeLists.txt | 15 +++++++++++++ .../test/src/realm-execution/test_e2e.cc | 9 ++++++++ 7 files changed, 66 insertions(+) create mode 100644 lib/realm-execution/CMakeLists.txt create mode 100644 lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h create mode 100644 lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc create mode 100644 lib/realm-execution/test/CMakeLists.txt create mode 100644 lib/realm-execution/test/src/realm-execution/test_e2e.cc diff --git a/.proj.toml b/.proj.toml index 38690f710b..5dbbfbcdd7 100644 --- a/.proj.toml +++ b/.proj.toml @@ -85,6 +85,13 @@ has-cpu-only-benchmarks = false has-cuda-tests = true has-cuda-benchmarks = false +[targets.realm-execution] +type = "lib" +has-cpu-only-tests = true +has-cpu-only-benchmarks = false +has-cuda-tests = true +has-cuda-benchmarks = false + # [targets.local-pcg-execution] # type = "lib" # has-cpu-only-tests = true diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 2e71e577c0..cb3bd6d6ae 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(op-attrs) add_subdirectory(kernels) add_subdirectory(local-execution) add_subdirectory(local-pcg-execution) +add_subdirectory(realm-execution) add_subdirectory(task-spec) add_subdirectory(utils) add_subdirectory(ffi) diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt new file mode 100644 index 0000000000..7a38f70607 --- /dev/null +++ b/lib/realm-execution/CMakeLists.txt @@ -0,0 +1,21 @@ +ff_add_library( + NAME + realm-execution + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE + include/ + PRIVATE_INCLUDE + src/ + DEPS + op-attrs + utils + kernels + task-spec + pcg + spdlog + compiler + local-execution +) + +add_subdirectory(test) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h new file mode 100644 index 0000000000..58cc5234d9 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H + +namespace FlexFlow { + +struct ParallelComputationGraphInstance { + public: +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc new file mode 100644 index 0000000000..a22f4730b7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -0,0 +1 @@ +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" diff --git a/lib/realm-execution/test/CMakeLists.txt b/lib/realm-execution/test/CMakeLists.txt new file mode 100644 index 0000000000..b3beff42c0 --- /dev/null +++ b/lib/realm-execution/test/CMakeLists.txt @@ -0,0 +1,15 @@ +ff_add_test_executable( + NAME + realm-execution-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + realm-execution + kernels + op-attrs + task-spec +) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc new file mode 100644 index 0000000000..55dfe427d5 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -0,0 +1,9 @@ +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training") { + } +} From 14c1b94e106b2ebc49bc2f0fbc2723a98e387146 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:25:52 -0800 Subject: [PATCH 03/88] More Realm execution skeleton. --- .../parallel_computation_graph_instance.h | 52 ++++++++++++++++++- .../parallel_computation_graph_instance.cc | 45 ++++++++++++++++ .../test/src/realm-execution/test_e2e.cc | 3 +- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 58cc5234d9..b0529761c1 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -1,12 +1,62 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#include "kernels/accessor.h" +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" +#include "utils/units/milliseconds_t.h" +#include + namespace FlexFlow { struct ParallelComputationGraphInstance { - public: +public: + ParallelComputationGraphInstance(DynamicOpenDataflowGraph, + Allocator &, + std::vector const &, + OptimizerAttrs const &, + std::optional const &, + std::optional); + DynamicOpenDataflowGraph const &get_dynamic_dataflow_graph() const; + Allocator &get_allocator() const; + std::vector const &get_topological_ordering() const; + OptimizerAttrs const &get_optimizer_attrs() const; + void update_optimizer_attrs_for_next_iter(); + std::optional const &get_loss_attrs() const; + std::optional get_loss_tensor_accessor() const; + +private: + DynamicOpenDataflowGraph dataflow_graph; + Allocator &allocator; + std::vector topological_ordering; + OptimizerAttrs optimizer_attrs; + std::optional loss_attrs; + std::optional logit_grad_tensor; }; +ParallelComputationGraphInstance create_parallel_computation_graph_instance( + ParallelComputationGraph const &pcg, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional label_tensor, + std::optional logit_tensor, + std::unordered_map const + &input_tensors, + Allocator &allocator, + ProfilingSettings const &profiling_settings, + device_handle_t const &device_handle, + FFIterationConfig const &iteration_config, + device_id_t device_idx); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index a22f4730b7..2f001a2975 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1 +1,46 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "pcg/optimizer_attrs.h" + +namespace FlexFlow { + +ParallelComputationGraphInstance::ParallelComputationGraphInstance( + DynamicOpenDataflowGraph dataflow_graph, + Allocator &allocator, + std::vector const &topological_ordering, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional logit_grad_tensor) + : dataflow_graph(dataflow_graph), allocator(allocator), + topological_ordering(topological_ordering), + optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), + logit_grad_tensor(logit_grad_tensor) {} + +DynamicOpenDataflowGraph const & + ParallelComputationGraphInstance::get_dynamic_dataflow_graph() const { + return this->dataflow_graph; +} +Allocator &ParallelComputationGraphInstance::get_allocator() const { + return this->allocator; +} +std::vector const & + ParallelComputationGraphInstance::get_topological_ordering() const { + return this->topological_ordering; +} +OptimizerAttrs const & + ParallelComputationGraphInstance::get_optimizer_attrs() const { + return this->optimizer_attrs; +} +void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { + this->optimizer_attrs = + get_optimizer_attrs_for_next_iter(this->optimizer_attrs); +} +std::optional const & + ParallelComputationGraphInstance::get_loss_attrs() const { + return this->loss_attrs; +} +std::optional + ParallelComputationGraphInstance::get_loss_tensor_accessor() const { + return this->logit_grad_tensor; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 55dfe427d5..78a57fb99f 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -4,6 +4,5 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("RealmBackend e2e Training") { - } + TEST_CASE("RealmBackend e2e Training") {} } From a9a365d0822bf2aa1b302029ea23201c67c9ba25 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:59:55 -0800 Subject: [PATCH 04/88] Stub creation. --- .../parallel_computation_graph_instance.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 2f001a2975..29683c4dba 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,5 +1,6 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include "pcg/optimizer_attrs.h" +#include "utils/exception.h" namespace FlexFlow { @@ -43,4 +44,20 @@ std::optional return this->logit_grad_tensor; } +ParallelComputationGraphInstance create_parallel_computation_graph_instance( + ParallelComputationGraph const &pcg, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional label_tensor, + std::optional logit_tensor, + std::unordered_map const + &input_tensors, + Allocator &allocator, + ProfilingSettings const &profiling_settings, + device_handle_t const &device_handle, + FFIterationConfig const &iteration_config, + device_id_t device_idx) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow From 788300cd76be540c1ff44e2a5c1e6056d7df1ecb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 15:17:00 -0800 Subject: [PATCH 05/88] More passes. --- .../parallel_computation_graph_instance.cc | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 29683c4dba..8f878c90d8 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,5 +1,12 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "local-execution/device_state_initialization.h" +#include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/loss_insertion.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" namespace FlexFlow { @@ -44,6 +51,15 @@ std::optional return this->logit_grad_tensor; } +static GenericTensorAccessorW + get_loss_tensor_accessor(DynamicOpenDataflowGraph const &dg, + DynamicValueAttrs const &value) { + return find_output_tensor(dg, value.tensor_guid, value.role) + .value() + .second.accessor.value() + .get(); +} + ParallelComputationGraphInstance create_parallel_computation_graph_instance( ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, @@ -57,6 +73,36 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( device_handle_t const &device_handle, FFIterationConfig const &iteration_config, device_id_t device_idx) { + + DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); + dg = perform_pass_expansion(dg); + + std::unordered_map inputs = + input_tensors; + std::optional logit_grad_value; + if (loss_attrs) { + auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( + dg, assert_unwrap(loss_attrs), assert_unwrap(logit_tensor)); + dg = dg2; + logit_grad_value = logit_grad_v; + inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); + } + + dg = perform_update_insertion(dg, optimizer_attrs); + dg = perform_tensor_allocation(dg, inputs, allocator); + + std::optional logit_grad_tensor = + transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { + return get_loss_tensor_accessor(dg, lgv); + }); + + dg = perform_device_state_initialization(dg, + allocator, + profiling_settings, + device_handle, + iteration_config, + optimizer_attrs, + device_idx); NOT_IMPLEMENTED(); } From 762827109b8e6ecdd306e90141fc55d812bc978f Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 16:39:30 -0800 Subject: [PATCH 06/88] Add Realm manager and test it. --- lib/realm-execution/CMakeLists.txt | 11 ++++---- .../include/realm-execution/realm_manager.h | 27 +++++++++++++++++++ .../src/realm-execution/realm_manager.cc | 22 +++++++++++++++ .../test/src/realm-execution/test_e2e.cc | 10 ++++++- 4 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_manager.h create mode 100644 lib/realm-execution/src/realm-execution/realm_manager.cc diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt index 7a38f70607..0a1b681b8d 100644 --- a/lib/realm-execution/CMakeLists.txt +++ b/lib/realm-execution/CMakeLists.txt @@ -8,14 +8,15 @@ ff_add_library( PRIVATE_INCLUDE src/ DEPS - op-attrs - utils + compiler kernels - task-spec + local-execution + op-attrs pcg + realm spdlog - compiler - local-execution + task-spec + utils ) add_subdirectory(test) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h new file mode 100644 index 0000000000..a08668e6cc --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H + +#include "realm.h" + +namespace FlexFlow { + +struct RealmManager { +public: + RealmManager(int *argc, char ***argv); + + RealmManager() = delete; + RealmManager(RealmManager const &) = delete; + RealmManager(RealmManager &&) = delete; + + Realm::Runtime get_runtime(); + void shutdown(); + int wait_for_shutdown(); + +private: + Realm::Runtime runtime; + Realm::Event last_event = Realm::Event::NO_EVENT; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc new file mode 100644 index 0000000000..5a085bc04b --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -0,0 +1,22 @@ +#include "realm-execution/realm_manager.h" +#include "utils/exception.h" + +namespace FlexFlow { + +RealmManager::RealmManager(int *argc, char ***argv) { + bool ok = this->runtime.init(argc, argv); + ASSERT(ok); +} + +Realm::Runtime RealmManager::get_runtime() { + return this->runtime; +} + +void RealmManager::shutdown() { + this->runtime.shutdown(this->last_event); +} + +int RealmManager::wait_for_shutdown() { + return this->runtime.wait_for_shutdown(); +} +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 78a57fb99f..947a02e6be 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,8 +1,16 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/realm_manager.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("RealmBackend e2e Training") {} + TEST_CASE("RealmBackend e2e Training") { + char fake_executable_name[] = "fake_executable_name"; + std::vector fake_args{fake_executable_name}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + RealmManager manager(&fake_argc, &fake_argv); + manager.shutdown(); + } } From d2b3f01c8be7a13621124416700becb08102f59e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 16:45:57 -0800 Subject: [PATCH 07/88] Do not expose raw runtime and properly wait in test. --- lib/realm-execution/include/realm-execution/realm_manager.h | 1 - lib/realm-execution/src/realm-execution/realm_manager.cc | 5 +---- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 2 ++ 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index a08668e6cc..f9fa9f7de7 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -13,7 +13,6 @@ struct RealmManager { RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - Realm::Runtime get_runtime(); void shutdown(); int wait_for_shutdown(); diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 5a085bc04b..014a16718a 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -8,10 +8,6 @@ RealmManager::RealmManager(int *argc, char ***argv) { ASSERT(ok); } -Realm::Runtime RealmManager::get_runtime() { - return this->runtime; -} - void RealmManager::shutdown() { this->runtime.shutdown(this->last_event); } @@ -19,4 +15,5 @@ void RealmManager::shutdown() { int RealmManager::wait_for_shutdown() { return this->runtime.wait_for_shutdown(); } + } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 947a02e6be..b88807e079 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -12,5 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); manager.shutdown(); + int result = manager.wait_for_shutdown(); + ASSERT(result == 0); } } From 37f1d209f66034dc7f5e8d5531d5ea5b4fbbbc09 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 17:11:02 -0800 Subject: [PATCH 08/88] Sketch more Realm manager APIs. --- .../parallel_computation_graph_instance.h | 13 ++++++------ .../include/realm-execution/realm_manager.h | 8 +++++++ .../parallel_computation_graph_instance.cc | 21 +++++++++---------- .../src/realm-execution/realm_manager.cc | 11 ++++++++++ 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index b0529761c1..4ba77a7925 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "realm-execution/realm_manager.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -20,8 +21,8 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(DynamicOpenDataflowGraph, - Allocator &, + ParallelComputationGraphInstance(RealmManager &, + DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, std::optional const &, @@ -35,8 +36,8 @@ struct ParallelComputationGraphInstance { std::optional get_loss_tensor_accessor() const; private: + RealmManager &realm; DynamicOpenDataflowGraph dataflow_graph; - Allocator &allocator; std::vector topological_ordering; OptimizerAttrs optimizer_attrs; std::optional loss_attrs; @@ -44,6 +45,7 @@ struct ParallelComputationGraphInstance { }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( + RealmManager &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -51,11 +53,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_tensor, std::unordered_map const &input_tensors, - Allocator &allocator, ProfilingSettings const &profiling_settings, - device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, - device_id_t device_idx); + FFIterationConfig const &iteration_config); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index f9fa9f7de7..9261bc91f4 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "pcg/device_id_t.dtg.h" #include "realm.h" namespace FlexFlow { @@ -16,6 +19,11 @@ struct RealmManager { void shutdown(); int wait_for_shutdown(); + Allocator &get_current_device_allocator() const; + + device_handle_t const &get_current_device_handle() const; + device_id_t const &get_current_device_idx() const; + private: Realm::Runtime runtime; Realm::Event last_event = Realm::Event::NO_EVENT; diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 8f878c90d8..64c9da2f4c 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -12,13 +12,13 @@ namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( + RealmManager &realm, DynamicOpenDataflowGraph dataflow_graph, - Allocator &allocator, std::vector const &topological_ordering, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional logit_grad_tensor) - : dataflow_graph(dataflow_graph), allocator(allocator), + : realm(realm), dataflow_graph(dataflow_graph), topological_ordering(topological_ordering), optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), logit_grad_tensor(logit_grad_tensor) {} @@ -28,7 +28,7 @@ DynamicOpenDataflowGraph const & return this->dataflow_graph; } Allocator &ParallelComputationGraphInstance::get_allocator() const { - return this->allocator; + return this->realm.get_current_device_allocator(); } std::vector const & ParallelComputationGraphInstance::get_topological_ordering() const { @@ -61,6 +61,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( + RealmManager &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -68,11 +69,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_tensor, std::unordered_map const &input_tensors, - Allocator &allocator, ProfilingSettings const &profiling_settings, - device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, - device_id_t device_idx) { + FFIterationConfig const &iteration_config) { DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); dg = perform_pass_expansion(dg); @@ -89,7 +87,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( } dg = perform_update_insertion(dg, optimizer_attrs); - dg = perform_tensor_allocation(dg, inputs, allocator); + dg = perform_tensor_allocation( + dg, inputs, realm.get_current_device_allocator()); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -97,12 +96,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( }); dg = perform_device_state_initialization(dg, - allocator, + realm.get_current_device_allocator(), profiling_settings, - device_handle, + realm.get_current_device_handle(), iteration_config, optimizer_attrs, - device_idx); + realm.get_current_device_idx()); NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 014a16718a..b136b4c379 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -16,4 +16,15 @@ int RealmManager::wait_for_shutdown() { return this->runtime.wait_for_shutdown(); } +Allocator &RealmManager::get_current_device_allocator() const { + NOT_IMPLEMENTED(); +} + +device_handle_t const &RealmManager::get_current_device_handle() const { + NOT_IMPLEMENTED(); +} +device_id_t const &RealmManager::get_current_device_idx() const { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow From 1fe90c1d7557d4724bc16d591dd9c2377f6145b4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 09:57:00 -0800 Subject: [PATCH 09/88] Add controller functionality. --- .../include/realm-execution/realm_manager.h | 17 ++++-- .../src/realm-execution/realm_manager.cc | 60 +++++++++++++++++-- .../test/src/realm-execution/test_e2e.cc | 4 +- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 9261bc91f4..497a1f3958 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -11,22 +11,31 @@ namespace FlexFlow { struct RealmManager { public: RealmManager(int *argc, char ***argv); + ~RealmManager(); RealmManager() = delete; RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - void shutdown(); - int wait_for_shutdown(); + Realm::Event start_controller(void (*thunk)(RealmManager &)); + // Current device context Allocator &get_current_device_allocator() const; - device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; +private: + RealmManager(void const *, size_t, void const *, size_t, Realm::Processor); + + [[nodiscard]] Realm::Event merge_outstanding_events(); + + static void controller_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor); + private: Realm::Runtime runtime; - Realm::Event last_event = Realm::Event::NO_EVENT; + std::vector outstanding_events; + bool is_root_runtime; }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index b136b4c379..acc11936c7 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -3,17 +3,48 @@ namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) { +RealmManager::RealmManager(int *argc, char ***argv) : is_root_runtime(true) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); } -void RealmManager::shutdown() { - this->runtime.shutdown(this->last_event); +RealmManager::RealmManager(void const *args, + size_t arglen, + void const *userdata, + size_t userdatalen, + Realm::Processor proc) + : runtime(Realm::Runtime::get_runtime()), is_root_runtime(false) {} + +RealmManager::~RealmManager() { + Realm::Event outstanding = this->merge_outstanding_events(); + if (is_root_runtime) { + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); + } else { + outstanding.wait(); + } } -int RealmManager::wait_for_shutdown() { - return this->runtime.wait_for_shutdown(); +Realm::Event RealmManager::start_controller(void (*thunk)(RealmManager &)) { + constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; + Realm::Event task_ready = Realm::Processor::register_task_by_kind( + Realm::Processor::LOC_PROC, + /*global=*/false, + CONTROLLER_TASK_ID, + Realm::CodeDescriptor(RealmManager::controller_task_wrapper), + Realm::ProfilingRequestSet(), + &thunk, + sizeof(thunk)); + + Realm::Processor target_proc = + Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) + .only_kind(Realm::Processor::LOC_PROC) + .first(); + + Realm::Event task_complete = this->runtime.collective_spawn( + target_proc, CONTROLLER_TASK_ID, &thunk, sizeof(thunk), task_ready); + this->outstanding_events.push_back(task_complete); + return task_complete; } Allocator &RealmManager::get_current_device_allocator() const { @@ -27,4 +58,23 @@ device_id_t const &RealmManager::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event RealmManager::merge_outstanding_events() { + Realm::Event result = Realm::Event::merge_events(this->outstanding_events); + this->outstanding_events.clear(); + return result; +} + +void RealmManager::controller_task_wrapper(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + assert(arglen == sizeof(void (*)(RealmManager &))); + void (*thunk)(RealmManager &) = + *reinterpret_cast(const_cast(args)); + + RealmManager manager(args, arglen, userdata, userlen, proc); + thunk(manager); +} + } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index b88807e079..f09951e73c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,8 +11,6 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - manager.shutdown(); - int result = manager.wait_for_shutdown(); - ASSERT(result == 0); + manager.start_controller([](RealmManager &manager) {}); } } From 150d9f4073dace79c4f43c7ec6f85fd0c2b2101e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 12:29:19 -0800 Subject: [PATCH 10/88] Fix Realm tests. --- .flake/pkgs/legion.nix | 48 ------------------- .flake/pkgs/realm.nix | 44 +++++++++++++++++ flake.nix | 21 ++++---- lib/realm-execution/CMakeLists.txt | 2 +- .../src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/realm_manager.cc | 22 +++++++++ 6 files changed, 78 insertions(+), 61 deletions(-) delete mode 100644 .flake/pkgs/legion.nix create mode 100644 .flake/pkgs/realm.nix create mode 100644 lib/realm-execution/test/src/realm-execution/realm_manager.cc diff --git a/.flake/pkgs/legion.nix b/.flake/pkgs/legion.nix deleted file mode 100644 index 361a66c4ff..0000000000 --- a/.flake/pkgs/legion.nix +++ /dev/null @@ -1,48 +0,0 @@ -{ lib -, stdenv -, fetchFromGitLab -, cmake -, cudaPackages ? { } -, cudaCapabilities ? [ "60" "70" "80" "86" ] -, maxDim ? 5 -}: - -# from https://codeberg.org/Uli/nix-things/src/commit/776519e382c81b136c1d0b10d8c7b52b4acb9192/overlays/cq/python/libclang-python.nix - -let - cmakeFlag = x: if x then "1" else "0"; - - inherit (cudaPackages) cudatoolkit; -in - -stdenv.mkDerivation rec { - pname = "legion"; - version = "2025-01-06"; - - src = fetchFromGitLab { - owner = "StanfordLegion"; - repo = "legion"; - rev = "7be1abd0207eb1126c7629b16d1123fa6f58ce9d"; - sha256 = "sha256-gTjnGYYTQwTsrV1WcY0qqpTrlwbzAPcndurRy6XnG8A="; - }; - - nativeBuildInputs = [ - cmake - ]; - - cmakeFlags = [ - "-DLegion_USE_CUDA=1" - "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" - "-DLegion_MAX_DIM=${toString maxDim}" - ]; - - buildInputs = [ - cudatoolkit - ]; - - meta = with lib; { - description = "Legion is a parallel programming model for distributed, heterogeneous machines"; - homepage = "https://legion.stanford.edu/"; - license = licenses.asl20; - }; -} diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix new file mode 100644 index 0000000000..1249c0ae28 --- /dev/null +++ b/.flake/pkgs/realm.nix @@ -0,0 +1,44 @@ +{ lib +, stdenv +, fetchFromGitHub +, cmake +, cudaPackages ? { } +, maxDim ? 5 +}: + +let + inherit (cudaPackages) cudatoolkit; +in + +stdenv.mkDerivation rec { + pname = "realm"; + version = "2025-01-06"; + + # This version is compatible with Legion 7be1abd0207eb1126c7629b16d1123fa6f58ce9d + src = fetchFromGitHub { + owner = "StanfordLegion"; + repo = "realm"; + rev = "0ef7edc8c012d4ab6a50805c044cec8a8edeae33"; + sha256 = "sha256-57/a1lAgs+ajpRn0y0Lk1gP5nKt+N08WW0DIJP4vdho="; + }; + + nativeBuildInputs = [ + cmake + ]; + + cmakeFlags = [ + "-DBUILD_SHARED_LIBS=ON" + "-DREALM_ENABLE_CUDA=ON" + "-DREALM_MAX_DIM=${toString maxDim}" + ]; + + buildInputs = [ + cudatoolkit + ]; + + meta = with lib; { + description = "Realm is a distributed, event–based tasking runtime for building high-performance applications that span clusters of CPUs, GPUs, and other accelerators"; + homepage = "https://legion.stanford.edu/realm"; + license = licenses.asl20; + }; +} diff --git a/flake.nix b/flake.nix index 6ccd5616cd..dad0e2fc32 100644 --- a/flake.nix +++ b/flake.nix @@ -30,8 +30,8 @@ }; }; - outputs = { self, nixpkgs, flake-utils, proj-repo, nixGL, ... }: flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: - let + outputs = { self, nixpkgs, flake-utils, proj-repo, nixGL, ... }: flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: + let pkgs = import nixpkgs { inherit system; config.allowUnfree = true; @@ -41,21 +41,21 @@ mkShell = attrs: pkgs.mkShell.override { stdenv = pkgs.cudaPackages.backendStdenv; } (attrs // { - hardeningDisable = ["all"]; # disable nixpkgs default compiler arguments, otherwise ubsan doesn't catch - # signed overflows due to the signedoverflow hardening setting. - # for more details, see the following (long-running) nixpkgs github issues: + hardeningDisable = ["all"]; # disable nixpkgs default compiler arguments, otherwise ubsan doesn't catch + # signed overflows due to the signedoverflow hardening setting. + # for more details, see the following (long-running) nixpkgs github issues: # - https://github.com/NixOS/nixpkgs/issues/18995 # - https://github.com/NixOS/nixpkgs/issues/60919 }); proj = proj-repo.packages.${system}.proj; - in + in { packages = rec { libdwarf-lite = pkgs.callPackage ./.flake/pkgs/libdwarf-lite.nix { }; cpptrace = pkgs.callPackage ./.flake/pkgs/cpptrace.nix { inherit libdwarf-lite; }; libassert = pkgs.callPackage ./.flake/pkgs/libassert.nix { inherit cpptrace; }; - legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + realm = pkgs.callPackage ./.flake/pkgs/realm.nix { }; bencher-cli = pkgs.callPackage ./.flake/pkgs/bencher-cli.nix { }; ffdb = pkgs.callPackage ./.flake/pkgs/ffdb { inherit proj; }; hpp2plantuml = pkgs.python3Packages.callPackage ./.flake/pkgs/hpp2plantuml.nix { }; @@ -83,8 +83,7 @@ shellHook = '' export PATH="$HOME/ff/.scripts/:$PATH" export RC_PARAMS="max_discard_ratio=100" - export CMAKE_FLAGS="-DFF_USE_EXTERNAL_LEGION=ON \ - -DFF_USE_EXTERNAL_NCCL=ON \ + export CMAKE_FLAGS="-DFF_USE_EXTERNAL_NCCL=ON \ -DFF_USE_EXTERNAL_JSON=ON \ -DFF_USE_EXTERNAL_FMT=ON \ -DFF_USE_EXTERNAL_SPDLOG=ON \ @@ -94,7 +93,7 @@ -DFF_USE_EXTERNAL_GBENCHMARK=ON \ -DFF_USE_EXTERNAL_LIBASSERT=ON" ''; - + buildInputs = builtins.concatLists [ (with pkgs; [ zlib @@ -125,7 +124,7 @@ ]) (with self.packages.${system}; [ libassert - legion + realm rapidcheckFull doctest ]) diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt index 0a1b681b8d..08676525e1 100644 --- a/lib/realm-execution/CMakeLists.txt +++ b/lib/realm-execution/CMakeLists.txt @@ -13,10 +13,10 @@ ff_add_library( local-execution op-attrs pcg - realm spdlog task-spec utils + Realm::Realm ) add_subdirectory(test) diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index acc11936c7..33e7ca252e 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -71,7 +71,7 @@ void RealmManager::controller_task_wrapper(void const *args, Realm::Processor proc) { assert(arglen == sizeof(void (*)(RealmManager &))); void (*thunk)(RealmManager &) = - *reinterpret_cast(const_cast(args)); + *reinterpret_cast(args); RealmManager manager(args, arglen, userdata, userlen, proc); thunk(manager); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc new file mode 100644 index 0000000000..880268c018 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -0,0 +1,22 @@ +#include "realm-execution/realm_manager.h" +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmManager") { + // Construct some fake command line for our test + char fake_executable_name[] = "fake_executable_name"; + std::vector fake_args{fake_executable_name}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + // Initialize Realm + RealmManager manager(&fake_argc, &fake_argv); + + // Launch a controller and wait on it + Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + event.wait(); + } +} From 9fcc76e493952c88f5816e33623459ea7733a8ec Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 12:33:22 -0800 Subject: [PATCH 11/88] Support passing closure arguments to controllers. --- .../include/realm-execution/realm_manager.h | 3 ++- lib/realm-execution/src/realm-execution/realm_manager.cc | 9 +++++---- .../test/src/realm-execution/realm_manager.cc | 7 +++++-- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 3 ++- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 497a1f3958..88cc11f744 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -17,7 +17,8 @@ struct RealmManager { RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - Realm::Event start_controller(void (*thunk)(RealmManager &)); + [[nodiscard]] Realm::Event + start_controller(std::function); // Current device context Allocator &get_current_device_allocator() const; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 33e7ca252e..0ccf3f4116 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -25,7 +25,8 @@ RealmManager::~RealmManager() { } } -Realm::Event RealmManager::start_controller(void (*thunk)(RealmManager &)) { +Realm::Event + RealmManager::start_controller(std::function thunk) { constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, @@ -69,9 +70,9 @@ void RealmManager::controller_task_wrapper(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - assert(arglen == sizeof(void (*)(RealmManager &))); - void (*thunk)(RealmManager &) = - *reinterpret_cast(args); + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); RealmManager manager(args, arglen, userdata, userlen, proc); thunk(manager); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 880268c018..16b5338881 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -15,8 +15,11 @@ TEST_SUITE(FF_TEST_SUITE) { // Initialize Realm RealmManager manager(&fake_argc, &fake_argv); - // Launch a controller and wait on it - Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + // Launch a controller + int some_data = 123; + Realm::Event event = manager.start_controller( + [&](RealmManager &manager) { ASSERT(some_data == 123); }); + // Need to block on the completion of the event to ensure we don't race event.wait(); } } diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index f09951e73c..623b8318e6 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,6 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - manager.start_controller([](RealmManager &manager) {}); + Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + event.wait(); } } From 98c605388f931783ffabca898d810271daa9df39 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:22:44 -0800 Subject: [PATCH 12/88] Move task IDs into Realm and assign IDs to remaining tasks. --- .../realm-execution}/task_id_t.dtg.toml | 5 +- .../include/realm-execution/task_id_t.h | 28 ++ .../src/realm-execution/task_id_t.cc | 192 ++++++++++++++ .../include/task-spec/ops/impl/dropout.h | 1 - .../task-spec/ops/op_task_id_t.dtg.toml | 18 -- .../task_id_with_noop_default_t.dtg.toml | 28 -- .../task-spec/task_id_with_noop_default_t.h | 28 -- .../task-spec/task_id_with_noop_default_t.cc | 243 ------------------ 8 files changed, 221 insertions(+), 322 deletions(-) rename lib/{task-spec/include/task-spec => realm-execution/include/realm-execution}/task_id_t.dtg.toml (98%) create mode 100644 lib/realm-execution/include/realm-execution/task_id_t.h create mode 100644 lib/realm-execution/src/realm-execution/task_id_t.cc delete mode 100644 lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml delete mode 100644 lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml delete mode 100644 lib/task-spec/include/task-spec/task_id_with_noop_default_t.h delete mode 100644 lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc diff --git a/lib/task-spec/include/task-spec/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml similarity index 98% rename from lib/task-spec/include/task-spec/task_id_t.dtg.toml rename to lib/realm-execution/include/realm-execution/task_id_t.dtg.toml index ce2de52d40..0336bc81a4 100644 --- a/lib/task-spec/include/task-spec/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml @@ -9,10 +9,7 @@ features = [ ] [[values]] -name = "TOP_LEVEL_TASK_ID" - -[[values]] -name = "FF_INIT_TASK_ID" +name = "CONTROLLER_TASK_ID" [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/task_id_t.h new file mode 100644 index 0000000000..af20dc27f6 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/task_id_t.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/task_id_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include + +namespace FlexFlow { + +std::optional + get_task_id_for_op(DynamicNodeInvocation const &, + std::optional const &); + +std::optional + get_init_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional get_fwd_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional get_bwd_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional + get_update_task_id_for_optimizer_attrs(OptimizerAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc new file mode 100644 index 0000000000..94b5fb5b24 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -0,0 +1,192 @@ +#include "realm-execution/task_id_t.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + get_task_id_for_op(DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs) { + DynamicTaskType task_type = invocation.node_attrs.task_type.value(); + switch (task_type) { + case DynamicTaskType::FWD: + return get_fwd_task_id_for_op_attrs( + invocation.node_attrs.op_attrs.value()); + case DynamicTaskType::BWD: + return get_bwd_task_id_for_op_attrs( + invocation.node_attrs.op_attrs.value()); + case DynamicTaskType::UPD: + return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); + case DynamicTaskType::LOSS: + return task_id_t::LOSS_BWD_TASK_ID; + default: + PANIC("Unhandled DynamicTaskType", task_type); + } +} + +std::optional + get_init_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { return std::nullopt; }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_INIT_TASK_ID; }, + [](BroadcastAttrs const &) { return std::nullopt; }, + [](CastAttrs const &) { return std::nullopt; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_INIT_TASK_ID; }, + [](ConcatAttrs const &) { return std::nullopt; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_INIT_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_INIT_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_INIT_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_INIT_TASK_ID; + }, + [](EmbeddingAttrs const &) { return std::nullopt; }, + [](FlatAttrs const &) { return std::nullopt; }, + [](GatherAttrs const &) { return task_id_t::GATHER_INIT_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_INIT_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_INIT_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_INIT_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_INIT_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_INIT_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_INIT_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_INIT_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_INIT_TASK_ID; + }, + [](ReshapeAttrs const &) { return std::nullopt; }, + [](ReverseAttrs const &) { return std::nullopt; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, + [](SplitAttrs const &) { return std::nullopt; }, + [](TopKAttrs const &) { return std::nullopt; }, + [](TransposeAttrs const &) { return std::nullopt; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional + get_fwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return task_id_t::BATCHMATMUL_FWD_TASK_ID; + }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_FWD_TASK_ID; }, + [](BroadcastAttrs const &) { return task_id_t::BROADCAST_FWD_TASK_ID; }, + [](CastAttrs const &) { return task_id_t::CAST_FWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_FWD_TASK_ID; }, + [](ConcatAttrs const &) { return task_id_t::CONCAT_FWD_TASK_ID; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_FWD_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_FWD_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_FWD_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_FWD_TASK_ID; + }, + [](EmbeddingAttrs const &) { return task_id_t::EMBED_FWD_TASK_ID; }, + [](FlatAttrs const &) { return task_id_t::FLAT_FWD_TASK_ID; }, + [](GatherAttrs const &) { return task_id_t::GATHER_FWD_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_FWD_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_FWD_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_FWD_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_FWD_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_FWD_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_FWD_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_FWD_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_FWD_TASK_ID; + }, + [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, + [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, + [](SplitAttrs const &) { return task_id_t::SPLIT_FWD_TASK_ID; }, + [](TopKAttrs const &) { return task_id_t::TOPK_FWD_TASK_ID; }, + [](TransposeAttrs const &) { return task_id_t::TRANSPOSE_FWD_TASK_ID; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional + get_bwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return task_id_t::BATCHMATMUL_BWD_TASK_ID; + }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_BWD_TASK_ID; }, + [](BroadcastAttrs const &) { return task_id_t::BROADCAST_BWD_TASK_ID; }, + [](CastAttrs const &) { return task_id_t::CAST_BWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_BWD_TASK_ID; }, + [](ConcatAttrs const &) { return task_id_t::CONCAT_BWD_TASK_ID; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_BWD_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_BWD_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_BWD_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_BWD_TASK_ID; + }, + [](EmbeddingAttrs const &) { return task_id_t::EMBED_BWD_TASK_ID; }, + [](FlatAttrs const &) { return task_id_t::FLAT_BWD_TASK_ID; }, + [](GatherAttrs const &) { return task_id_t::GATHER_BWD_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_BWD_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_BWD_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_BWD_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_BWD_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_BWD_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_BWD_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_BWD_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_BWD_TASK_ID; + }, + [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, + [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, + [](SplitAttrs const &) { return task_id_t::SPLIT_BWD_TASK_ID; }, + [](TopKAttrs const &) { return task_id_t::TOPK_BWD_TASK_ID; }, + [](TransposeAttrs const &) { return task_id_t::TRANSPOSE_BWD_TASK_ID; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional get_update_task_id_for_optimizer_attrs( + OptimizerAttrs const &optimizer_attrs) { + + return optimizer_attrs.visit>(overload{ + [](SGDOptimizerAttrs const &) { return task_id_t::SGD_UPD_NCCL_TASK_ID; }, + [](AdamOptimizerAttrs const &) { + return task_id_t::ADAM_UPD_NCCL_TASK_ID; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/ops/impl/dropout.h b/lib/task-spec/include/task-spec/ops/impl/dropout.h index a7b382ce62..192f2f8244 100644 --- a/lib/task-spec/include/task-spec/ops/impl/dropout.h +++ b/lib/task-spec/include/task-spec/ops/impl/dropout.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_DROPOUT_H #include "op-attrs/ops/dropout_attrs.dtg.h" -#include "task-spec/task_id_t.dtg.h" #include "task-spec/task_impl_function.dtg.h" namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml b/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml deleted file mode 100644 index 557da6cf4c..0000000000 --- a/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "op_task_id_t" -type = "enum" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "INIT" - -[[values]] -name = "FWD" - -[[values]] -name = "BWD" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml deleted file mode 100644 index 50349d5773..0000000000 --- a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "task_id_with_noop_default_t" -type = "variant" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", -] - -includes = [ - "task-spec/task_id_t.dtg.h", - "", -] - -src_includes = [ - "utils/rapidcheck/monostate.h", - "utils/fmt/monostate.h", -] - -[[values]] -type = "::FlexFlow::task_id_t" -key = "real_task" - -[[values]] -type = "std::monostate" -key = "noop_task" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h deleted file mode 100644 index 054b73844e..0000000000 --- a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H - -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/operator_type.dtg.h" -#include "task-spec/ops/op_task_id_t.dtg.h" -#include "task-spec/task_id_with_noop_default_t.dtg.h" - -namespace FlexFlow { - -task_id_with_noop_default_t lift_task_id_t(task_id_t); -task_id_with_noop_default_t default_noop_task(); - -task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( - op_task_id_t, ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc b/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc deleted file mode 100644 index 20e0d00c57..0000000000 --- a/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc +++ /dev/null @@ -1,243 +0,0 @@ -#include "task-spec/task_id_with_noop_default_t.h" -#include "utils/overload.h" - -namespace FlexFlow { - -task_id_with_noop_default_t lift_task_id_t(task_id_t task_id) { - return task_id_with_noop_default_t{task_id}; -} - -task_id_with_noop_default_t default_noop_task() { - return task_id_with_noop_default_t{std::monostate{}}; -} - -task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( - op_task_id_t op_task_id, ComputationGraphOpAttrs const &op_attrs) { - switch (op_task_id) { - case op_task_id_t::INIT: - return get_init_task_id_for_op_attrs(op_attrs); - case op_task_id_t::FWD: - return get_fwd_task_id_for_op_attrs(op_attrs); - case op_task_id_t::BWD: - return get_bwd_task_id_for_op_attrs(op_attrs); - default: - PANIC("Unhandled op_task_id_t", op_task_id); - } -} - -task_id_with_noop_default_t - get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { return default_noop_task(); }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_INIT_TASK_ID); - }, - [](BroadcastAttrs const &) { return default_noop_task(); }, - [](CastAttrs const &) { return default_noop_task(); }, - [](ConcatAttrs const &) { return default_noop_task(); }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_INIT_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_INIT_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_INIT_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_INIT_TASK_ID); - }, - [](EmbeddingAttrs const &) { return default_noop_task(); }, - [](FlatAttrs const &) { return default_noop_task(); }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_INIT_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_INIT_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_INIT_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_INIT_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_INIT_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_INIT_TASK_ID); - }, - [](ReshapeAttrs const &) { return default_noop_task(); }, - [](ReverseAttrs const &) { return default_noop_task(); }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_INIT_TASK_ID); - }, - [](SplitAttrs const &) { return default_noop_task(); }, - [](TopKAttrs const &) { return default_noop_task(); }, - [](TransposeAttrs const &) { return default_noop_task(); }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -task_id_with_noop_default_t - get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { - return lift_task_id_t(task_id_t::BATCHMATMUL_FWD_TASK_ID); - }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_FWD_TASK_ID); - }, - [](BroadcastAttrs const &) { - return lift_task_id_t(task_id_t::BROADCAST_FWD_TASK_ID); - }, - [](CastAttrs const &) { - return lift_task_id_t(task_id_t::CAST_FWD_TASK_ID); - }, - [](ConcatAttrs const &) { - return lift_task_id_t(task_id_t::CONCAT_FWD_TASK_ID); - }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_FWD_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_FWD_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_FWD_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_FWD_TASK_ID); - }, - [](EmbeddingAttrs const &) { - return lift_task_id_t(task_id_t::EMBED_FWD_TASK_ID); - }, - [](FlatAttrs const &) { - return lift_task_id_t(task_id_t::FLAT_FWD_TASK_ID); - }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_FWD_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_FWD_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_FWD_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_FWD_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_FWD_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_FWD_TASK_ID); - }, - [](ReshapeAttrs const &) { - return lift_task_id_t(task_id_t::RESHAPE_FWD_TASK_ID); - }, - [](ReverseAttrs const &) { - return lift_task_id_t(task_id_t::REVERSE_FWD_TASK_ID); - }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_FWD_TASK_ID); - }, - [](SplitAttrs const &) { - return lift_task_id_t(task_id_t::SPLIT_FWD_TASK_ID); - }, - [](TopKAttrs const &) { - return lift_task_id_t(task_id_t::TOPK_FWD_TASK_ID); - }, - [](TransposeAttrs const &) { - return lift_task_id_t(task_id_t::TRANSPOSE_FWD_TASK_ID); - }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -task_id_with_noop_default_t - get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { - return lift_task_id_t(task_id_t::BATCHMATMUL_BWD_TASK_ID); - }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_BWD_TASK_ID); - }, - [](BroadcastAttrs const &) { - return lift_task_id_t(task_id_t::BROADCAST_BWD_TASK_ID); - }, - [](CastAttrs const &) { - return lift_task_id_t(task_id_t::CAST_BWD_TASK_ID); - }, - [](ConcatAttrs const &) { - return lift_task_id_t(task_id_t::CONCAT_BWD_TASK_ID); - }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_BWD_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_BWD_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_BWD_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_BWD_TASK_ID); - }, - [](EmbeddingAttrs const &) { - return lift_task_id_t(task_id_t::EMBED_BWD_TASK_ID); - }, - [](FlatAttrs const &) { - return lift_task_id_t(task_id_t::FLAT_BWD_TASK_ID); - }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_BWD_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_BWD_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_BWD_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_BWD_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_BWD_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_BWD_TASK_ID); - }, - [](ReshapeAttrs const &) { - return lift_task_id_t(task_id_t::RESHAPE_BWD_TASK_ID); - }, - [](ReverseAttrs const &) { - return lift_task_id_t(task_id_t::REVERSE_BWD_TASK_ID); - }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_BWD_TASK_ID); - }, - [](SplitAttrs const &) { - return lift_task_id_t(task_id_t::SPLIT_BWD_TASK_ID); - }, - [](TopKAttrs const &) { - return lift_task_id_t(task_id_t::TOPK_BWD_TASK_ID); - }, - [](TransposeAttrs const &) { - return lift_task_id_t(task_id_t::TRANSPOSE_BWD_TASK_ID); - }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -} // namespace FlexFlow From f8ab5752629eb777cdc455335374d5249a745065 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:34:22 -0800 Subject: [PATCH 13/88] Avoid pulling in the entire invocation. --- .../include/realm-execution/task_id_t.h | 4 ++-- lib/realm-execution/src/realm-execution/task_id_t.cc | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/task_id_t.h index af20dc27f6..38b82ad9e0 100644 --- a/lib/realm-execution/include/realm-execution/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/task_id_t.h @@ -4,13 +4,13 @@ #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include namespace FlexFlow { std::optional - get_task_id_for_op(DynamicNodeInvocation const &, + get_task_id_for_op(DynamicNodeAttrs const &, std::optional const &); std::optional diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc index 94b5fb5b24..574dbb1e54 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -8,16 +8,14 @@ namespace FlexFlow { std::optional - get_task_id_for_op(DynamicNodeInvocation const &invocation, + get_task_id_for_op(DynamicNodeAttrs const &node_attrs, std::optional const &optimizer_attrs) { - DynamicTaskType task_type = invocation.node_attrs.task_type.value(); + DynamicTaskType task_type = node_attrs.task_type.value(); switch (task_type) { case DynamicTaskType::FWD: - return get_fwd_task_id_for_op_attrs( - invocation.node_attrs.op_attrs.value()); + return get_fwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); case DynamicTaskType::BWD: - return get_bwd_task_id_for_op_attrs( - invocation.node_attrs.op_attrs.value()); + return get_bwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); case DynamicTaskType::UPD: return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); case DynamicTaskType::LOSS: From c9c2b183b01e9d05f287be3cc8bcdd9a1b23bb74 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:45:46 -0800 Subject: [PATCH 14/88] Conversion into Realm task IDs. --- .../include/realm-execution/realm_task_id_t.h | 13 +++++++++++++ .../src/realm-execution/realm_manager.cc | 5 ++++- .../src/realm-execution/realm_task_id_t.cc | 10 ++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_task_id_t.h create mode 100644 lib/realm-execution/src/realm-execution/realm_task_id_t.cc diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h new file mode 100644 index 0000000000..6d2e316b14 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H + +#include "realm-execution/task_id_t.dtg.h" +#include "realm.h" + +namespace FlexFlow { + +Realm::Processor::TaskFuncID get_realm_task_id_for_task_id(task_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 0ccf3f4116..747f603f5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,4 +1,6 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { @@ -27,7 +29,8 @@ RealmManager::~RealmManager() { Realm::Event RealmManager::start_controller(std::function thunk) { - constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; + Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = + get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, /*global=*/false, diff --git a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc b/lib/realm-execution/src/realm-execution/realm_task_id_t.cc new file mode 100644 index 0000000000..50b23dfe86 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_task_id_t.cc @@ -0,0 +1,10 @@ +#include "realm-execution/realm_task_id_t.h" + +namespace FlexFlow { + +Realm::Processor::TaskFuncID get_realm_task_id_for_task_id(task_id_t task_id) { + return Realm::Processor::TASK_ID_FIRST_AVAILABLE + + static_cast(task_id); +} + +} // namespace FlexFlow From 2a0bdd98ac6f6c2547f44e101f88d2ff4aed6ede Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 15:49:20 -0800 Subject: [PATCH 15/88] Add a top-level PRealm switch. --- .../include/realm-execution/realm.h | 20 +++++++++++++++++++ .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../src/realm-execution/task_id_t.cc | 1 - .../test/src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/test_e2e.cc | 3 ++- 6 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm.h diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h new file mode 100644 index 0000000000..f15113ee92 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H + +#ifdef FLEXFLOW_USE_PREALM +#include +#else +#include +#endif + +namespace FlexFlow { + +#ifdef FLEXFLOW_USE_PREALM +namespace Realm = ::PRealm; +#else +namespace Realm = ::Realm; +#endif + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 88cc11f744..b26adea548 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,7 +4,7 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm.h" +#include "realm-execution/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 6d2e316b14..8e6da1a2bd 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc index 574dbb1e54..3521f50c02 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -2,7 +2,6 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" -#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 16b5338881..f9fbd986c2 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -17,7 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - Realm::Event event = manager.start_controller( + FlexFlow::Realm::Event event = manager.start_controller( [&](RealmManager &manager) { ASSERT(some_data == 123); }); // Need to block on the completion of the event to ensure we don't race event.wait(); diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 623b8318e6..fa9f798e4f 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,7 +11,8 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + FlexFlow::Realm::Event event = + manager.start_controller([](RealmManager &manager) {}); event.wait(); } } From 05eeada94290694a10c7e7bab2190ccb663a0526 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 17:08:38 -0800 Subject: [PATCH 16/88] Some work on Realm task registry. --- .../realm-execution/realm_task_registry.h | 13 +++++ .../realm-execution/realm_task_registry.cc | 55 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 lib/realm-execution/include/realm-execution/realm_task_registry.h create mode 100644 lib/realm-execution/src/realm-execution/realm_task_registry.cc diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h new file mode 100644 index 0000000000..3a4cee106c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/task_id_t.dtg.h" + +namespace FlexFlow { + +Realm::Event register_all_tasks(); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc new file mode 100644 index 0000000000..a5e52b7a7c --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -0,0 +1,55 @@ +#include "realm-execution/realm.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" + +namespace FlexFlow { + +void op_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor) {} + +static Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)) { + return Realm::Processor::register_task_by_kind( + target_kind, + /*global=*/false, + get_realm_task_id_for_task_id(func_id), + Realm::CodeDescriptor(task_body), + Realm::ProfilingRequestSet()); +} + +Realm::Event register_all_tasks() { + std::vector pending_registrations; + + std::vector init_task_ids = { + task_id_t::BATCHNORM_INIT_TASK_ID, + task_id_t::COMBINE_INIT_TASK_ID, + task_id_t::CONV2D_INIT_TASK_ID, + task_id_t::DROPOUT_INIT_TASK_ID, + task_id_t::ELEMENTBINARY_INIT_TASK_ID, + task_id_t::ELEMENTUNARY_INIT_TASK_ID, + task_id_t::GATHER_INIT_TASK_ID, + task_id_t::LAYERNORM_INIT_TASK_ID, + task_id_t::LINEAR_INIT_TASK_ID, + task_id_t::ATTENTION_INIT_TASK_ID, + task_id_t::POOL2D_INIT_TASK_ID, + task_id_t::REDUCE_INIT_TASK_ID, + task_id_t::REDUCTION_INIT_TASK_ID, + task_id_t::REPARTITION_INIT_TASK_ID, + task_id_t::REPLICATE_INIT_TASK_ID, + task_id_t::SOFTMAX_INIT_TASK_ID, + }; + + for (task_id_t init_task_id : init_task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::LOC_PROC, init_task_id, op_task_wrapper)); + } + + return Realm::Event::merge_events(pending_registrations); +} + +} // namespace FlexFlow From 621814b97023bf8538ac6d947af884ca38cd1e22 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:20:53 -0800 Subject: [PATCH 17/88] Split out the Realm context. --- .../parallel_computation_graph_instance.h | 8 +-- .../include/realm-execution/realm_context.h | 34 +++++++++++ .../include/realm-execution/realm_manager.h | 25 ++------ .../parallel_computation_graph_instance.cc | 4 +- .../src/realm-execution/realm_context.cc | 34 +++++++++++ .../src/realm-execution/realm_manager.cc | 60 +++++-------------- .../test/src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/test_e2e.cc | 4 +- 8 files changed, 96 insertions(+), 75 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_context.h create mode 100644 lib/realm-execution/src/realm-execution/realm_context.cc diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 4ba77a7925..0dd87d566f 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,7 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "realm-execution/realm_manager.h" +#include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -21,7 +21,7 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(RealmManager &, + ParallelComputationGraphInstance(RealmContext &, DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, @@ -36,7 +36,7 @@ struct ParallelComputationGraphInstance { std::optional get_loss_tensor_accessor() const; private: - RealmManager &realm; + RealmContext &realm; DynamicOpenDataflowGraph dataflow_graph; std::vector topological_ordering; OptimizerAttrs optimizer_attrs; @@ -45,7 +45,7 @@ struct ParallelComputationGraphInstance { }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmManager &realm, + RealmContext &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h new file mode 100644 index 0000000000..5539fe693e --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H + +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "realm-execution/realm.h" + +namespace FlexFlow { + +struct RealmContext { +public: + RealmContext(); + virtual ~RealmContext(); + + RealmContext(RealmContext const &) = delete; + RealmContext(RealmContext &&) = delete; + + // Current device context + Allocator &get_current_device_allocator() const; + device_handle_t const &get_current_device_handle() const; + device_id_t const &get_current_device_idx() const; + +protected: + [[nodiscard]] Realm::Event merge_outstanding_events(); + +protected: + Realm::Runtime runtime; + std::vector outstanding_events; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index b26adea548..bf5e8f72f1 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -5,38 +5,21 @@ #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" namespace FlexFlow { -struct RealmManager { +struct RealmManager : private RealmContext { public: RealmManager(int *argc, char ***argv); - ~RealmManager(); + virtual ~RealmManager(); RealmManager() = delete; RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; [[nodiscard]] Realm::Event - start_controller(std::function); - - // Current device context - Allocator &get_current_device_allocator() const; - device_handle_t const &get_current_device_handle() const; - device_id_t const &get_current_device_idx() const; - -private: - RealmManager(void const *, size_t, void const *, size_t, Realm::Processor); - - [[nodiscard]] Realm::Event merge_outstanding_events(); - - static void controller_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor); - -private: - Realm::Runtime runtime; - std::vector outstanding_events; - bool is_root_runtime; + start_controller(std::function); }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 64c9da2f4c..c8100287f8 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -12,7 +12,7 @@ namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( - RealmManager &realm, + RealmContext &realm, DynamicOpenDataflowGraph dataflow_graph, std::vector const &topological_ordering, OptimizerAttrs const &optimizer_attrs, @@ -61,7 +61,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmManager &realm, + RealmContext &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc new file mode 100644 index 0000000000..5068373ebe --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -0,0 +1,34 @@ +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" +#include "utils/exception.h" + +namespace FlexFlow { + +RealmContext::RealmContext() {} + +RealmContext::~RealmContext() { + if (!this->outstanding_events.empty()) { + Realm::Event outstanding = this->merge_outstanding_events(); + outstanding.wait(); + } +} + +Allocator &RealmContext::get_current_device_allocator() const { + NOT_IMPLEMENTED(); +} + +device_handle_t const &RealmContext::get_current_device_handle() const { + NOT_IMPLEMENTED(); +} +device_id_t const &RealmContext::get_current_device_idx() const { + NOT_IMPLEMENTED(); +} + +Realm::Event RealmContext::merge_outstanding_events() { + Realm::Event result = Realm::Event::merge_events(this->outstanding_events); + this->outstanding_events.clear(); + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 747f603f5d..501ba7536a 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -5,37 +5,39 @@ namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) : is_root_runtime(true) { +RealmManager::RealmManager(int *argc, char ***argv) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); } -RealmManager::RealmManager(void const *args, - size_t arglen, - void const *userdata, - size_t userdatalen, - Realm::Processor proc) - : runtime(Realm::Runtime::get_runtime()), is_root_runtime(false) {} - RealmManager::~RealmManager() { Realm::Event outstanding = this->merge_outstanding_events(); - if (is_root_runtime) { this->runtime.shutdown(outstanding); this->runtime.wait_for_shutdown(); - } else { - outstanding.wait(); - } +} + +static void controller_task_wrapper(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); + + RealmContext ctx; + thunk(ctx); } Realm::Event - RealmManager::start_controller(std::function thunk) { + RealmManager::start_controller(std::function thunk) { Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, /*global=*/false, CONTROLLER_TASK_ID, - Realm::CodeDescriptor(RealmManager::controller_task_wrapper), + Realm::CodeDescriptor(controller_task_wrapper), Realm::ProfilingRequestSet(), &thunk, sizeof(thunk)); @@ -51,34 +53,4 @@ Realm::Event return task_complete; } -Allocator &RealmManager::get_current_device_allocator() const { - NOT_IMPLEMENTED(); -} - -device_handle_t const &RealmManager::get_current_device_handle() const { - NOT_IMPLEMENTED(); -} -device_id_t const &RealmManager::get_current_device_idx() const { - NOT_IMPLEMENTED(); -} - -Realm::Event RealmManager::merge_outstanding_events() { - Realm::Event result = Realm::Event::merge_events(this->outstanding_events); - this->outstanding_events.clear(); - return result; -} - -void RealmManager::controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); - - RealmManager manager(args, arglen, userdata, userlen, proc); - thunk(manager); -} - } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index f9fbd986c2..6c28a001ad 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; FlexFlow::Realm::Event event = manager.start_controller( - [&](RealmManager &manager) { ASSERT(some_data == 123); }); + [&](RealmContext &ctx) { ASSERT(some_data == 123); }); // Need to block on the completion of the event to ensure we don't race event.wait(); } diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index fa9f798e4f..a30d5c4d8e 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,8 +11,6 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - FlexFlow::Realm::Event event = - manager.start_controller([](RealmManager &manager) {}); - event.wait(); + (void)manager.start_controller([](RealmContext &ctx) {}); } } From 362b6c03425a9f0e14df30e7a9c3ef237f6054ee Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:32:02 -0800 Subject: [PATCH 18/88] Switch to mapped PCG. --- .../parallel_computation_graph_instance.h | 4 ++-- .../parallel_computation_graph_instance.cc | 7 ++++--- .../src/realm-execution/realm_manager.cc | 12 ++++++------ ...ke_dynamic_open_dataflow_graph_from_mpcg.h | 14 ++++++++++++++ ...ake_dynamic_open_dataflow_graph_from_pcg.h | 14 -------------- ..._dynamic_open_dataflow_graph_from_mpcg.cc} | 19 ++++++++++--------- 6 files changed, 36 insertions(+), 34 deletions(-) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h delete mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h rename lib/task-spec/src/task-spec/dynamic_graph/{make_dynamic_open_dataflow_graph_from_pcg.cc => make_dynamic_open_dataflow_graph_from_mpcg.cc} (84%) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 0dd87d566f..06c2d2d912 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -7,8 +7,8 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/device_id_t.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -46,7 +46,7 @@ struct ParallelComputationGraphInstance { ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &realm, - ParallelComputationGraph const &pcg, + MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index c8100287f8..e7bf79f12d 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -4,7 +4,7 @@ #include "pcg/optimizer_attrs.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/loss_insertion.h" -#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" @@ -62,7 +62,7 @@ static GenericTensorAccessorW ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &realm, - ParallelComputationGraph const &pcg, + MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, @@ -72,7 +72,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config) { - DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); + DynamicOpenDataflowGraph dg = + make_dynamic_open_dataflow_graph_from_mpcg(mpcg); dg = perform_pass_expansion(dg); std::unordered_map inputs = diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 501ba7536a..0c34d77204 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -12,15 +12,15 @@ RealmManager::RealmManager(int *argc, char ***argv) { RealmManager::~RealmManager() { Realm::Event outstanding = this->merge_outstanding_events(); - this->runtime.shutdown(outstanding); - this->runtime.wait_for_shutdown(); + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); } static void controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { ASSERT(arglen == sizeof(std::function)); std::function thunk = *reinterpret_cast const *>(args); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h new file mode 100644 index 0000000000..758a0c2813 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_MPCG_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_MPCG_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( + MappedParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h deleted file mode 100644 index a71eb558c1..0000000000 --- a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H - -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" - -namespace FlexFlow { - -DynamicOpenDataflowGraph - make_dynamic_open_dataflow_graph_from_pcg(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc similarity index 84% rename from lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc rename to lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index 841be27dfd..e90ef10398 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -1,4 +1,4 @@ -#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -13,26 +13,27 @@ namespace FlexFlow { -DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( - ParallelComputationGraph const &pcg) { +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( + MappedParallelComputationGraph const &mpcg) { DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { + for (auto const &[layer, attrs] : + get_parallel_layer_attrs_mapping(mpcg.pcg)) { DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, + /*mapping=*/mpcg.mapped_tasks.at(layer), /*op_attrs=*/attrs.op_attrs, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), + transform(get_incoming_tensors(mpcg.pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + get_parallel_tensor_attrs(mpcg.pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, @@ -48,11 +49,11 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( }; }); std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), + transform(get_outgoing_tensors(mpcg.pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + get_parallel_tensor_attrs(mpcg.pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, From b39058cfb1a38e353e4158a3e99a8e075939bc27 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:54:57 -0800 Subject: [PATCH 19/88] Add shard expansion pass (and implement shard expansion pass). --- .../parallel_computation_graph_instance.h | 3 ++- .../parallel_computation_graph_instance.cc | 9 ++++++--- .../task-spec/dynamic_graph/shard_expansion.cc | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 06c2d2d912..f361cec3ca 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -50,7 +51,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, - std::optional logit_tensor, + std::optional logit_tensor, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index e7bf79f12d..80ed98f8c2 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -3,9 +3,11 @@ #include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" @@ -66,7 +68,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, - std::optional logit_tensor, + std::optional logit_tensor, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, @@ -81,13 +83,14 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_grad_value; if (loss_attrs) { auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( - dg, assert_unwrap(loss_attrs), assert_unwrap(logit_tensor)); + dg, loss_attrs.value(), dynamic_tensor_guid_t{logit_tensor.value()}); dg = dg2; logit_grad_value = logit_grad_v; - inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); + inputs.insert(std::pair{label_v, label_tensor.value()}); } dg = perform_update_insertion(dg, optimizer_attrs); + dg = perform_shard_expansion(dg); dg = perform_tensor_allocation( dg, inputs, realm.get_current_device_allocator()); diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index ea253b63f8..33b7fb8591 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -81,4 +81,19 @@ std::unordered_set }); } +DynamicOpenDataflowGraph + perform_shard_expansion(DynamicOpenDataflowGraph const &g) { + + ASSERT(no_part_of_graph_is_shard_expanded(g)); + + DynamicOpenDataflowGraph result = + flatmap_dynamic_invocation_set(g, [&](DynamicNodeInvocation const &i) { + return perform_shard_expansion_for_invocation(i); + }); + + ASSERT(graph_is_fully_shard_expanded(result)); + + return result; +} + } // namespace FlexFlow From 32c4f61e8a43f746deacaf12c521d0fed89915ac Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 10:47:43 -0800 Subject: [PATCH 20/88] Add instance field to dynamic graph, more task IDs. --- .../include/realm-execution/realm_context.h | 2 +- .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../realm-execution/realm_task_registry.h | 4 +- .../realm-execution/realm_task_registry.cc | 81 +++++++++++++++++-- lib/task-spec/CMakeLists.txt | 1 + .../dynamic_value_attrs.dtg.toml | 6 ++ .../include/task-spec/realm/fmt/instance.h | 35 ++++++++ .../include/task-spec/realm}/realm.h | 4 +- .../task-spec/dynamic_graph/loss_insertion.cc | 2 + ...ake_dynamic_open_dataflow_graph_from_cg.cc | 2 + ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 + .../dynamic_graph/update_insertion.cc | 1 + .../src/task-spec/realm/fmt/instance.h | 10 +++ .../dynamic_open_dataflow_graph.cc | 3 + .../dynamic_graph/machine_slicing.cc | 1 + .../task-spec/dynamic_graph/pass_expansion.cc | 3 + .../dynamic_graph/shard_expansion.cc | 1 + 18 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 lib/task-spec/include/task-spec/realm/fmt/instance.h rename lib/{realm-execution/include/realm-execution => task-spec/include/task-spec/realm}/realm.h (63%) create mode 100644 lib/task-spec/src/task-spec/realm/fmt/instance.h diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 5539fe693e..357b05b699 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -4,7 +4,7 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/realm.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index bf5e8f72f1..ebf3bb401e 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,8 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/realm.h" #include "realm-execution/realm_context.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 8e6da1a2bd..327cf9ffd0 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H -#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index 3a4cee106c..d9d993795b 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H -#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { -Realm::Event register_all_tasks(); +[[nodiscard]] Realm::Event register_all_tasks(); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc index a5e52b7a7c..5c61c208fb 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -1,11 +1,13 @@ -#include "realm-execution/realm.h" +#include "realm-execution/realm_task_registry.h" #include "realm-execution/realm_task_id_t.h" -#include "realm-execution/task_id_t.dtg.h" +#include "utils/exception.h" namespace FlexFlow { -void op_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor) {} +static void operation_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} static Realm::Event register_task(Realm::Processor::Kind target_kind, task_id_t func_id, @@ -25,7 +27,8 @@ static Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Event register_all_tasks() { std::vector pending_registrations; - std::vector init_task_ids = { + std::vector task_ids = { + // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, task_id_t::COMBINE_INIT_TASK_ID, task_id_t::CONV2D_INIT_TASK_ID, @@ -42,11 +45,75 @@ Realm::Event register_all_tasks() { task_id_t::REPARTITION_INIT_TASK_ID, task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, + + // Forward tasks + task_id_t::BATCHMATMUL_FWD_TASK_ID, + task_id_t::BATCHNORM_FWD_TASK_ID, + task_id_t::BROADCAST_FWD_TASK_ID, + task_id_t::CAST_FWD_TASK_ID, + task_id_t::COMBINE_FWD_TASK_ID, + task_id_t::CONCAT_FWD_TASK_ID, + task_id_t::CONV2D_FWD_TASK_ID, + task_id_t::DROPOUT_FWD_TASK_ID, + task_id_t::ELEMENTBINARY_FWD_TASK_ID, + task_id_t::ELEMENTUNARY_FWD_TASK_ID, + task_id_t::EMBED_FWD_TASK_ID, + task_id_t::FLAT_FWD_TASK_ID, + task_id_t::GATHER_FWD_TASK_ID, + task_id_t::LAYERNORM_FWD_TASK_ID, + task_id_t::LINEAR_FWD_TASK_ID, + task_id_t::ATTENTION_FWD_TASK_ID, + task_id_t::POOL2D_FWD_TASK_ID, + task_id_t::REDUCE_FWD_TASK_ID, + task_id_t::REDUCTION_FWD_TASK_ID, + task_id_t::REPARTITION_FWD_TASK_ID, + task_id_t::REPLICATE_FWD_TASK_ID, + task_id_t::RESHAPE_FWD_TASK_ID, + task_id_t::REVERSE_FWD_TASK_ID, + task_id_t::SOFTMAX_FWD_TASK_ID, + task_id_t::SPLIT_FWD_TASK_ID, + task_id_t::TOPK_FWD_TASK_ID, + task_id_t::TRANSPOSE_FWD_TASK_ID, + + // Backward tasks + task_id_t::BATCHMATMUL_BWD_TASK_ID, + task_id_t::BATCHNORM_BWD_TASK_ID, + task_id_t::BROADCAST_BWD_TASK_ID, + task_id_t::CAST_BWD_TASK_ID, + task_id_t::COMBINE_BWD_TASK_ID, + task_id_t::CONCAT_BWD_TASK_ID, + task_id_t::CONV2D_BWD_TASK_ID, + task_id_t::DROPOUT_BWD_TASK_ID, + task_id_t::ELEMENTBINARY_BWD_TASK_ID, + task_id_t::ELEMENTUNARY_BWD_TASK_ID, + task_id_t::EMBED_BWD_TASK_ID, + task_id_t::FLAT_BWD_TASK_ID, + task_id_t::GATHER_BWD_TASK_ID, + task_id_t::LAYERNORM_BWD_TASK_ID, + task_id_t::LINEAR_BWD_TASK_ID, + task_id_t::ATTENTION_BWD_TASK_ID, + task_id_t::POOL2D_BWD_TASK_ID, + task_id_t::REDUCE_BWD_TASK_ID, + task_id_t::REDUCTION_BWD_TASK_ID, + task_id_t::REPARTITION_BWD_TASK_ID, + task_id_t::REPLICATE_BWD_TASK_ID, + task_id_t::RESHAPE_BWD_TASK_ID, + task_id_t::REVERSE_BWD_TASK_ID, + task_id_t::SOFTMAX_BWD_TASK_ID, + task_id_t::SPLIT_BWD_TASK_ID, + task_id_t::TOPK_BWD_TASK_ID, + task_id_t::TRANSPOSE_BWD_TASK_ID, + + // Update tasks + task_id_t::SGD_UPD_NCCL_TASK_ID, + task_id_t::ADAM_UPD_NCCL_TASK_ID, }; - for (task_id_t init_task_id : init_task_ids) { + for (task_id_t task_id : task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::LOC_PROC, task_id, operation_task_wrapper)); pending_registrations.push_back(register_task( - Realm::Processor::LOC_PROC, init_task_id, op_task_wrapper)); + Realm::Processor::TOC_PROC, task_id, operation_task_wrapper)); } return Realm::Event::merge_events(pending_registrations); diff --git a/lib/task-spec/CMakeLists.txt b/lib/task-spec/CMakeLists.txt index 3c7c91af67..f4f5353f70 100644 --- a/lib/task-spec/CMakeLists.txt +++ b/lib/task-spec/CMakeLists.txt @@ -14,6 +14,7 @@ ff_add_library( pcg spdlog compiler + Realm::Realm ) add_subdirectory(test) diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 89b94b1017..763ebf180f 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -14,6 +14,8 @@ includes = [ "op-attrs/parallel_tensor_space_coordinate.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", + "task-spec/realm/fmt/instance.h", + "task-spec/realm/realm.h", ] src_includes = [ @@ -36,6 +38,10 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" name = "accessor" type = "std::optional<::FlexFlow::DynamicTensorAccessor>" +[[fields]] +name = "instance" +type = "std::optional<::FlexFlow::Realm::RegionInstance>" + [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/realm/fmt/instance.h b/lib/task-spec/include/task-spec/realm/fmt/instance.h new file mode 100644 index 0000000000..23979c7efc --- /dev/null +++ b/lib/task-spec/include/task-spec/realm/fmt/instance.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H + +#include "task-spec/realm/realm.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter<::FlexFlow::Realm::RegionInstance, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::FlexFlow::Realm::RegionInstance const &m, + FormatContext &ctx) const -> decltype(ctx.out()) { + std::string result = fmt::format("", m.id); + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::RegionInstance const &m); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/task-spec/include/task-spec/realm/realm.h similarity index 63% rename from lib/realm-execution/include/realm-execution/realm.h rename to lib/task-spec/include/task-spec/realm/realm.h index f15113ee92..8123c9e9fa 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/task-spec/include/task-spec/realm/realm.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #ifdef FLEXFLOW_USE_PREALM #include diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 4270119612..837ade2aad 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -23,6 +23,7 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; DynamicValueAttrs logit_grad_value{ @@ -30,6 +31,7 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; DynamicNodeInvocation loss_invocation{ diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 204597386e..294241b732 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -45,6 +45,7 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -64,6 +65,7 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index e90ef10398..eceb580a20 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -44,6 +44,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -64,6 +65,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 58a32db6c1..23708f3779 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -51,6 +51,7 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( DynamicValueAttrs value_attrs = output.second; ASSERT(value_attrs.accessor == std::nullopt); + ASSERT(value_attrs.instance == std::nullopt); DynamicNodeAttrs update_node_attrs = i.node_attrs; update_node_attrs.task_type = DynamicTaskType::UPD; diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.h b/lib/task-spec/src/task-spec/realm/fmt/instance.h new file mode 100644 index 0000000000..fa15e1c16f --- /dev/null +++ b/lib/task-spec/src/task-spec/realm/fmt/instance.h @@ -0,0 +1,10 @@ +#include "task-spec/realm/fmt/instance.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::RegionInstance const &m) { + return s << fmt::to_string(m); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index fc9110b6e4..bb9a45e59a 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -16,6 +16,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -29,6 +30,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -42,6 +44,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index 40d37f50df..c28e12e0af 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -76,6 +76,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index e8fcf2e40b..e57691b475 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -20,6 +20,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -113,6 +114,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -229,6 +231,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_type, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 23fbb6e514..4d88dde805 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -121,6 +121,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; From 5bd089e3c3ffcee7c1645392f08e79fc5d83400e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 10:48:32 -0800 Subject: [PATCH 21/88] Fix filename. --- lib/task-spec/src/task-spec/realm/fmt/{instance.h => instance.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lib/task-spec/src/task-spec/realm/fmt/{instance.h => instance.cc} (100%) diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.h b/lib/task-spec/src/task-spec/realm/fmt/instance.cc similarity index 100% rename from lib/task-spec/src/task-spec/realm/fmt/instance.h rename to lib/task-spec/src/task-spec/realm/fmt/instance.cc From 6bd47e042d8b865513e045d211cfbb7c9529e16a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 11:17:12 -0800 Subject: [PATCH 22/88] Some work in instance allocation and registry/manager. --- .../realm-execution/instance_allocation.h | 26 +++++ .../include/realm-execution/realm_context.h | 2 + .../realm-execution/realm_task_registry.h | 8 ++ .../realm-execution/instance_allocation.cc | 104 ++++++++++++++++++ .../parallel_computation_graph_instance.cc | 13 +-- .../src/realm-execution/realm_context.cc | 6 + .../src/realm-execution/realm_manager.cc | 46 ++++---- .../realm-execution/realm_task_registry.cc | 14 +-- 8 files changed, 182 insertions(+), 37 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/instance_allocation.h create mode 100644 lib/realm-execution/src/realm-execution/instance_allocation.cc diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h new file mode 100644 index 0000000000..ea07cf0601 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H + +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +bool no_instances_are_allocated(DynamicOpenDataflowGraph const &); +bool all_instances_are_allocated(DynamicOpenDataflowGraph const &); + +bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g); + +DynamicValueAttrs + perform_instance_allocation_for_value(DynamicValueAttrs const &, + Allocator &); + +DynamicOpenDataflowGraph perform_instance_allocation( + DynamicOpenDataflowGraph const &, + std::unordered_map const + &preallocated, + RealmContext &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 357b05b699..c72fe30b72 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -21,6 +21,8 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + Realm::Event get_outstanding_events(); + protected: [[nodiscard]] Realm::Event merge_outstanding_events(); diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index d9d993795b..d6bf5b927f 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -6,6 +6,14 @@ namespace FlexFlow { +[[nodiscard]] Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)); + [[nodiscard]] Realm::Event register_all_tasks(); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc new file mode 100644 index 0000000000..76d89313a6 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -0,0 +1,104 @@ +#include "realm-execution/instance_allocation.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/all_are_true.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/map_values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/optional.h" + +namespace FlexFlow { + +bool no_instances_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return !v.accessor.has_value() && !v.instance.has_value(); + })); +} + +bool all_instances_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return v.instance.has_value(); + })); +} + +bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return v.parallel_tensor_shape.has_value(); + })); +} + +DynamicValueAttrs + perform_instance_allocation_for_value(DynamicValueAttrs const &value, + RealmContext &ctx) { + ASSERT(value.accessor == std::nullopt); + ASSERT(value.instance == std::nullopt); + + TensorShape shape = + get_piece_shape(assert_unwrap(value.parallel_tensor_shape)); + + NOT_IMPLEMENTED(); + // GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + + DynamicValueAttrs result = value; + // result.accessor = DynamicTensorAccessor{accessor}; + + return result; +} + +DynamicOpenDataflowGraph perform_instance_allocation( + DynamicOpenDataflowGraph const &g, + std::unordered_map const + &preallocated, + RealmContext &ctx) { + ASSERT(no_instances_are_allocated(g)); + ASSERT(instances_are_ready_for_allocation(g)); + for (DynamicValueAttrs const &v : keys(preallocated)) { + ASSERT(v.accessor == std::nullopt); + ASSERT(v.instance == std::nullopt); + } + + std::unordered_set all_values = + unordered_set_of(get_dynamic_values(g)); + + bidict unallocated_to_allocated = + generate_bidict(all_values, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + if (contains_key(preallocated, v)) { + // FIXME: Attach external instance to existing + // allocation and use that + NOT_IMPLEMENTED(); + } else { + return perform_instance_allocation_for_value(v, ctx); + } + }); + + DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( + g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/map_values( + i.inputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + /*node_attrs=*/i.node_attrs, + /*outputs=*/ + map_values(i.outputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + }; + }); + + ASSERT(all_instances_are_allocated(result)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 80ed98f8c2..ec80519cf3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,7 +1,7 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include "local-execution/device_state_initialization.h" -#include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/instance_allocation.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" @@ -63,7 +63,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmContext &realm, + RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -91,8 +91,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - dg = perform_tensor_allocation( - dg, inputs, realm.get_current_device_allocator()); + dg = perform_instance_allocation(dg, inputs, ctx); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -100,12 +99,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( }); dg = perform_device_state_initialization(dg, - realm.get_current_device_allocator(), + ctx.get_current_device_allocator(), profiling_settings, - realm.get_current_device_handle(), + ctx.get_current_device_handle(), iteration_config, optimizer_attrs, - realm.get_current_device_idx()); + ctx.get_current_device_idx()); NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 5068373ebe..ede6ae6d8d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -25,6 +25,12 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event RealmContext::get_outstanding_events() { + Realm::Event result = this->merge_outstanding_events(); + this->outstanding_events.push_back(result); + return result; +} + Realm::Event RealmContext::merge_outstanding_events() { Realm::Event result = Realm::Event::merge_events(this->outstanding_events); this->outstanding_events.clear(); diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 0c34d77204..63c6266948 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,21 +1,11 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_task_id_t.h" +#include "realm-execution/realm_task_registry.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) { - bool ok = this->runtime.init(argc, argv); - ASSERT(ok); -} - -RealmManager::~RealmManager() { - Realm::Event outstanding = this->merge_outstanding_events(); - this->runtime.shutdown(outstanding); - this->runtime.wait_for_shutdown(); -} - static void controller_task_wrapper(void const *args, size_t arglen, void const *userdata, @@ -29,26 +19,36 @@ static void controller_task_wrapper(void const *args, thunk(ctx); } +RealmManager::RealmManager(int *argc, char ***argv) { + bool ok = this->runtime.init(argc, argv); + ASSERT(ok); + + // Register all tasks at initialization time so we don't need to later + register_all_tasks().wait(); + register_task(Realm::Processor::LOC_PROC, + task_id_t::CONTROLLER_TASK_ID, + controller_task_wrapper) + .wait(); +} + +RealmManager::~RealmManager() { + Realm::Event outstanding = this->merge_outstanding_events(); + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); +} + Realm::Event RealmManager::start_controller(std::function thunk) { - Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = - get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); - Realm::Event task_ready = Realm::Processor::register_task_by_kind( - Realm::Processor::LOC_PROC, - /*global=*/false, - CONTROLLER_TASK_ID, - Realm::CodeDescriptor(controller_task_wrapper), - Realm::ProfilingRequestSet(), - &thunk, - sizeof(thunk)); - Realm::Processor target_proc = Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) .only_kind(Realm::Processor::LOC_PROC) .first(); Realm::Event task_complete = this->runtime.collective_spawn( - target_proc, CONTROLLER_TASK_ID, &thunk, sizeof(thunk), task_ready); + target_proc, + get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), + &thunk, + sizeof(thunk)); this->outstanding_events.push_back(task_complete); return task_complete; } diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc index 5c61c208fb..436a6af3f3 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -9,13 +9,13 @@ static void operation_task_wrapper( NOT_IMPLEMENTED(); } -static Realm::Event register_task(Realm::Processor::Kind target_kind, - task_id_t func_id, - void (*task_body)(void const *, - size_t, - void const *, - size_t, - Realm::Processor)) { +Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)) { return Realm::Processor::register_task_by_kind( target_kind, /*global=*/false, From 8e4cd0957e8ed7e7767c407c7e526e7474f667a4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 12:21:57 -0800 Subject: [PATCH 23/88] Instance allocation. --- .../realm-execution/instance_allocation.h | 2 +- .../include/realm-execution/realm_context.h | 10 ++ .../realm-execution/instance_allocation.cc | 16 ++-- .../parallel_computation_graph_instance.cc | 18 ++-- .../src/realm-execution/realm_context.cc | 93 +++++++++++++++++++ 5 files changed, 124 insertions(+), 15 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index ea07cf0601..d1dfa3fda0 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -15,7 +15,7 @@ DynamicValueAttrs perform_instance_allocation_for_value(DynamicValueAttrs const &, Allocator &); -DynamicOpenDataflowGraph perform_instance_allocation( +std::pair perform_instance_allocation( DynamicOpenDataflowGraph const &, std::unordered_map const &preallocated, diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index c72fe30b72..90ef402fb6 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -21,9 +21,19 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + // Instance management + std::pair + create_instance(Realm::Memory memory, + TensorShape const &shape, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on = Realm::Event::NO_EVENT); + + // Get the current set of outstanding events Realm::Event get_outstanding_events(); protected: + // Compact AND CLEAR the outstanding event queue + // Important: USER MUST BLOCK on event or else use it, or it WILL BE LOST [[nodiscard]] Realm::Event merge_outstanding_events(); protected: diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 76d89313a6..0870117bfe 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -1,11 +1,13 @@ #include "realm-execution/instance_allocation.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "utils/bidict/generate_bidict.h" #include "utils/containers/all_are_true.h" #include "utils/containers/contains_key.h" +#include "utils/containers/make.h" #include "utils/containers/map_values.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" @@ -40,19 +42,19 @@ DynamicValueAttrs ASSERT(value.accessor == std::nullopt); ASSERT(value.instance == std::nullopt); - TensorShape shape = - get_piece_shape(assert_unwrap(value.parallel_tensor_shape)); + TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); - NOT_IMPLEMENTED(); - // GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + Realm::Memory memory = Realm::Memory::NO_MEMORY; // FIXME + auto [instance, ready] = + ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); DynamicValueAttrs result = value; - // result.accessor = DynamicTensorAccessor{accessor}; + result.instance = instance; return result; } -DynamicOpenDataflowGraph perform_instance_allocation( +std::pair perform_instance_allocation( DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, @@ -98,7 +100,7 @@ DynamicOpenDataflowGraph perform_instance_allocation( ASSERT(all_instances_are_allocated(result)); - return result; + return std::pair{result, ctx.get_outstanding_events()}; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index ec80519cf3..dddb624df3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -53,13 +53,12 @@ std::optional return this->logit_grad_tensor; } -static GenericTensorAccessorW - get_loss_tensor_accessor(DynamicOpenDataflowGraph const &dg, +static Realm::RegionInstance + get_loss_tensor_instance(DynamicOpenDataflowGraph const &dg, DynamicValueAttrs const &value) { return find_output_tensor(dg, value.tensor_guid, value.role) .value() - .second.accessor.value() - .get(); + .second.instance.value(); } ParallelComputationGraphInstance create_parallel_computation_graph_instance( @@ -91,11 +90,16 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - dg = perform_instance_allocation(dg, inputs, ctx); + Realm::Event instances_ready; + { + auto [dg2, ready] = perform_instance_allocation(dg, inputs, ctx); + dg = dg2; + instances_ready = ready; + } - std::optional logit_grad_tensor = + std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - return get_loss_tensor_accessor(dg, lgv); + return get_loss_tensor_instance(dg, lgv); }); dg = perform_device_state_initialization(dg, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index ede6ae6d8d..6ab7f992fa 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,7 +1,9 @@ #include "realm-execution/realm_context.h" +#include "op-attrs/datatype.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" +#include "utils/positive_int/positive_int.h" namespace FlexFlow { @@ -25,6 +27,97 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +std::pair + RealmContext::create_instance(Realm::Memory memory, + TensorShape const &shape, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on) { + std::vector dims{shape.dims.ff_ordered.begin(), + shape.dims.ff_ordered.end()}; + std::vector field_sizes{ + static_cast(int{size_of_datatype(shape.data_type)})}; + Realm::RegionInstance inst; + Realm::Event ready; + switch (shape.dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<1>(Realm::Point<1>::ZEROES(), + Realm::Point<1>(dims.data()) - + Realm::Point<1>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 2 + case 2: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<2>(Realm::Point<2>::ZEROES(), + Realm::Point<2>(dims.data()) - + Realm::Point<2>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 3 + case 3: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<3>(Realm::Point<3>::ZEROES(), + Realm::Point<3>(dims.data()) - + Realm::Point<3>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 4 + case 4: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<4>(Realm::Point<4>::ZEROES(), + Realm::Point<4>(dims.data()) - + Realm::Point<4>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 5 + case 5: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<5>(Realm::Point<5>::ZEROES(), + Realm::Point<5>(dims.data()) - + Realm::Point<5>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM", + fmt::to_string(shape.dims.ff_ordered.num_dims())); + break; + } + this->outstanding_events.push_back(ready); + return std::pair{inst, ready}; +} + Realm::Event RealmContext::get_outstanding_events() { Realm::Event result = this->merge_outstanding_events(); this->outstanding_events.push_back(result); From ad536719e1e8f7d05f204b56dd4a9becfd192178 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 12:26:48 -0800 Subject: [PATCH 24/88] Simplify dims and use constructors. --- .../src/realm-execution/realm_context.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 6ab7f992fa..4890eb4a5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -38,15 +38,15 @@ std::pair static_cast(int{size_of_datatype(shape.data_type)})}; Realm::RegionInstance inst; Realm::Event ready; - switch (shape.dims.ff_ordered.num_dims()) { + switch (dims.size()) { #if REALM_MAX_DIM >= 1 case 1: ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<1>(Realm::Point<1>::ZEROES(), - Realm::Point<1>(dims.data()) - - Realm::Point<1>::ONES()), + Realm::Rect<1>{Realm::Point<1>::ZEROES(), + Realm::Point<1>{dims.data()} - + Realm::Point<1>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -58,9 +58,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<2>(Realm::Point<2>::ZEROES(), - Realm::Point<2>(dims.data()) - - Realm::Point<2>::ONES()), + Realm::Rect<2>{Realm::Point<2>::ZEROES(), + Realm::Point<2>{dims.data()} - + Realm::Point<2>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -72,9 +72,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<3>(Realm::Point<3>::ZEROES(), - Realm::Point<3>(dims.data()) - - Realm::Point<3>::ONES()), + Realm::Rect<3>{Realm::Point<3>::ZEROES(), + Realm::Point<3>{dims.data()} - + Realm::Point<3>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -86,9 +86,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<4>(Realm::Point<4>::ZEROES(), - Realm::Point<4>(dims.data()) - - Realm::Point<4>::ONES()), + Realm::Rect<4>{Realm::Point<4>::ZEROES(), + Realm::Point<4>{dims.data()} - + Realm::Point<4>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -100,9 +100,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<5>(Realm::Point<5>::ZEROES(), - Realm::Point<5>(dims.data()) - - Realm::Point<5>::ONES()), + Realm::Rect<5>{Realm::Point<5>::ZEROES(), + Realm::Point<5>{dims.data()} - + Realm::Point<5>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -111,7 +111,7 @@ std::pair #endif default: PANIC("TensorShape dims greater than REALM_MAX_DIM", - fmt::to_string(shape.dims.ff_ordered.num_dims())); + fmt::to_string(dims.size())); break; } this->outstanding_events.push_back(ready); From 876ccc0b6ed1ff5067158e13cdec1d34a18d8bce Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 13:24:41 -0800 Subject: [PATCH 25/88] Refactor. --- .../src/realm-execution/realm_context.cc | 105 +++++++++--------- 1 file changed, 51 insertions(+), 54 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 4890eb4a5d..b2671f709e 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,5 +1,6 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" +#include "op-attrs/tensor_dims.dtg.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" @@ -27,91 +28,87 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +template +static Realm::Rect rect_from_dims(TensorDims const &dims) { + std::vector values{dims.ff_ordered.begin(), dims.ff_ordered.end()}; + return Realm::Rect{Realm::Point::ZEROES(), + Realm::Point{values.data()} - + Realm::Point::ONES()}; +} + std::pair RealmContext::create_instance(Realm::Memory memory, TensorShape const &shape, Realm::ProfilingRequestSet const &prs, Realm::Event wait_on) { - std::vector dims{shape.dims.ff_ordered.begin(), - shape.dims.ff_ordered.end()}; std::vector field_sizes{ static_cast(int{size_of_datatype(shape.data_type)})}; Realm::RegionInstance inst; Realm::Event ready; - switch (dims.size()) { + switch (shape.dims.ff_ordered.num_dims()) { #if REALM_MAX_DIM >= 1 case 1: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<1>{Realm::Point<1>::ZEROES(), - Realm::Point<1>{dims.data()} - - Realm::Point<1>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<1>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 2 case 2: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<2>{Realm::Point<2>::ZEROES(), - Realm::Point<2>{dims.data()} - - Realm::Point<2>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<2>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 3 case 3: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<3>{Realm::Point<3>::ZEROES(), - Realm::Point<3>{dims.data()} - - Realm::Point<3>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<3>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 4 case 4: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<4>{Realm::Point<4>::ZEROES(), - Realm::Point<4>{dims.data()} - - Realm::Point<4>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<4>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 5 case 5: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<5>{Realm::Point<5>::ZEROES(), - Realm::Point<5>{dims.data()} - - Realm::Point<5>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<5>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif default: PANIC("TensorShape dims greater than REALM_MAX_DIM", - fmt::to_string(dims.size())); + fmt::to_string(shape.dims.ff_ordered.num_dims())); break; } this->outstanding_events.push_back(ready); From 24012347cb981b48900971583b0405325d14d1fd Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 14:39:40 -0800 Subject: [PATCH 26/88] Sketch out device mapping. --- .../include/realm-execution/realm_context.h | 5 +++ .../realm-execution/instance_allocation.cc | 41 +++++++++++-------- .../src/realm-execution/realm_context.cc | 9 ++++ 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 90ef402fb6..6ba64338c9 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -16,6 +16,11 @@ struct RealmContext { RealmContext(RealmContext const &) = delete; RealmContext(RealmContext &&) = delete; + // Device mapping + Realm::Processor + map_device_coord_to_processor(MachineSpaceCoordinate const &); + Realm::Memory get_nearest_memory(Realm::Processor) const; + // Current device context Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 0870117bfe..33b7b54937 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -2,8 +2,10 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "utils/bidict/generate_bidict.h" #include "utils/containers/all_are_true.h" #include "utils/containers/contains_key.h" @@ -37,14 +39,17 @@ bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { } DynamicValueAttrs - perform_instance_allocation_for_value(DynamicValueAttrs const &value, + perform_instance_allocation_for_value(DynamicNodeAttrs const &node, + DynamicValueAttrs const &value, RealmContext &ctx) { ASSERT(value.accessor == std::nullopt); ASSERT(value.instance == std::nullopt); TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); - Realm::Memory memory = Realm::Memory::NO_MEMORY; // FIXME + MachineSpaceCoordinate device_coord = assert_unwrap(node.device_coord); + Realm::Processor proc = ctx.map_device_coord_to_processor(device_coord); + Realm::Memory memory = ctx.get_nearest_memory(proc); auto [instance, ready] = ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); @@ -66,20 +71,20 @@ std::pair perform_instance_allocation( ASSERT(v.instance == std::nullopt); } - std::unordered_set all_values = - unordered_set_of(get_dynamic_values(g)); - - bidict unallocated_to_allocated = - generate_bidict(all_values, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - if (contains_key(preallocated, v)) { - // FIXME: Attach external instance to existing - // allocation and use that - NOT_IMPLEMENTED(); - } else { - return perform_instance_allocation_for_value(v, ctx); - } - }); + bidict unallocated_to_allocated; + auto allocate = [&](DynamicNodeAttrs const &n, DynamicValueAttrs const &v) { + if (contains_key(preallocated, v)) { + // FIXME: Attach external instance to existing allocation and use that + NOT_IMPLEMENTED(); + } else { + if (contains_key(unallocated_to_allocated, v)) { + return unallocated_to_allocated.at_l(v); + } else { + DynamicValueAttrs v2 = perform_instance_allocation_for_value(n, v, ctx); + uallocated_to_allocated.equate(v, v2); + } + } + }; DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { @@ -87,13 +92,13 @@ std::pair perform_instance_allocation( /*inputs=*/map_values( i.inputs, [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return unallocated_to_allocated.at_l(v); + return allocate(i.node_attrs, v); }), /*node_attrs=*/i.node_attrs, /*outputs=*/ map_values(i.outputs, [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return unallocated_to_allocated.at_l(v); + return allocate(i.node_attrs, v); }), }; }); diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index b2671f709e..30343652d7 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -17,6 +17,15 @@ RealmContext::~RealmContext() { } } +Realm::Processor RealmContext::map_device_coord_to_processor( + MachineSpaceCoordinate const &device_coord) { + NOT_IMPLEMENTED(); +} + +Realm::Memory get_nearest_memory(Realm::Processor proc) const { + NOT_IMPLEMENTED(); +} + Allocator &RealmContext::get_current_device_allocator() const { NOT_IMPLEMENTED(); } From d71801214ec349136f70fdd9804d6d34209daeae Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 15:53:27 -0800 Subject: [PATCH 27/88] Move instance backing to a separate map, remove realm from task-spec. --- .../include/realm-execution}/fmt/instance.h | 6 +- .../realm-execution/instance_allocation.h | 8 +-- .../include/realm-execution}/realm.h | 0 .../include/realm-execution/realm_context.h | 3 +- .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../realm-execution/realm_task_registry.h | 2 +- .../tensor_instance_backing.dtg.toml | 24 +++++++ .../realm-execution/tensor_instance_backing.h | 12 ++++ .../src/realm-execution}/fmt/instance.cc | 2 +- .../realm-execution/instance_allocation.cc | 72 ++++--------------- .../parallel_computation_graph_instance.cc | 17 +---- .../src/realm-execution/realm_context.cc | 2 +- .../tensor_instance_backing.cc | 11 +++ lib/task-spec/CMakeLists.txt | 1 - .../dynamic_value_attrs.dtg.toml | 6 -- .../task-spec/dynamic_graph/loss_insertion.cc | 2 - ...ake_dynamic_open_dataflow_graph_from_cg.cc | 2 - ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 - .../dynamic_graph/update_insertion.cc | 1 - .../dynamic_open_dataflow_graph.cc | 3 - .../dynamic_graph/machine_slicing.cc | 1 - .../task-spec/dynamic_graph/pass_expansion.cc | 3 - .../dynamic_graph/shard_expansion.cc | 1 - 24 files changed, 74 insertions(+), 111 deletions(-) rename lib/{task-spec/include/task-spec/realm => realm-execution/include/realm-execution}/fmt/instance.h (83%) rename lib/{task-spec/include/task-spec/realm => realm-execution/include/realm-execution}/realm.h (100%) create mode 100644 lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tensor_instance_backing.h rename lib/{task-spec/src/task-spec/realm => realm-execution/src/realm-execution}/fmt/instance.cc (82%) create mode 100644 lib/realm-execution/src/realm-execution/tensor_instance_backing.cc diff --git a/lib/task-spec/include/task-spec/realm/fmt/instance.h b/lib/realm-execution/include/realm-execution/fmt/instance.h similarity index 83% rename from lib/task-spec/include/task-spec/realm/fmt/instance.h rename to lib/realm-execution/include/realm-execution/fmt/instance.h index 23979c7efc..b2efc59b7d 100644 --- a/lib/task-spec/include/task-spec/realm/fmt/instance.h +++ b/lib/realm-execution/include/realm-execution/fmt/instance.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#include "task-spec/realm/realm.h" +#include "realm-execution/realm.h" #include "utils/check_fmtable.h" #include #include @@ -15,8 +15,8 @@ struct formatter<::FlexFlow::Realm::RegionInstance, ::FlexFlow::Realm::RegionInstance>::value>> : formatter<::std::string> { template - auto format(::FlexFlow::Realm::RegionInstance const &m, - FormatContext &ctx) const -> decltype(ctx.out()) { + auto format(::FlexFlow::Realm::RegionInstance const &m, FormatContext &ctx) + -> decltype(ctx.out()) { std::string result = fmt::format("", m.id); return formatter::format(result, ctx); diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index d1dfa3fda0..59065694e9 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -2,20 +2,16 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" namespace FlexFlow { -bool no_instances_are_allocated(DynamicOpenDataflowGraph const &); -bool all_instances_are_allocated(DynamicOpenDataflowGraph const &); - -bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g); - DynamicValueAttrs perform_instance_allocation_for_value(DynamicValueAttrs const &, Allocator &); -std::pair perform_instance_allocation( +TensorInstanceBacking perform_instance_allocation( DynamicOpenDataflowGraph const &, std::unordered_map const &preallocated, diff --git a/lib/task-spec/include/task-spec/realm/realm.h b/lib/realm-execution/include/realm-execution/realm.h similarity index 100% rename from lib/task-spec/include/task-spec/realm/realm.h rename to lib/realm-execution/include/realm-execution/realm.h diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 6ba64338c9..bfc1a53cd3 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -4,7 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "task-spec/realm/realm.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "realm-execution/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index ebf3bb401e..bf5e8f72f1 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,8 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" +#include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 327cf9ffd0..8e6da1a2bd 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index d6bf5b927f..f800b1d8c4 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml new file mode 100644 index 0000000000..bdf08df59c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorInstanceBacking" +type = "struct" +features = [ + "eq", + #"fmt", + "hash", +] + +includes = [ + "", + "realm-execution/realm.h", + "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "realm-execution/fmt/instance.h", + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "backing" +type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, std::pair<::FlexFlow::Realm::RegionInstance, ::FlexFlow::Realm::Event>>" diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.h b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h new file mode 100644 index 0000000000..1d143b7409 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TENSOR_INSTANCE_BACKING_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TENSOR_INSTANCE_BACKING_H + +#include "realm-execution/tensor_instance_backing.dtg.h" + +namespace FlexFlow { + +TensorInstanceBacking make_empty_tensor_instance_backing(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.cc b/lib/realm-execution/src/realm-execution/fmt/instance.cc similarity index 82% rename from lib/task-spec/src/task-spec/realm/fmt/instance.cc rename to lib/realm-execution/src/realm-execution/fmt/instance.cc index fa15e1c16f..f8eabe9bb0 100644 --- a/lib/task-spec/src/task-spec/realm/fmt/instance.cc +++ b/lib/realm-execution/src/realm-execution/fmt/instance.cc @@ -1,4 +1,4 @@ -#include "task-spec/realm/fmt/instance.h" +#include "realm-execution/fmt/instance.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 33b7b54937..c033f0bac1 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -1,7 +1,9 @@ #include "realm-execution/instance_allocation.h" +#include "local-execution/tensor_allocation.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -17,95 +19,47 @@ namespace FlexFlow { -bool no_instances_are_allocated(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return !v.accessor.has_value() && !v.instance.has_value(); - })); -} - -bool all_instances_are_allocated(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return v.instance.has_value(); - })); -} - -bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return v.parallel_tensor_shape.has_value(); - })); -} - -DynamicValueAttrs +std::pair perform_instance_allocation_for_value(DynamicNodeAttrs const &node, DynamicValueAttrs const &value, RealmContext &ctx) { ASSERT(value.accessor == std::nullopt); - ASSERT(value.instance == std::nullopt); TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); MachineSpaceCoordinate device_coord = assert_unwrap(node.device_coord); Realm::Processor proc = ctx.map_device_coord_to_processor(device_coord); Realm::Memory memory = ctx.get_nearest_memory(proc); - auto [instance, ready] = - ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); - - DynamicValueAttrs result = value; - result.instance = instance; - - return result; + return ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); } -std::pair perform_instance_allocation( +TensorInstanceBacking perform_instance_allocation( DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, RealmContext &ctx) { - ASSERT(no_instances_are_allocated(g)); - ASSERT(instances_are_ready_for_allocation(g)); + ASSERT(no_tensors_are_allocated(g)); + ASSERT(tensors_are_ready_for_allocation(g)); for (DynamicValueAttrs const &v : keys(preallocated)) { ASSERT(v.accessor == std::nullopt); - ASSERT(v.instance == std::nullopt); } - bidict unallocated_to_allocated; + TensorInstanceBacking result = make_empty_tensor_instance_backing(); auto allocate = [&](DynamicNodeAttrs const &n, DynamicValueAttrs const &v) { if (contains_key(preallocated, v)) { // FIXME: Attach external instance to existing allocation and use that NOT_IMPLEMENTED(); } else { - if (contains_key(unallocated_to_allocated, v)) { - return unallocated_to_allocated.at_l(v); + if (contains_key(result.backing, v)) { + return result.backing.at(v); } else { - DynamicValueAttrs v2 = perform_instance_allocation_for_value(n, v, ctx); - uallocated_to_allocated.equate(v, v2); + result.backing.insert( + std::pair{v, perform_instance_allocation_for_value(n, v, ctx)}); } } }; - DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( - g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { - return DynamicNodeInvocation{ - /*inputs=*/map_values( - i.inputs, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return allocate(i.node_attrs, v); - }), - /*node_attrs=*/i.node_attrs, - /*outputs=*/ - map_values(i.outputs, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return allocate(i.node_attrs, v); - }), - }; - }); - - ASSERT(all_instances_are_allocated(result)); - - return std::pair{result, ctx.get_outstanding_events()}; + return result; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index dddb624df3..e0e4f769d3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -53,14 +53,6 @@ std::optional return this->logit_grad_tensor; } -static Realm::RegionInstance - get_loss_tensor_instance(DynamicOpenDataflowGraph const &dg, - DynamicValueAttrs const &value) { - return find_output_tensor(dg, value.tensor_guid, value.role) - .value() - .second.instance.value(); -} - ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, @@ -90,16 +82,11 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - Realm::Event instances_ready; - { - auto [dg2, ready] = perform_instance_allocation(dg, inputs, ctx); - dg = dg2; - instances_ready = ready; - } + TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - return get_loss_tensor_instance(dg, lgv); + return backing.backing.at(lgv).first; }); dg = perform_device_state_initialization(dg, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 30343652d7..4c02c13aa0 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -22,7 +22,7 @@ Realm::Processor RealmContext::map_device_coord_to_processor( NOT_IMPLEMENTED(); } -Realm::Memory get_nearest_memory(Realm::Processor proc) const { +Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc new file mode 100644 index 0000000000..53c2a2b271 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc @@ -0,0 +1,11 @@ +#include "realm-execution/tensor_instance_backing.h" + +namespace FlexFlow { + +TensorInstanceBacking make_empty_tensor_instance_backing() { + return TensorInstanceBacking{ + /*backing=*/{}, + }; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/CMakeLists.txt b/lib/task-spec/CMakeLists.txt index f4f5353f70..3c7c91af67 100644 --- a/lib/task-spec/CMakeLists.txt +++ b/lib/task-spec/CMakeLists.txt @@ -14,7 +14,6 @@ ff_add_library( pcg spdlog compiler - Realm::Realm ) add_subdirectory(test) diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 763ebf180f..89b94b1017 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -14,8 +14,6 @@ includes = [ "op-attrs/parallel_tensor_space_coordinate.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", - "task-spec/realm/fmt/instance.h", - "task-spec/realm/realm.h", ] src_includes = [ @@ -38,10 +36,6 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" name = "accessor" type = "std::optional<::FlexFlow::DynamicTensorAccessor>" -[[fields]] -name = "instance" -type = "std::optional<::FlexFlow::Realm::RegionInstance>" - [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 837ade2aad..4270119612 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -23,7 +23,6 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; DynamicValueAttrs logit_grad_value{ @@ -31,7 +30,6 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; DynamicNodeInvocation loss_invocation{ diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 294241b732..204597386e 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -45,7 +45,6 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -65,7 +64,6 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index eceb580a20..e90ef10398 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -44,7 +44,6 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -65,7 +64,6 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 23708f3779..58a32db6c1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -51,7 +51,6 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( DynamicValueAttrs value_attrs = output.second; ASSERT(value_attrs.accessor == std::nullopt); - ASSERT(value_attrs.instance == std::nullopt); DynamicNodeAttrs update_node_attrs = i.node_attrs; update_node_attrs.task_type = DynamicTaskType::UPD; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index bb9a45e59a..fc9110b6e4 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -16,7 +16,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -30,7 +29,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -44,7 +42,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index c28e12e0af..40d37f50df 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -76,7 +76,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index e57691b475..e8fcf2e40b 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -20,7 +20,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -114,7 +113,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -231,7 +229,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_type, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 4d88dde805..23fbb6e514 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -121,7 +121,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; From 1da64500be213566d245cbb881643e9a60a89b59 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 16:51:32 -0800 Subject: [PATCH 28/88] Implement processor queries. --- .../include/realm-execution/realm_context.h | 11 +++- .../parallel_computation_graph_instance.cc | 7 ++- .../src/realm-execution/realm_context.cc | 56 ++++++++++++++++++- .../src/realm-execution/realm_manager.cc | 6 +- 4 files changed, 72 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index bfc1a53cd3..73d60e9f50 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,14 +6,16 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include namespace FlexFlow { struct RealmContext { public: - RealmContext(); + RealmContext(Realm::Processor); virtual ~RealmContext(); + RealmContext() = delete; RealmContext(RealmContext const &) = delete; RealmContext(RealmContext &&) = delete; @@ -23,6 +25,7 @@ struct RealmContext { Realm::Memory get_nearest_memory(Realm::Processor) const; // Current device context + Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; @@ -42,9 +45,15 @@ struct RealmContext { // Important: USER MUST BLOCK on event or else use it, or it WILL BE LOST [[nodiscard]] Realm::Event merge_outstanding_events(); + void discover_machine_topology(); + protected: Realm::Runtime runtime; + Realm::Processor processor; std::vector outstanding_events; + std::unordered_map, + std::vector> + processors; }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index e0e4f769d3..5d6aeddf83 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -10,6 +10,7 @@ #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" +#include "utils/optional.h" namespace FlexFlow { @@ -74,10 +75,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_grad_value; if (loss_attrs) { auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( - dg, loss_attrs.value(), dynamic_tensor_guid_t{logit_tensor.value()}); + dg, + assert_unwrap(loss_attrs), + dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}); dg = dg2; logit_grad_value = logit_grad_v; - inputs.insert(std::pair{label_v, label_tensor.value()}); + inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); } dg = perform_update_insertion(dg, optimizer_attrs); diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 4c02c13aa0..bf5f337796 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,14 +1,19 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" +#include "pcg/device_type.dtg.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/transform.h" #include "utils/exception.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/one_to_many/one_to_many.h" #include "utils/positive_int/positive_int.h" namespace FlexFlow { -RealmContext::RealmContext() {} +RealmContext::RealmContext(Realm::Processor proc) : processor(proc) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -17,13 +22,45 @@ RealmContext::~RealmContext() { } } +static std::tuple + convert_machine_space_coordinate( + MachineSpaceCoordinate const &device_coord) { + Realm::AddressSpace as = int{device_coord.node_idx}; + Realm::Processor::Kind kind; + switch (device_coord.device_type) { + case DeviceType::CPU: + kind = Realm::Processor::Kind::LOC_PROC; + break; + case DeviceType::GPU: + kind = Realm::Processor::Kind::TOC_PROC; + break; + default: + PANIC("Unhandled DeviceType", fmt::to_string(device_coord.device_type)); + break; + } + nonnegative_int proc_in_node = device_coord.device_idx; + return std::tuple{as, kind, proc_in_node}; +} + Realm::Processor RealmContext::map_device_coord_to_processor( MachineSpaceCoordinate const &device_coord) { - NOT_IMPLEMENTED(); + this->discover_machine_topology(); + auto [as, kind, proc_in_node] = + convert_machine_space_coordinate(device_coord); + return this->processors.at(std::pair{as, kind}).at(int{proc_in_node}); } Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { - NOT_IMPLEMENTED(); + // FIMXE: this isn't going to do what you expect until + // https://github.com/StanfordLegion/realm/pull/392 merges + Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); + mq.best_affinity_to(proc); + ASSERT(mq.count() > 0); + return mq.first(); +} + +Realm::Processor RealmContext::get_current_processor() const { + return this->processor; } Allocator &RealmContext::get_current_device_allocator() const { @@ -136,4 +173,17 @@ Realm::Event RealmContext::merge_outstanding_events() { return result; } +void RealmContext::discover_machine_topology() { + if (!this->processors.empty()) { + return; + } + + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + for (Realm::Processor proc : pq) { + Realm::AddressSpace as = proc.address_space(); + Realm::Processor::Kind kind = proc.kind(); + this->processors[std::pair{as, kind}].push_back(proc); + } +} + } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 63c6266948..f8a3e4014b 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/realm_context.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/realm_task_registry.h" #include "realm-execution/task_id_t.dtg.h" @@ -15,11 +16,12 @@ static void controller_task_wrapper(void const *args, std::function thunk = *reinterpret_cast const *>(args); - RealmContext ctx; + RealmContext ctx{proc}; thunk(ctx); } -RealmManager::RealmManager(int *argc, char ***argv) { +RealmManager::RealmManager(int *argc, char ***argv) + : RealmContext(Realm::Processor::NO_PROC) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); From a507ce10ef1269547713687e5d59249598123dc2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 7 Feb 2026 11:41:18 -0800 Subject: [PATCH 29/88] Enable PRealm. --- .flake/pkgs/realm.nix | 10 ++++++---- lib/realm-execution/include/realm-execution/realm.h | 2 ++ .../realm-execution/tensor_instance_backing.dtg.toml | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix index 1249c0ae28..b809573690 100644 --- a/.flake/pkgs/realm.nix +++ b/.flake/pkgs/realm.nix @@ -3,6 +3,7 @@ , fetchFromGitHub , cmake , cudaPackages ? { } +, zlib , maxDim ? 5 }: @@ -12,14 +13,13 @@ in stdenv.mkDerivation rec { pname = "realm"; - version = "2025-01-06"; + version = "2026-02-06"; - # This version is compatible with Legion 7be1abd0207eb1126c7629b16d1123fa6f58ce9d src = fetchFromGitHub { owner = "StanfordLegion"; repo = "realm"; - rev = "0ef7edc8c012d4ab6a50805c044cec8a8edeae33"; - sha256 = "sha256-57/a1lAgs+ajpRn0y0Lk1gP5nKt+N08WW0DIJP4vdho="; + rev = "0405b67ca14b586f7dec0dcddee194cecee7efa6"; + sha256 = "sha256-iUPVV1rh3QuyDKgXuu8aDlaZGlNwcpPvPsSVLWp8tr4="; }; nativeBuildInputs = [ @@ -29,11 +29,13 @@ stdenv.mkDerivation rec { cmakeFlags = [ "-DBUILD_SHARED_LIBS=ON" "-DREALM_ENABLE_CUDA=ON" + "-DREALM_ENABLE_PREALM=ON" "-DREALM_MAX_DIM=${toString maxDim}" ]; buildInputs = [ cudatoolkit + zlib ]; meta = with lib; { diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h index 8123c9e9fa..b6913e66f5 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H +#define FLEXFLOW_USE_PREALM + #ifdef FLEXFLOW_USE_PREALM #include #else diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml index bdf08df59c..e6a8bd58d9 100644 --- a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -4,7 +4,7 @@ type = "struct" features = [ "eq", #"fmt", - "hash", + #"hash", ] includes = [ From 7b60556a0a9330e972eb30e343a4770b3ec3d0c5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:18:32 -0800 Subject: [PATCH 30/88] Move tasks to dedicated file, stub out device state init, shuffle directories. --- .../distributed_device_state_initialization.h | 21 ++++++++++++++++ .../{ => tasks}/realm_task_id_t.h | 2 +- .../{ => tasks}/realm_task_registry.h | 2 +- .../realm-execution/tasks/realm_tasks.h | 15 ++++++++++++ .../{ => tasks}/task_id_t.dtg.toml | 0 .../realm-execution/{ => tasks}/task_id_t.h | 2 +- ...distributed_device_state_initialization.cc | 15 ++++++++++++ .../parallel_computation_graph_instance.cc | 17 ++++++------- .../src/realm-execution/realm_context.cc | 4 ++-- .../src/realm-execution/realm_manager.cc | 23 +++--------------- .../{ => tasks}/realm_task_id_t.cc | 2 +- .../{ => tasks}/realm_task_registry.cc | 21 ++++++++-------- .../src/realm-execution/tasks/realm_tasks.cc | 24 +++++++++++++++++++ .../realm-execution/{ => tasks}/task_id_t.cc | 2 +- 14 files changed, 104 insertions(+), 46 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h rename lib/realm-execution/include/realm-execution/{ => tasks}/realm_task_id_t.h (86%) rename lib/realm-execution/include/realm-execution/{ => tasks}/realm_task_registry.h (94%) create mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_tasks.h rename lib/realm-execution/include/realm-execution/{ => tasks}/task_id_t.dtg.toml (100%) rename lib/realm-execution/include/realm-execution/{ => tasks}/task_id_t.h (94%) create mode 100644 lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc rename lib/realm-execution/src/realm-execution/{ => tasks}/realm_task_id_t.cc (82%) rename lib/realm-execution/src/realm-execution/{ => tasks}/realm_task_registry.cc (86%) create mode 100644 lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc rename lib/realm-execution/src/realm-execution/{ => tasks}/task_id_t.cc (99%) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h new file mode 100644 index 0000000000..4121f10341 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H + +#include "kernels/profiling_settings.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph perform_distributed_device_state_initialization( + DynamicOpenDataflowGraph const &, + RealmContext &ctx, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h similarity index 86% rename from lib/realm-execution/include/realm-execution/realm_task_id_t.h rename to lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h index 8e6da1a2bd..cd5eba2f34 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #include "realm-execution/realm.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h similarity index 94% rename from lib/realm-execution/include/realm-execution/realm_task_registry.h rename to lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h index f800b1d8c4..a0277382bf 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #include "realm-execution/realm.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h new file mode 100644 index 0000000000..d2b104faa8 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H + +#include "realm-execution/realm.h" + +namespace FlexFlow { + +void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); + +void controller_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml similarity index 100% rename from lib/realm-execution/include/realm-execution/task_id_t.dtg.toml rename to lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h similarity index 94% rename from lib/realm-execution/include/realm-execution/task_id_t.h rename to lib/realm-execution/include/realm-execution/tasks/task_id_t.h index 38b82ad9e0..4a5d9299ae 100644 --- a/lib/realm-execution/include/realm-execution/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h @@ -3,7 +3,7 @@ #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc new file mode 100644 index 0000000000..c6d0621f3d --- /dev/null +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -0,0 +1,15 @@ +#include "realm-execution/distributed_device_state_initialization.h" +#include "utils/exception.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph perform_distributed_device_state_initialization( + DynamicOpenDataflowGraph const &dg, + RealmContext &ctx, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 5d6aeddf83..bb763334d5 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,6 +1,6 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" -#include "local-execution/device_state_initialization.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" @@ -92,14 +92,15 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( return backing.backing.at(lgv).first; }); - dg = perform_device_state_initialization(dg, - ctx.get_current_device_allocator(), - profiling_settings, - ctx.get_current_device_handle(), - iteration_config, - optimizer_attrs, - ctx.get_current_device_idx()); + dg = perform_distributed_device_state_initialization( + dg, ctx, profiling_settings, iteration_config, optimizer_attrs); NOT_IMPLEMENTED(); + + // TODO list: + // * per-device state initialization (RPC mechanism?) + // * Realm allocator + // * task body + // * external instances } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index bf5f337796..37f72ba86d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -2,8 +2,8 @@ #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" #include "pcg/device_type.dtg.h" -#include "realm-execution/realm_task_id_t.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/containers/contains_key.h" #include "utils/containers/transform.h" #include "utils/exception.h" diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index f8a3e4014b..9d8b9f0b7f 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,25 +1,12 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/realm_task_id_t.h" -#include "realm-execution/realm_task_registry.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { -static void controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); - - RealmContext ctx{proc}; - thunk(ctx); -} - RealmManager::RealmManager(int *argc, char ***argv) : RealmContext(Realm::Processor::NO_PROC) { bool ok = this->runtime.init(argc, argv); @@ -27,10 +14,6 @@ RealmManager::RealmManager(int *argc, char ***argv) // Register all tasks at initialization time so we don't need to later register_all_tasks().wait(); - register_task(Realm::Processor::LOC_PROC, - task_id_t::CONTROLLER_TASK_ID, - controller_task_wrapper) - .wait(); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc similarity index 82% rename from lib/realm-execution/src/realm-execution/realm_task_id_t.cc rename to lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc index 50b23dfe86..ec1aa143a6 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc @@ -1,4 +1,4 @@ -#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_id_t.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc similarity index 86% rename from lib/realm-execution/src/realm-execution/realm_task_registry.cc rename to lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 436a6af3f3..7e30edbc9f 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,14 +1,10 @@ -#include "realm-execution/realm_task_registry.h" -#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/realm_tasks.h" #include "utils/exception.h" namespace FlexFlow { -static void operation_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - Realm::Event register_task(Realm::Processor::Kind target_kind, task_id_t func_id, void (*task_body)(void const *, @@ -110,12 +106,15 @@ Realm::Event register_all_tasks() { }; for (task_id_t task_id : task_ids) { - pending_registrations.push_back(register_task( - Realm::Processor::LOC_PROC, task_id, operation_task_wrapper)); - pending_registrations.push_back(register_task( - Realm::Processor::TOC_PROC, task_id, operation_task_wrapper)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, task_id, op_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::TOC_PROC, task_id, op_task_body)); } + pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, + task_id_t::CONTROLLER_TASK_ID, + controller_task_body)); return Realm::Event::merge_events(pending_registrations); } diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc new file mode 100644 index 0000000000..a50f7f3e47 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc @@ -0,0 +1,24 @@ +#include "realm-execution/tasks/realm_tasks.h" +#include "realm-execution/realm_context.h" + +namespace FlexFlow { + +void op_task_body( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} + +void controller_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); + + RealmContext ctx{proc}; + thunk(ctx); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc similarity index 99% rename from lib/realm-execution/src/realm-execution/task_id_t.cc rename to lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index 3521f50c02..5a99f2bea8 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -1,4 +1,4 @@ -#include "realm-execution/task_id_t.h" +#include "realm-execution/tasks/task_id_t.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" From 950e6e8053441096098d68ae26992260303e4629 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:47:41 -0800 Subject: [PATCH 31/88] Make use of task args struct. --- .../realm-execution/tasks/realm_tasks.h | 20 +++++++++++++++++++ .../src/realm-execution/tasks/realm_tasks.cc | 14 +++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h index d2b104faa8..ceda961914 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h @@ -2,11 +2,31 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H #include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceInitTaskArgs { +public: + DynamicNodeInvocation *invocation; +}; +static_assert(std::has_unique_object_representations_v); + +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +struct ControllerTaskArgs { +public: + std::function thunk; +}; + void controller_task_body( void const *, size_t, void const *, size_t, Realm::Processor); diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc index a50f7f3e47..b1da1f0694 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc @@ -1,5 +1,6 @@ #include "realm-execution/tasks/realm_tasks.h" #include "realm-execution/realm_context.h" +#include "utils/exception.h" namespace FlexFlow { @@ -8,17 +9,22 @@ void op_task_body( NOT_IMPLEMENTED(); } +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} + void controller_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); + ASSERT(arglen == sizeof(ControllerTaskArgs)); + ControllerTaskArgs task_args = + *reinterpret_cast(args); RealmContext ctx{proc}; - thunk(ctx); + task_args.thunk(ctx); } } // namespace FlexFlow From 901f0cb998b1f38812a6f58609fc203201e1f4be Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:54:01 -0800 Subject: [PATCH 32/88] Use task args struct. --- lib/realm-execution/src/realm-execution/realm_manager.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 9d8b9f0b7f..dec2ed7847 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -2,6 +2,7 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/realm_tasks.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/exception.h" @@ -29,11 +30,13 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); + ControllerTaskArgs task_args; + task_args.thunk = thunk; Realm::Event task_complete = this->runtime.collective_spawn( target_proc, get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), - &thunk, - sizeof(thunk)); + &task_args, + sizeof(task_args)); this->outstanding_events.push_back(task_complete); return task_complete; } From 0535c34253805f0f3e940c0543e9c8781ccee975 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 14:26:45 -0800 Subject: [PATCH 33/88] Refactor task APIs. --- .../include/realm-execution/realm_context.h | 18 +++++++ .../tasks/impl/controller_task.h | 19 +++++++ .../tasks/impl/device_init_return_task.h | 21 ++++++++ .../tasks/impl/device_init_task.h | 24 +++++++++ .../realm-execution/tasks/impl/op_task.h | 21 ++++++++ .../tasks/realm_task_registry.h | 4 +- .../realm-execution/tasks/realm_tasks.h | 35 ------------ .../realm-execution/tasks/task_id_t.dtg.toml | 3 ++ .../include/realm-execution/tasks/task_id_t.h | 4 +- .../src/realm-execution/realm_context.cc | 35 ++++++++++++ .../src/realm-execution/realm_manager.cc | 15 +----- .../tasks/impl/controller_task.cc | 37 +++++++++++++ .../tasks/impl/device_init_return_task.cc | 49 +++++++++++++++++ .../tasks/impl/device_init_task.cc | 54 +++++++++++++++++++ .../src/realm-execution/tasks/impl/op_task.cc | 48 +++++++++++++++++ .../tasks/realm_task_registry.cc | 5 +- .../src/realm-execution/tasks/realm_tasks.cc | 30 ----------- 17 files changed, 339 insertions(+), 83 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/op_task.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_tasks.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc delete mode 100644 lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 73d60e9f50..422c4f4027 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,6 +6,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include namespace FlexFlow { @@ -30,6 +31,23 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + // Task creation + Realm::Event spawn_task(Realm::Processor proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + + Realm::Event + collective_spawn_task(Realm::Processor target_proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + // Instance management std::pair create_instance(Realm::Memory memory, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h new file mode 100644 index 0000000000..d4c397bb37 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_CONTROLLER_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_CONTROLLER_TASK_H + +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" + +namespace FlexFlow { + +void controller_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + collective_spawn_controller_task(RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h new file mode 100644 index 0000000000..fc6c8bdb9f --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H + +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_init_return_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h new file mode 100644 index 0000000000..bd4ca269df --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H + +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h new file mode 100644 index 0000000000..4c3e6d38d1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H + +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h index a0277382bf..8114f1a82c 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_REGISTRY_H #include "realm-execution/realm.h" #include "realm-execution/tasks/task_id_t.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h deleted file mode 100644 index ceda961914..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H - -#include "realm-execution/realm.h" -#include "realm-execution/realm_context.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include - -namespace FlexFlow { - -void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); - -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceInitTaskArgs { -public: - DynamicNodeInvocation *invocation; -}; -static_assert(std::has_unique_object_representations_v); - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -struct ControllerTaskArgs { -public: - std::function thunk; -}; - -void controller_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -} // namespace FlexFlow - -#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index 0336bc81a4..34e5183488 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -11,6 +11,9 @@ features = [ [[values]] name = "CONTROLLER_TASK_ID" +[[values]] +name = "DEVICE_INIT_RETURN_TASK_ID" + [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h index 4a5d9299ae..53945d2e5b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_TASK_ID_T_H #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 37f72ba86d..7e6c73c9e7 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -74,6 +74,41 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event + RealmContext::spawn_task(Realm::Processor proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + Realm::Event result = proc.spawn(get_realm_task_id_for_task_id(task_id), + args, + arglen, + requests, + wait_on, + priority); + this->outstanding_events.push_back(result); + return result; +} + +Realm::Event RealmContext::collective_spawn_task(Realm::Processor target_proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::Event wait_on, + int priority) { + Realm::Event result = + this->runtime.collective_spawn(target_proc, + get_realm_task_id_for_task_id(task_id), + args, + arglen, + wait_on, + priority); + this->outstanding_events.push_back(result); + return result; +} + template static Realm::Rect rect_from_dims(TensorDims const &dims) { std::vector values{dims.ff_ordered.begin(), dims.ff_ordered.end()}; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index dec2ed7847..7233103cc3 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,10 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/impl/controller_task.h" #include "realm-execution/tasks/realm_task_registry.h" -#include "realm-execution/tasks/realm_tasks.h" -#include "realm-execution/tasks/task_id_t.dtg.h" -#include "utils/exception.h" namespace FlexFlow { @@ -30,15 +27,7 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); - ControllerTaskArgs task_args; - task_args.thunk = thunk; - Realm::Event task_complete = this->runtime.collective_spawn( - target_proc, - get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), - &task_args, - sizeof(task_args)); - this->outstanding_events.push_back(task_complete); - return task_complete; + return collective_spawn_controller_task(*this, target_proc, thunk); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc new file mode 100644 index 0000000000..2fd5cee52d --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc @@ -0,0 +1,37 @@ +#include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/task_id_t.h" + +namespace FlexFlow { + +struct ControllerTaskArgs { +public: + std::function thunk; +}; + +void controller_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(ControllerTaskArgs)); + ControllerTaskArgs task_args = + *reinterpret_cast(args); + + RealmContext ctx{proc}; + task_args.thunk(ctx); +} + +Realm::Event collective_spawn_controller_task( + RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk) { + ControllerTaskArgs task_args; + task_args.thunk = thunk; + + return ctx.collective_spawn_task(target_proc, + task_id_t::CONTROLLER_TASK_ID, + &task_args, + sizeof(task_args)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc new file mode 100644 index 0000000000..fa421cda30 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc @@ -0,0 +1,49 @@ +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceInitReturnTaskArgs { +public: + DeviceInitReturnTaskArgs() = delete; + DeviceInitReturnTaskArgs(DeviceSpecificPerDeviceOpState result, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecificPerDeviceOpState result; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; + +void device_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceInitReturnTaskArgs)); + DeviceInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr) { + DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc new file mode 100644 index 0000000000..0deb8407c4 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -0,0 +1,54 @@ +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/task_id_t.h" +#include "utils/optional.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceInitTaskArgs { +public: + DynamicNodeInvocation const *invocation; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; +static_assert(std::has_unique_object_representations_v); + +void device_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceInitTaskArgs)); + DeviceInitTaskArgs task_args = + *reinterpret_cast(args); + + // FIXME: not safe to dereference unless we're on the same address space + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + NOT_IMPLEMENTED(); +} + +Realm::Event + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr) { + DeviceInitTaskArgs task_args; + task_args.invocation = &invocation; + task_args.origin_proc = ctx.get_current_processor(); + task_args.origin_result_ptr = result_ptr; + + return ctx.spawn_task(target_proc, + assert_unwrap(get_init_task_id_for_op_attrs( + assert_unwrap(invocation.node_attrs.op_attrs))), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc new file mode 100644 index 0000000000..9d9a36e2d5 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -0,0 +1,48 @@ +#include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/task_id_t.h" +#include "utils/optional.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct OpTaskArgs { +public: + DynamicNodeInvocation const *invocation; + Realm::Processor origin_proc; +}; +static_assert(std::has_unique_object_representations_v); + +void op_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(OpTaskArgs)); + OpTaskArgs task_args = *reinterpret_cast(args); + + // FIXME: not safe to dereference unless we're on the same address space + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + NOT_IMPLEMENTED(); +} + +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs) { + OpTaskArgs task_args; + task_args.invocation = &invocation; + return ctx.spawn_task( + target_proc, + assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 7e30edbc9f..c604d1b06a 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/impl/controller_task.h" +#include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/realm_task_id_t.h" -#include "realm-execution/tasks/realm_tasks.h" #include "utils/exception.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc deleted file mode 100644 index b1da1f0694..0000000000 --- a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "realm-execution/tasks/realm_tasks.h" -#include "realm-execution/realm_context.h" -#include "utils/exception.h" - -namespace FlexFlow { - -void op_task_body( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - -void controller_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(ControllerTaskArgs)); - ControllerTaskArgs task_args = - *reinterpret_cast(args); - - RealmContext ctx{proc}; - task_args.thunk(ctx); -} - -} // namespace FlexFlow From 1d656480c15263a5aa368ad9f3949fe8bc5fe1a9 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 14:45:44 -0800 Subject: [PATCH 34/88] Finish implementation of device init task. --- .../tasks/impl/device_init_task.h | 15 +++--- .../realm-execution/tasks/realm_task_id_t.h | 4 +- .../tasks/impl/device_init_task.cc | 50 ++++++++++++++++--- .../tasks/realm_task_registry.cc | 13 ++++- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index bd4ca269df..ebce5fed4c 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -1,23 +1,26 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H +#include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { void device_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - std::optional const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); +Realm::Event spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h index cd5eba2f34..a3c6891fb0 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_ID_T_H #include "realm-execution/realm.h" #include "realm-execution/tasks/task_id_t.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 0deb8407c4..c27fc5802b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/impl/device_init_task.h" +#include "local-execution/device_state_initialization.h" +#include "realm-execution/tasks/impl/device_init_return_task.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/optional.h" +#include #include namespace FlexFlow { @@ -9,8 +12,22 @@ namespace FlexFlow { // now just pass the pointer and assume we're running inside a single address // space struct DeviceInitTaskArgs { + DeviceInitTaskArgs() = delete; + DeviceInitTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : invocation(invocation), profiling_settings(profiling_settings), + iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), + origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + public: DynamicNodeInvocation const *invocation; + ProfilingSettings const *profiling_settings; + FFIterationConfig const *iteration_config; + OptimizerAttrs const *optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; @@ -29,19 +46,40 @@ void device_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - NOT_IMPLEMENTED(); + DynamicNodeInvocation result_invocation = + initialize_node(*task_args.invocation, + ctx.get_current_device_allocator(), + *task_args.profiling_settings, + ctx.get_current_device_handle(), + *task_args.iteration_config, + *task_args.optimizer_attrs, + ctx.get_current_device_idx()); + std::optional result_state = + result_invocation.node_attrs.per_device_op_state; + if (result_state) { + spawn_device_init_return_task(ctx, + task_args.origin_proc, + assert_unwrap(result_state), + task_args.origin_result_ptr); + } } Realm::Event spawn_device_init_task(RealmContext &ctx, Realm::Processor &target_proc, DynamicNodeInvocation const &invocation, - std::optional const &optimizer_attrs, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, DeviceSpecificPerDeviceOpState *result_ptr) { - DeviceInitTaskArgs task_args; - task_args.invocation = &invocation; - task_args.origin_proc = ctx.get_current_processor(); - task_args.origin_result_ptr = result_ptr; + DeviceInitTaskArgs task_args{ + &invocation, + &profiling_settings, + &iteration_config, + &optimizer_attrs, + ctx.get_current_processor(), + result_ptr, + }; return ctx.spawn_task(target_proc, assert_unwrap(get_init_task_id_for_op_attrs( diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index c604d1b06a..c63d4727a9 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -26,7 +26,7 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Event register_all_tasks() { std::vector pending_registrations; - std::vector task_ids = { + std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, task_id_t::COMBINE_INIT_TASK_ID, @@ -44,7 +44,14 @@ Realm::Event register_all_tasks() { task_id_t::REPARTITION_INIT_TASK_ID, task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, + }; + for (task_id_t task_id : init_task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::TOC_PROC, task_id, device_init_task_body)); + } + + std::vector task_ids = { // Forward tasks task_id_t::BATCHMATMUL_FWD_TASK_ID, task_id_t::BATCHNORM_FWD_TASK_ID, @@ -118,6 +125,10 @@ Realm::Event register_all_tasks() { pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, task_id_t::CONTROLLER_TASK_ID, controller_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_INIT_RETURN_TASK_ID, + device_init_return_task_body)); return Realm::Event::merge_events(pending_registrations); } From 95df07366b40761444e61780e4456bc8a049b880 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 15:14:24 -0800 Subject: [PATCH 35/88] Finish implementation of device state initialization. --- .../tasks/impl/device_init_task.h | 15 ++--- ...distributed_device_state_initialization.cc | 57 ++++++++++++++++++- .../tasks/impl/device_init_task.cc | 29 +++++----- 3 files changed, 79 insertions(+), 22 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index ebce5fed4c..af07139483 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -14,13 +14,14 @@ namespace FlexFlow { void device_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); +std::optional + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index c6d0621f3d..f7fcea87e7 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,5 +1,11 @@ #include "realm-execution/distributed_device_state_initialization.h" -#include "utils/exception.h" +#include "local-execution/device_state_initialization.h" +#include "realm-execution/tasks/impl/device_init_task.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/optional.h" +#include +#include namespace FlexFlow { @@ -9,7 +15,54 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs) { - NOT_IMPLEMENTED(); + + // Initialize all operators and save the per-device op state + ASSERT(no_nodes_are_initialized(dg)); + + std::unordered_map + result_map; + for (DynamicNodeInvocation const &invocation : dg.invocations) { + Realm::Processor target_proc = ctx.map_device_coord_to_processor( + assert_unwrap(invocation.node_attrs.device_coord)); + + // FIXME: in the absense of a real serializer we're just tossing around raw + // bytes, which means we need to bypass the constructor for this type (yes, + // ugh) + DeviceSpecificPerDeviceOpState *output = + static_cast( + malloc(sizeof(DeviceSpecificPerDeviceOpState))); + std::optional result = + spawn_device_init_task(ctx, + target_proc, + invocation, + profiling_settings, + iteration_config, + optimizer_attrs, + output); + if (result) { + result_map[invocation] = output; + } else { + free(output); + } + } + + ctx.get_outstanding_events().wait(); + + DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( + dg, [&](DynamicNodeInvocation const &invocation) { + DynamicNodeInvocation result = invocation; + auto device_state = result_map.find(invocation); + if (device_state != result_map.end()) { + result.node_attrs.per_device_op_state = *device_state->second; + } + return result; + }); + + for (auto &[invocation, output] : result_map) { + free(output); + } + + return result; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index c27fc5802b..91b753d639 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -1,6 +1,7 @@ #include "realm-execution/tasks/impl/device_init_task.h" #include "local-execution/device_state_initialization.h" #include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/optional.h" #include @@ -56,15 +57,13 @@ void device_init_task_body(void const *args, ctx.get_current_device_idx()); std::optional result_state = result_invocation.node_attrs.per_device_op_state; - if (result_state) { - spawn_device_init_return_task(ctx, - task_args.origin_proc, - assert_unwrap(result_state), - task_args.origin_result_ptr); - } + spawn_device_init_return_task(ctx, + task_args.origin_proc, + assert_unwrap(result_state), + task_args.origin_result_ptr); } -Realm::Event +std::optional spawn_device_init_task(RealmContext &ctx, Realm::Processor &target_proc, DynamicNodeInvocation const &invocation, @@ -81,12 +80,16 @@ Realm::Event result_ptr, }; - return ctx.spawn_task(target_proc, - assert_unwrap(get_init_task_id_for_op_attrs( - assert_unwrap(invocation.node_attrs.op_attrs))), - &task_args, - sizeof(task_args), - Realm::ProfilingRequestSet{}); + std::optional task_id = get_init_task_id_for_op_attrs( + assert_unwrap(invocation.node_attrs.op_attrs)); + if (task_id) { + return ctx.spawn_task(target_proc, + assert_unwrap(task_id), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); + } + return std::nullopt; } } // namespace FlexFlow From 9a41fb42be943f55c9e326e15f8fa18905b4e9bc Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 15:15:51 -0800 Subject: [PATCH 36/88] Block on initialization. --- .../parallel_computation_graph_instance.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index bb763334d5..cdb3e5fe46 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -87,6 +87,10 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); + // FIXME: for now we're going to be lazy and block on everything rather than + // do fine-grained dependencies + ctx.get_outstanding_events().wait(); + std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { return backing.backing.at(lgv).first; From de338aea87e29f8552d8f306286f2dd4e6dabe0e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 17:04:10 -0800 Subject: [PATCH 37/88] Wire up rest of Realm implementation. --- .../parallel_computation_graph_instance.h | 19 +-- .../realm-execution/tasks/impl/op_task.h | 8 +- .../parallel_computation_graph_instance.cc | 159 +++++++++++++++--- .../tasks/impl/device_init_task.cc | 13 +- .../src/realm-execution/tasks/impl/op_task.cc | 49 +++++- 5 files changed, 206 insertions(+), 42 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index f361cec3ca..0886dcf4c0 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -23,30 +23,27 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: ParallelComputationGraphInstance(RealmContext &, - DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, std::optional const &, - std::optional); - DynamicOpenDataflowGraph const &get_dynamic_dataflow_graph() const; - Allocator &get_allocator() const; - std::vector const &get_topological_ordering() const; + std::optional); + RealmContext &get_realm_context(); + std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); std::optional const &get_loss_attrs() const; - std::optional get_loss_tensor_accessor() const; + std::optional get_loss_tensor_instance() const; private: - RealmContext &realm; - DynamicOpenDataflowGraph dataflow_graph; - std::vector topological_ordering; + RealmContext &ctx; + std::vector execution_order; OptimizerAttrs optimizer_attrs; std::optional loss_attrs; - std::optional logit_grad_tensor; + std::optional logit_grad_tensor; }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmContext &realm, + RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 4c3e6d38d1..dd75ed66ea 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -1,10 +1,13 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { @@ -12,8 +15,11 @@ void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor &target_proc, + Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index cdb3e5fe46..2683d019c3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -2,6 +2,8 @@ #include "pcg/optimizer_attrs.h" #include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/tasks/impl/op_task.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" @@ -9,33 +11,27 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" -#include "utils/exception.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/optional.h" namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( - RealmContext &realm, - DynamicOpenDataflowGraph dataflow_graph, - std::vector const &topological_ordering, + RealmContext &ctx, + std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, - std::optional logit_grad_tensor) - : realm(realm), dataflow_graph(dataflow_graph), - topological_ordering(topological_ordering), + std::optional logit_grad_tensor) + : ctx(ctx), execution_order(execution_order), optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), logit_grad_tensor(logit_grad_tensor) {} -DynamicOpenDataflowGraph const & - ParallelComputationGraphInstance::get_dynamic_dataflow_graph() const { - return this->dataflow_graph; -} -Allocator &ParallelComputationGraphInstance::get_allocator() const { - return this->realm.get_current_device_allocator(); +RealmContext &ParallelComputationGraphInstance::get_realm_context() { + return this->ctx; } std::vector const & - ParallelComputationGraphInstance::get_topological_ordering() const { - return this->topological_ordering; + ParallelComputationGraphInstance::get_execution_order() const { + return this->execution_order; } OptimizerAttrs const & ParallelComputationGraphInstance::get_optimizer_attrs() const { @@ -49,8 +45,8 @@ std::optional const & ParallelComputationGraphInstance::get_loss_attrs() const { return this->loss_attrs; } -std::optional - ParallelComputationGraphInstance::get_loss_tensor_accessor() const { +std::optional + ParallelComputationGraphInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; } @@ -88,7 +84,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); // FIXME: for now we're going to be lazy and block on everything rather than - // do fine-grained dependencies + // do fine-grained dependencies on instances ctx.get_outstanding_events().wait(); std::optional logit_grad_tensor = @@ -98,13 +94,134 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_distributed_device_state_initialization( dg, ctx, profiling_settings, iteration_config, optimizer_attrs); - NOT_IMPLEMENTED(); + + // Compute the topological ordering of the graph + auto [kwarg_graph, node_map] = + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph(dg); + std::vector node_topo_order = get_topological_ordering(kwarg_graph); + std::vector invocation_topo_order = transform( + node_topo_order, [&](Node node) { return node_map.at_l(node); }); + + return ParallelComputationGraphInstance{ctx, + invocation_topo_order, + optimizer_attrs, + loss_attrs, + logit_grad_tensor}; // TODO list: - // * per-device state initialization (RPC mechanism?) // * Realm allocator - // * task body // * external instances } +static std::unordered_map + execute_distributed_dynamic_node_invocation_set( + RealmContext &ctx, + std::vector const &invocations, + OptimizerAttrs const &optimizer_attrs, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig iteration_config) { + return unordered_map_from_pairs( + transform(invocations, [&](DynamicNodeInvocation const &invocation) { + Realm::Event result = + spawn_op_task(ctx, + ctx.map_device_coord_to_processor(assert_unwrap( + invocation.node_attrs.device_coord)), + invocation, + profiling_settings, + loss_attrs, + iteration_config, + optimizer_attrs); + return std::pair{invocation.node_attrs.layer_guid, result}; + })); +} + +std::unordered_map + perform_all_passes_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + instance.get_execution_order(); + std::unordered_map result = + execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); + instance.update_optimizer_attrs_for_next_iter(); + return result; +} + +std::unordered_map + perform_forward_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::FWD; + }); + + return execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); +} + +std::unordered_map + perform_backward_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::BWD; + }); + + return execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); +} + +std::unordered_map + perform_update_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::UPD; + }); + + std::unordered_map result = + execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); + instance.update_optimizer_attrs_for_next_iter(); + return result; +} + } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 91b753d639..49b5568d26 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -3,6 +3,7 @@ #include "realm-execution/tasks/impl/device_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" #include "utils/optional.h" #include #include @@ -43,7 +44,7 @@ void device_init_task_body(void const *args, DeviceInitTaskArgs task_args = *reinterpret_cast(args); - // FIXME: not safe to dereference unless we're on the same address space + // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; @@ -55,11 +56,15 @@ void device_init_task_body(void const *args, *task_args.iteration_config, *task_args.optimizer_attrs, ctx.get_current_device_idx()); - std::optional result_state = - result_invocation.node_attrs.per_device_op_state; + DeviceSpecificPerDeviceOpState result_state = + assert_unwrap(result_invocation.node_attrs.per_device_op_state); + // Important: to make sure this doesn't get deallocated, we intentionally leak + // the allocation here + DeviceSpecificPerDeviceOpState *result_state_ptr = + new DeviceSpecificPerDeviceOpState{result_state}; spawn_device_init_return_task(ctx, task_args.origin_proc, - assert_unwrap(result_state), + *result_state_ptr, task_args.origin_result_ptr); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 9d9a36e2d5..79c152844b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/impl/op_task.h" +#include "local-execution/task_execution.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/per_device_op_state.h" #include "utils/optional.h" #include @@ -9,8 +11,24 @@ namespace FlexFlow { // now just pass the pointer and assume we're running inside a single address // space struct OpTaskArgs { +public: + OpTaskArgs() = delete; + OpTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + std::optional const *loss_attrs, + FFIterationConfig const *iteration_config, + std::optional const *optimizer_attrs, + Realm::Processor origin_proc) + : invocation(invocation), profiling_settings(profiling_settings), + loss_attrs(loss_attrs), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs) {} + public: DynamicNodeInvocation const *invocation; + ProfilingSettings const *profiling_settings; + std::optional const *loss_attrs; + FFIterationConfig const *iteration_config; + std::optional const *optimizer_attrs; Realm::Processor origin_proc; }; static_assert(std::has_unique_object_representations_v); @@ -23,20 +41,41 @@ void op_task_body(void const *args, ASSERT(arglen == sizeof(OpTaskArgs)); OpTaskArgs task_args = *reinterpret_cast(args); - // FIXME: not safe to dereference unless we're on the same address space + // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - NOT_IMPLEMENTED(); + execute_dynamic_node_invocation( + /*invocation=*/*task_args.invocation, + /*allocator=*/ctx.get_current_device_allocator(), + /*profiling_settings=*/*task_args.profiling_settings, + /*ff_handle=*/ctx.get_current_device_handle(), + /*loss_attrs=*/*task_args.loss_attrs, + /*per_device_op_state=*/ + transform(task_args.invocation->node_attrs.per_device_op_state, + [&](DeviceSpecificPerDeviceOpState const &op_state) { + return get_device_state_from_device_specific( + op_state, ctx.get_current_device_idx()); + }), + /*iteration_config=*/*task_args.iteration_config, + /*optimizer_attrs=*/*task_args.optimizer_attrs, + /*device_idx=*/ctx.get_current_device_idx()); } Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor &target_proc, + Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs) { - OpTaskArgs task_args; - task_args.invocation = &invocation; + OpTaskArgs task_args{&invocation, + &profiling_settings, + &loss_attrs, + &iteration_config, + &optimizer_attrs, + ctx.get_current_processor()}; return ctx.spawn_task( target_proc, assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), From 46b7053f52db4aa3291b0247a21ffc4fd29d0cb4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 17:16:31 -0800 Subject: [PATCH 38/88] Implement Realm device idx. --- .../include/realm-execution/realm_context.h | 2 +- .../src/realm-execution/realm_context.cc | 26 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 422c4f4027..e28e91234e 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -29,7 +29,7 @@ struct RealmContext { Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; - device_id_t const &get_current_device_idx() const; + device_id_t get_current_device_idx() const; // Task creation Realm::Event spawn_task(Realm::Processor proc, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 7e6c73c9e7..781561c95a 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,6 +1,7 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" +#include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" @@ -70,8 +71,29 @@ Allocator &RealmContext::get_current_device_allocator() const { device_handle_t const &RealmContext::get_current_device_handle() const { NOT_IMPLEMENTED(); } -device_id_t const &RealmContext::get_current_device_idx() const { - NOT_IMPLEMENTED(); +device_id_t RealmContext::get_current_device_idx() const { + Realm::Processor proc = this->get_current_processor(); + + // FIXME: find a more efficient way to implement this than scanning the + // machine every time + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + pq.same_address_space_as(proc); + nonnegative_int idx{0}; + for (Realm::Processor p : pq) { + if (p == proc) { + break; + } + idx++; + } + + switch (proc.kind()) { + case Realm::Processor::LOC_PROC: + return make_device_id_t_from_idx(idx, DeviceType::CPU); + case Realm::Processor::TOC_PROC: + return make_device_id_t_from_idx(idx, DeviceType::GPU); + default: + PANIC("Unhandled Realm::ProcessorKind", fmt::to_string(int{proc.kind()})); + } } Realm::Event From 45634545373edea2b3c870809c26bc70758d75d2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 09:47:00 -0800 Subject: [PATCH 39/88] Updates to compile against latest local-execution. --- .../parallel_computation_graph_instance.h | 3 -- .../realm-execution/tasks/impl/op_task.h | 1 - .../parallel_computation_graph_instance.cc | 32 ++++++------------- .../tasks/impl/device_init_task.cc | 11 +++++-- .../src/realm-execution/tasks/impl/op_task.cc | 8 +---- .../src/realm-execution/tasks/task_id_t.cc | 12 ++++--- ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 +- 7 files changed, 28 insertions(+), 41 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 0886dcf4c0..de06f457e2 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -25,20 +25,17 @@ struct ParallelComputationGraphInstance { ParallelComputationGraphInstance(RealmContext &, std::vector const &, OptimizerAttrs const &, - std::optional const &, std::optional); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); - std::optional const &get_loss_attrs() const; std::optional get_loss_tensor_instance() const; private: RealmContext &ctx; std::vector execution_order; OptimizerAttrs optimizer_attrs; - std::optional loss_attrs; std::optional logit_grad_tensor; }; diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index dd75ed66ea..3fcffc30fa 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -18,7 +18,6 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs); diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 2683d019c3..05dfec74c3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -20,11 +20,9 @@ ParallelComputationGraphInstance::ParallelComputationGraphInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, - std::optional const &loss_attrs, std::optional logit_grad_tensor) : ctx(ctx), execution_order(execution_order), - optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), - logit_grad_tensor(logit_grad_tensor) {} + optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} RealmContext &ParallelComputationGraphInstance::get_realm_context() { return this->ctx; @@ -41,10 +39,6 @@ void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { this->optimizer_attrs = get_optimizer_attrs_for_next_iter(this->optimizer_attrs); } -std::optional const & - ParallelComputationGraphInstance::get_loss_attrs() const { - return this->loss_attrs; -} std::optional ParallelComputationGraphInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; @@ -102,15 +96,15 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::vector invocation_topo_order = transform( node_topo_order, [&](Node node) { return node_map.at_l(node); }); - return ParallelComputationGraphInstance{ctx, - invocation_topo_order, - optimizer_attrs, - loss_attrs, - logit_grad_tensor}; + return ParallelComputationGraphInstance{ + ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: // * Realm allocator // * external instances + // * dependencies + // * task argument serializer + // * copies } static std::unordered_map @@ -119,7 +113,6 @@ static std::unordered_map std::vector const &invocations, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig iteration_config) { return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { @@ -129,7 +122,6 @@ static std::unordered_map invocation.node_attrs.device_coord)), invocation, profiling_settings, - loss_attrs, iteration_config, optimizer_attrs); return std::pair{invocation.node_attrs.layer_guid, result}; @@ -141,7 +133,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = instance.get_execution_order(); std::unordered_map result = execute_distributed_dynamic_node_invocation_set( @@ -149,7 +141,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; @@ -160,7 +151,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -173,7 +164,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); } @@ -182,7 +172,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -195,7 +185,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); } @@ -204,7 +193,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -218,7 +207,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 49b5568d26..cc080255e2 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -4,6 +4,7 @@ #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "utils/optional.h" #include #include @@ -85,9 +86,13 @@ std::optional result_ptr, }; - std::optional task_id = get_init_task_id_for_op_attrs( - assert_unwrap(invocation.node_attrs.op_attrs)); - if (task_id) { + std::optional task_id = + and_then(and_then(invocation.node_attrs.op_attrs, + [](TrainingOperationAttrs const &op_attrs) { + return op_attrs.try_require_pcg_op(); + }), + get_init_task_id_for_op_attrs); + if (task_id.has_value()) { return ctx.spawn_task(target_proc, assert_unwrap(task_id), &task_args, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 79c152844b..5f6ab40607 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -15,18 +15,15 @@ struct OpTaskArgs { OpTaskArgs() = delete; OpTaskArgs(DynamicNodeInvocation const *invocation, ProfilingSettings const *profiling_settings, - std::optional const *loss_attrs, FFIterationConfig const *iteration_config, std::optional const *optimizer_attrs, Realm::Processor origin_proc) : invocation(invocation), profiling_settings(profiling_settings), - loss_attrs(loss_attrs), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs) {} + iteration_config(iteration_config), optimizer_attrs(optimizer_attrs) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; - std::optional const *loss_attrs; FFIterationConfig const *iteration_config; std::optional const *optimizer_attrs; Realm::Processor origin_proc; @@ -50,7 +47,6 @@ void op_task_body(void const *args, /*allocator=*/ctx.get_current_device_allocator(), /*profiling_settings=*/*task_args.profiling_settings, /*ff_handle=*/ctx.get_current_device_handle(), - /*loss_attrs=*/*task_args.loss_attrs, /*per_device_op_state=*/ transform(task_args.invocation->node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { @@ -67,12 +63,10 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs) { OpTaskArgs task_args{&invocation, &profiling_settings, - &loss_attrs, &iteration_config, &optimizer_attrs, ctx.get_current_processor()}; diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index 5a99f2bea8..94e1b887e7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -2,6 +2,7 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "utils/optional.h" #include "utils/overload.h" namespace FlexFlow { @@ -9,14 +10,17 @@ namespace FlexFlow { std::optional get_task_id_for_op(DynamicNodeAttrs const &node_attrs, std::optional const &optimizer_attrs) { - DynamicTaskType task_type = node_attrs.task_type.value(); + DynamicTaskType task_type = assert_unwrap(node_attrs.task_type); switch (task_type) { case DynamicTaskType::FWD: - return get_fwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); + return get_fwd_task_id_for_op_attrs( + assert_unwrap(node_attrs.op_attrs).require_pcg_op()); case DynamicTaskType::BWD: - return get_bwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); + return get_bwd_task_id_for_op_attrs( + assert_unwrap(node_attrs.op_attrs).require_pcg_op()); case DynamicTaskType::UPD: - return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); + return get_update_task_id_for_optimizer_attrs( + assert_unwrap(optimizer_attrs)); case DynamicTaskType::LOSS: return task_id_t::LOSS_BWD_TASK_ID; default: diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index e90ef10398..ced98dfd44 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -23,7 +23,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/mpcg.mapped_tasks.at(layer), - /*op_attrs=*/attrs.op_attrs, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; From 6daf3709946d80895f020c167c8ed914462abd9e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 10:01:14 -0800 Subject: [PATCH 40/88] Fix up function arguments. --- .../distributed_device_state_initialization.h | 2 +- .../include/realm-execution/instance_allocation.h | 11 ++++++----- .../parallel_computation_graph_instance.h | 9 +++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index 4121f10341..d2ed093c0b 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -10,7 +10,7 @@ namespace FlexFlow { DynamicOpenDataflowGraph perform_distributed_device_state_initialization( - DynamicOpenDataflowGraph const &, + DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index 59065694e9..09709201ce 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -7,15 +7,16 @@ namespace FlexFlow { -DynamicValueAttrs - perform_instance_allocation_for_value(DynamicValueAttrs const &, - Allocator &); +std::pair + perform_instance_allocation_for_value(DynamicNodeAttrs const &node, + DynamicValueAttrs const &value, + RealmContext &ctx); TensorInstanceBacking perform_instance_allocation( - DynamicOpenDataflowGraph const &, + DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, - RealmContext &); + RealmContext &ctx); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index de06f457e2..f48879a2bb 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -22,10 +22,11 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(RealmContext &, - std::vector const &, - OptimizerAttrs const &, - std::optional); + ParallelComputationGraphInstance( + RealmContext &ctx, + std::vector const &execution_order, + OptimizerAttrs const &optimizer_attrs, + std::optional logit_grad_tensor); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; From f7e58bd805dfd3955f032d330324264b8c591028 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 10:50:16 -0800 Subject: [PATCH 41/88] Rename PCGInstance and add dependency set. --- .../realm-execution/atomic_dependency_set.h | 26 ++++++++++++ .../include/realm-execution/dependency_set.h | 34 +++++++++++++++ .../pcg_instance.h} | 13 +++--- .../realm-execution/atomic_dependency_set.cc | 23 +++++++++++ .../src/realm-execution/dependency_set.cc | 41 +++++++++++++++++++ .../pcg_instance.cc} | 31 +++++++------- .../test/src/realm-execution/realm_manager.cc | 1 - .../test/src/realm-execution/test_e2e.cc | 2 +- 8 files changed, 150 insertions(+), 21 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/atomic_dependency_set.h create mode 100644 lib/realm-execution/include/realm-execution/dependency_set.h rename lib/realm-execution/include/realm-execution/{parallel_computation_graph_instance/parallel_computation_graph_instance.h => pcg_instance/pcg_instance.h} (84%) create mode 100644 lib/realm-execution/src/realm-execution/atomic_dependency_set.cc create mode 100644 lib/realm-execution/src/realm-execution/dependency_set.cc rename lib/realm-execution/src/realm-execution/{parallel_computation_graph_instance/parallel_computation_graph_instance.cc => pcg_instance/pcg_instance.cc} (90%) diff --git a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h new file mode 100644 index 0000000000..8a1ae96b3e --- /dev/null +++ b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_ATOMIC_DEPENDENCY_SET_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_ATOMIC_DEPENDENCY_SET_H + +#include "realm-execution/realm.h" +#include + +namespace FlexFlow { + +struct AtomicDependencySet { +public: + AtomicDependencySet() = delete; + explicit AtomicDependencySet(Realm::Event precondition); + + void add_writer(Realm::Event writer); + void add_reader(Realm::Event reader); + + Realm::Event get_current_outstanding_events() const; + +private: + Realm::Event writer; + std::vector readers; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/dependency_set.h b/lib/realm-execution/include/realm-execution/dependency_set.h new file mode 100644 index 0000000000..a7100076b2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/dependency_set.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEPENDENCY_SET_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEPENDENCY_SET_H + +#include "realm-execution/atomic_dependency_set.h" +#include "realm-execution/realm.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include + +namespace FlexFlow { + +struct DependencySet { +public: + DependencySet() = delete; + explicit DependencySet(Realm::Event precondition); + + void add_writer(DynamicValueAttrs const &value, Realm::Event writer); + void add_reader(DynamicValueAttrs const &value, Realm::Event reader); + + Realm::Event + get_current_outstanding_events(DynamicValueAttrs const &value) const; + +private: + AtomicDependencySet & + get_atomic_dependency_set(DynamicValueAttrs const &value); + +private: + Realm::Event precondition; + std::unordered_map + atomic_dependencies; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h similarity index 84% rename from lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h rename to lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index f48879a2bb..3c5b4189ea 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H #include "kernels/accessor.h" #include "kernels/allocation.h" @@ -20,9 +20,12 @@ namespace FlexFlow { -struct ParallelComputationGraphInstance { +struct PCGInstance { public: - ParallelComputationGraphInstance( + PCGInstance() = delete; + PCGInstance(PCGInstance const &) = delete; + PCGInstance(PCGInstance &&) = delete; + explicit PCGInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, @@ -40,7 +43,7 @@ struct ParallelComputationGraphInstance { std::optional logit_grad_tensor; }; -ParallelComputationGraphInstance create_parallel_computation_graph_instance( +PCGInstance create_pcg_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, diff --git a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc new file mode 100644 index 0000000000..bdc05b7c46 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc @@ -0,0 +1,23 @@ +#include "realm-execution/atomic_dependency_set.h" + +namespace FlexFlow { + +AtomicDependencySet::AtomicDependencySet(Realm::Event precondition) + : writer(precondition) {} + +void AtomicDependencySet::add_writer(Realm::Event writer) { + this->writer = Realm::Event::merge_events( + writer, this->get_current_outstanding_events()); + this->readers.clear(); +} + +void AtomicDependencySet::add_reader(Realm::Event reader) { + this->readers.push_back(reader); +} + +Realm::Event AtomicDependencySet::get_current_outstanding_events() const { + Realm::Event readers = Realm::Event::merge_events(this->readers); + return Realm::Event::merge_events(writer, readers); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/dependency_set.cc b/lib/realm-execution/src/realm-execution/dependency_set.cc new file mode 100644 index 0000000000..3af03ffcef --- /dev/null +++ b/lib/realm-execution/src/realm-execution/dependency_set.cc @@ -0,0 +1,41 @@ +#include "realm-execution/dependency_set.h" +#include "realm-execution/atomic_dependency_set.h" +#include "utils/containers/contains_key.h" + +namespace FlexFlow { + +DependencySet::DependencySet(Realm::Event precondition) + : precondition(precondition) {} + +void DependencySet::add_writer(DynamicValueAttrs const &value, + Realm::Event writer) { + AtomicDependencySet &atomic_dependence_set = + this->get_atomic_dependency_set(value); + atomic_dependence_set.add_writer(writer); +} + +void DependencySet::add_reader(DynamicValueAttrs const &value, + Realm::Event reader) { + AtomicDependencySet &atomic_dependence_set = + this->get_atomic_dependency_set(value); + atomic_dependence_set.add_reader(reader); +} + +Realm::Event DependencySet::get_current_outstanding_events( + DynamicValueAttrs const &value) const { + if (contains_key(this->atomic_dependencies, value)) { + return this->atomic_dependencies.at(value).get_current_outstanding_events(); + } + return this->precondition; +} + +AtomicDependencySet & + DependencySet::get_atomic_dependency_set(DynamicValueAttrs const &value) { + if (!contains_key(this->atomic_dependencies, value)) { + this->atomic_dependencies.insert( + {value, AtomicDependencySet{this->precondition}}); + } + return this->atomic_dependencies.at(value); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc similarity index 90% rename from lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc rename to lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 05dfec74c3..c1654397ec 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -1,5 +1,6 @@ -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/pcg_instance/pcg_instance.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/dependency_set.h" #include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" @@ -16,7 +17,7 @@ namespace FlexFlow { -ParallelComputationGraphInstance::ParallelComputationGraphInstance( +PCGInstance::PCGInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, @@ -24,27 +25,26 @@ ParallelComputationGraphInstance::ParallelComputationGraphInstance( : ctx(ctx), execution_order(execution_order), optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} -RealmContext &ParallelComputationGraphInstance::get_realm_context() { +RealmContext &PCGInstance::get_realm_context() { return this->ctx; } std::vector const & - ParallelComputationGraphInstance::get_execution_order() const { + PCGInstance::get_execution_order() const { return this->execution_order; } -OptimizerAttrs const & - ParallelComputationGraphInstance::get_optimizer_attrs() const { +OptimizerAttrs const &PCGInstance::get_optimizer_attrs() const { return this->optimizer_attrs; } -void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { +void PCGInstance::update_optimizer_attrs_for_next_iter() { this->optimizer_attrs = get_optimizer_attrs_for_next_iter(this->optimizer_attrs); } std::optional - ParallelComputationGraphInstance::get_loss_tensor_instance() const { + PCGInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; } -ParallelComputationGraphInstance create_parallel_computation_graph_instance( +PCGInstance create_parallel_computation_graph_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, @@ -96,7 +96,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::vector invocation_topo_order = transform( node_topo_order, [&](Node node) { return node_map.at_l(node); }); - return ParallelComputationGraphInstance{ + return PCGInstance{ ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: @@ -114,6 +114,9 @@ static std::unordered_map OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { + // For simplicity we'll track a dependency on all outstanding operations up to + // this point. This will create an effective barrier between phases. + DependencySet dependency_set{ctx.get_outstanding_events()}; return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { Realm::Event result = @@ -130,7 +133,7 @@ static std::unordered_map std::unordered_map perform_all_passes_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -148,7 +151,7 @@ std::unordered_map std::unordered_map perform_forward_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -169,7 +172,7 @@ std::unordered_map std::unordered_map perform_backward_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -190,7 +193,7 @@ std::unordered_map std::unordered_map perform_update_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 6c28a001ad..94e0d7d0f4 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,5 +1,4 @@ #include "realm-execution/realm_manager.h" -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include using namespace ::FlexFlow; diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index a30d5c4d8e..37f1a9b42c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,4 +1,4 @@ -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/pcg_instance/pcg_instance.h" #include "realm-execution/realm_manager.h" #include From bb5a54aab3e3b5aab4a428e523e1726d5c71fcfb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:17:21 -0800 Subject: [PATCH 42/88] Dependency tracking. --- .../realm-execution/atomic_dependency_set.h | 3 +- .../include/realm-execution/dependency_set.h | 4 +- .../distributed_device_state_initialization.h | 3 +- .../tasks/impl/controller_task.h | 3 +- .../tasks/impl/device_init_return_task.h | 3 +- .../tasks/impl/device_init_task.h | 3 +- .../realm-execution/tasks/impl/op_task.h | 14 +++---- .../realm-execution/atomic_dependency_set.cc | 12 ++++-- .../src/realm-execution/dependency_set.cc | 12 +++++- ...distributed_device_state_initialization.cc | 6 ++- .../pcg_instance/pcg_instance.cc | 39 +++++++++++++++---- .../src/realm-execution/realm_manager.cc | 3 +- .../tasks/impl/controller_task.cc | 12 +++--- .../tasks/impl/device_init_return_task.cc | 6 ++- .../tasks/impl/device_init_task.cc | 9 +++-- .../src/realm-execution/tasks/impl/op_task.cc | 17 ++++---- 16 files changed, 101 insertions(+), 48 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h index 8a1ae96b3e..da6ba86638 100644 --- a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h +++ b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h @@ -14,7 +14,8 @@ struct AtomicDependencySet { void add_writer(Realm::Event writer); void add_reader(Realm::Event reader); - Realm::Event get_current_outstanding_events() const; + Realm::Event get_dependency_for_writer() const; + Realm::Event get_dependency_for_reader() const; private: Realm::Event writer; diff --git a/lib/realm-execution/include/realm-execution/dependency_set.h b/lib/realm-execution/include/realm-execution/dependency_set.h index a7100076b2..629a40e2e7 100644 --- a/lib/realm-execution/include/realm-execution/dependency_set.h +++ b/lib/realm-execution/include/realm-execution/dependency_set.h @@ -16,8 +16,8 @@ struct DependencySet { void add_writer(DynamicValueAttrs const &value, Realm::Event writer); void add_reader(DynamicValueAttrs const &value, Realm::Event reader); - Realm::Event - get_current_outstanding_events(DynamicValueAttrs const &value) const; + Realm::Event get_dependency_for_writer(DynamicValueAttrs const &value) const; + Realm::Event get_dependency_for_reader(DynamicValueAttrs const &value) const; private: AtomicDependencySet & diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index d2ed093c0b..5530f473d8 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -14,7 +14,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs); + OptimizerAttrs const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h index d4c397bb37..7134973ead 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h @@ -12,7 +12,8 @@ void controller_task_body( Realm::Event collective_spawn_controller_task(RealmContext &ctx, Realm::Processor &target_proc, - std::function thunk); + std::function thunk, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h index fc6c8bdb9f..0f92b35c24 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h @@ -14,7 +14,8 @@ Realm::Event spawn_device_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr); + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index af07139483..7842963c7b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -21,7 +21,8 @@ std::optional ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 3fcffc30fa..21d8795339 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -13,13 +13,13 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs); +Realm::Event spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc index bdc05b7c46..ba4fcc5a9f 100644 --- a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc +++ b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc @@ -6,8 +6,8 @@ AtomicDependencySet::AtomicDependencySet(Realm::Event precondition) : writer(precondition) {} void AtomicDependencySet::add_writer(Realm::Event writer) { - this->writer = Realm::Event::merge_events( - writer, this->get_current_outstanding_events()); + this->writer = + Realm::Event::merge_events(writer, this->get_dependency_for_writer()); this->readers.clear(); } @@ -15,9 +15,13 @@ void AtomicDependencySet::add_reader(Realm::Event reader) { this->readers.push_back(reader); } -Realm::Event AtomicDependencySet::get_current_outstanding_events() const { +Realm::Event AtomicDependencySet::get_dependency_for_writer() const { Realm::Event readers = Realm::Event::merge_events(this->readers); - return Realm::Event::merge_events(writer, readers); + return Realm::Event::merge_events(this->writer, readers); +} + +Realm::Event AtomicDependencySet::get_dependency_for_reader() const { + return this->writer; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/dependency_set.cc b/lib/realm-execution/src/realm-execution/dependency_set.cc index 3af03ffcef..84412a125d 100644 --- a/lib/realm-execution/src/realm-execution/dependency_set.cc +++ b/lib/realm-execution/src/realm-execution/dependency_set.cc @@ -21,10 +21,18 @@ void DependencySet::add_reader(DynamicValueAttrs const &value, atomic_dependence_set.add_reader(reader); } -Realm::Event DependencySet::get_current_outstanding_events( +Realm::Event DependencySet::get_dependency_for_writer( DynamicValueAttrs const &value) const { if (contains_key(this->atomic_dependencies, value)) { - return this->atomic_dependencies.at(value).get_current_outstanding_events(); + return this->atomic_dependencies.at(value).get_dependency_for_writer(); + } + return this->precondition; +} + +Realm::Event DependencySet::get_dependency_for_reader( + DynamicValueAttrs const &value) const { + if (contains_key(this->atomic_dependencies, value)) { + return this->atomic_dependencies.at(value).get_dependency_for_reader(); } return this->precondition; } diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index f7fcea87e7..4ea8d0bbd1 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -14,7 +14,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs) { + OptimizerAttrs const &optimizer_attrs, + Realm::Event precondition) { // Initialize all operators and save the per-device op state ASSERT(no_nodes_are_initialized(dg)); @@ -38,7 +39,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( profiling_settings, iteration_config, optimizer_attrs, - output); + output, + precondition); if (result) { result_map[invocation] = output; } else { diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index c1654397ec..e636cbf259 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -7,11 +7,14 @@ #include "realm-execution/tasks/impl/op_task.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/optional.h" @@ -77,17 +80,20 @@ PCGInstance create_parallel_computation_graph_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); - // FIXME: for now we're going to be lazy and block on everything rather than - // do fine-grained dependencies on instances - ctx.get_outstanding_events().wait(); - std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { return backing.backing.at(lgv).first; }); + // FIXME: for now we're going to be lazy and block on everything rather than + // do fine-grained dependencies on instances dg = perform_distributed_device_state_initialization( - dg, ctx, profiling_settings, iteration_config, optimizer_attrs); + dg, + ctx, + profiling_settings, + iteration_config, + optimizer_attrs, + ctx.get_outstanding_events()); // Compute the topological ordering of the graph auto [kwarg_graph, node_map] = @@ -102,7 +108,6 @@ PCGInstance create_parallel_computation_graph_instance( // TODO list: // * Realm allocator // * external instances - // * dependencies // * task argument serializer // * copies } @@ -119,6 +124,19 @@ static std::unordered_map DependencySet dependency_set{ctx.get_outstanding_events()}; return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { + std::vector input_dependencies = + transform(vector_of(values(invocation.inputs)), + [&](DynamicValueAttrs const &value) { + return dependency_set.get_dependency_for_reader(value); + }); + std::vector output_dependencies = + transform(vector_of(values(invocation.outputs)), + [&](DynamicValueAttrs const &value) { + return dependency_set.get_dependency_for_writer(value); + }); + Realm::Event dependencies = Realm::Event::merge_events( + Realm::Event::merge_events(input_dependencies), + Realm::Event::merge_events(output_dependencies)); Realm::Event result = spawn_op_task(ctx, ctx.map_device_coord_to_processor(assert_unwrap( @@ -126,7 +144,14 @@ static std::unordered_map invocation, profiling_settings, iteration_config, - optimizer_attrs); + optimizer_attrs, + dependencies); + for (DynamicValueAttrs const &value : values(invocation.inputs)) { + dependency_set.add_reader(value, result); + } + for (DynamicValueAttrs const &value : values(invocation.outputs)) { + dependency_set.add_writer(value, result); + } return std::pair{invocation.node_attrs.layer_guid, result}; })); } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 7233103cc3..adafea47e6 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -27,7 +27,8 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); - return collective_spawn_controller_task(*this, target_proc, thunk); + return collective_spawn_controller_task( + *this, target_proc, thunk, Realm::Event::NO_EVENT); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc index 2fd5cee52d..285e8acaa7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc @@ -21,17 +21,19 @@ void controller_task_body(void const *args, task_args.thunk(ctx); } -Realm::Event collective_spawn_controller_task( - RealmContext &ctx, - Realm::Processor &target_proc, - std::function thunk) { +Realm::Event + collective_spawn_controller_task(RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk, + Realm::Event precondition) { ControllerTaskArgs task_args; task_args.thunk = thunk; return ctx.collective_spawn_task(target_proc, task_id_t::CONTROLLER_TASK_ID, &task_args, - sizeof(task_args)); + sizeof(task_args), + precondition); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc index fa421cda30..610500a94b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc @@ -36,14 +36,16 @@ Realm::Event spawn_device_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr) { + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition) { DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; return ctx.spawn_task(origin_proc, task_id_t::DEVICE_INIT_RETURN_TASK_ID, &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index cc080255e2..7f36f48921 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -66,7 +66,8 @@ void device_init_task_body(void const *args, spawn_device_init_return_task(ctx, task_args.origin_proc, *result_state_ptr, - task_args.origin_result_ptr); + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional @@ -76,7 +77,8 @@ std::optional ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr) { + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { DeviceInitTaskArgs task_args{ &invocation, &profiling_settings, @@ -97,7 +99,8 @@ std::optional assert_unwrap(task_id), &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } return std::nullopt; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 5f6ab40607..216f0badde 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -58,13 +58,13 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs) { +Realm::Event spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{&invocation, &profiling_settings, &iteration_config, @@ -75,7 +75,8 @@ Realm::Event assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } } // namespace FlexFlow From 8588e3677072a2a97286da89ae083c5de8d2c540 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:18:53 -0800 Subject: [PATCH 43/88] Add event argument to controller. --- lib/realm-execution/include/realm-execution/realm_manager.h | 3 ++- lib/realm-execution/src/realm-execution/realm_manager.cc | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index bf5e8f72f1..8a79476bcf 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -19,7 +19,8 @@ struct RealmManager : private RealmContext { RealmManager(RealmManager &&) = delete; [[nodiscard]] Realm::Event - start_controller(std::function); + start_controller(std::function, + Realm::Event wait_on = Realm::Event::NO_EVENT); }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index adafea47e6..fc74fffe5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -21,14 +21,14 @@ RealmManager::~RealmManager() { } Realm::Event - RealmManager::start_controller(std::function thunk) { + RealmManager::start_controller(std::function thunk, + Realm::Event wait_on) { Realm::Processor target_proc = Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) .only_kind(Realm::Processor::LOC_PROC) .first(); - return collective_spawn_controller_task( - *this, target_proc, thunk, Realm::Event::NO_EVENT); + return collective_spawn_controller_task(*this, target_proc, thunk, wait_on); } } // namespace FlexFlow From eacdc8ccd25184ef41fe4022ec0cc01b1eda5d8c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:53:33 -0800 Subject: [PATCH 44/88] Implement the allocator. --- .../include/realm-execution/realm_allocator.h | 31 +++++++++++ .../include/realm-execution/realm_context.h | 6 ++- .../pcg_instance/pcg_instance.cc | 2 +- .../src/realm-execution/realm_allocator.cc | 53 +++++++++++++++++++ .../src/realm-execution/realm_context.cc | 10 ++-- 5 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_allocator.h create mode 100644 lib/realm-execution/src/realm-execution/realm_allocator.cc diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h new file mode 100644 index 0000000000..dab6f3ea63 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_ALLOCATOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_ALLOCATOR_H + +#include "kernels/allocation.h" +#include "realm-execution/realm.h" + +namespace FlexFlow { + +struct RealmAllocator : public IAllocator { + RealmAllocator(Realm::Processor processor, Realm::Memory memory); + RealmAllocator(RealmAllocator const &) = delete; + RealmAllocator(RealmAllocator &&) = delete; + ~RealmAllocator() = default; + + void *allocate(size_t) override; + void deallocate(void *) override; + + DeviceType get_allocation_device_type() const override; + +private: + Realm::Processor processor; + Realm::Memory memory; + std::unordered_map ptr_instances; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(RealmAllocator); + +Allocator get_realm_allocator(Realm::Processor processor, Realm::Memory memory); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index e28e91234e..755bf595d6 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,6 +6,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include @@ -23,11 +24,11 @@ struct RealmContext { // Device mapping Realm::Processor map_device_coord_to_processor(MachineSpaceCoordinate const &); - Realm::Memory get_nearest_memory(Realm::Processor) const; + static Realm::Memory get_nearest_memory(Realm::Processor); // Current device context Realm::Processor get_current_processor() const; - Allocator &get_current_device_allocator() const; + Allocator &get_current_device_allocator(); device_handle_t const &get_current_device_handle() const; device_id_t get_current_device_idx() const; @@ -68,6 +69,7 @@ struct RealmContext { protected: Realm::Runtime runtime; Realm::Processor processor; + Allocator allocator; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index e636cbf259..93b42743a0 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -106,7 +106,7 @@ PCGInstance create_parallel_computation_graph_instance( ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: - // * Realm allocator + // * current device handle // * external instances // * task argument serializer // * copies diff --git a/lib/realm-execution/src/realm-execution/realm_allocator.cc b/lib/realm-execution/src/realm-execution/realm_allocator.cc new file mode 100644 index 0000000000..f24106b0bc --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_allocator.cc @@ -0,0 +1,53 @@ +#include "realm-execution/realm_allocator.h" +#include "kernels/device.h" +#include "pcg/device_type.dtg.h" + +namespace FlexFlow { + +RealmAllocator::RealmAllocator(Realm::Processor processor, Realm::Memory memory) + : processor(processor), memory(memory) {} + +void *RealmAllocator::allocate(size_t requested_memory_size) { + Realm::Rect<1> bounds{Realm::Point<1>::ZEROES(), + Realm::Point<1>{requested_memory_size} - + Realm::Point<1>::ONES()}; + std::vector field_sizes{1}; + Realm::RegionInstance inst; + Realm::Event ready = + Realm::RegionInstance::create_instance(inst, + this->memory, + bounds, + field_sizes, + 0 /*SOA*/, + Realm::ProfilingRequestSet{}); + ready.wait(); + void *ptr = + inst.pointer_untyped(/*offset=*/0, /*datalen=*/requested_memory_size); + ASSERT(ptr); + this->ptr_instances.insert({ptr, inst}); + return ptr; +} + +void RealmAllocator::deallocate(void *ptr) { + this->ptr_instances.at(ptr).destroy(Realm::Event::NO_EVENT); + this->ptr_instances.erase(ptr); +} + +DeviceType RealmAllocator::get_allocation_device_type() const { + switch (this->processor.kind()) { + case Realm::Processor::Kind::LOC_PROC: + return DeviceType::CPU; + case Realm::Processor::Kind::TOC_PROC: + return DeviceType::GPU; + default: + PANIC("Unhandled FwbTensorType", this->processor.kind()); + } +} + +Allocator get_realm_allocator(Realm::Processor processor, + Realm::Memory memory) { + Allocator allocator = Allocator::create(processor, memory); + return allocator; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 781561c95a..a77383779f 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -14,7 +14,9 @@ namespace FlexFlow { -RealmContext::RealmContext(Realm::Processor proc) : processor(proc) {} +RealmContext::RealmContext(Realm::Processor proc) + : processor(proc), allocator(get_realm_allocator( + proc, RealmContext::get_nearest_memory(proc))) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -51,7 +53,7 @@ Realm::Processor RealmContext::map_device_coord_to_processor( return this->processors.at(std::pair{as, kind}).at(int{proc_in_node}); } -Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { +Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) { // FIMXE: this isn't going to do what you expect until // https://github.com/StanfordLegion/realm/pull/392 merges Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); @@ -64,8 +66,8 @@ Realm::Processor RealmContext::get_current_processor() const { return this->processor; } -Allocator &RealmContext::get_current_device_allocator() const { - NOT_IMPLEMENTED(); +Allocator &RealmContext::get_current_device_allocator() { + return this->allocator; } device_handle_t const &RealmContext::get_current_device_handle() const { From 6828cfa4aa112892a67d4e6952d4a7aa488df69a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 12:17:50 -0800 Subject: [PATCH 45/88] Implement device handle. --- .../include/realm-execution/realm_allocator.h | 2 + .../include/realm-execution/realm_context.h | 10 ++++- .../pcg_instance/pcg_instance.cc | 1 - .../src/realm-execution/realm_context.cc | 42 ++++++++++++++++--- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h index dab6f3ea63..d72f2d7f91 100644 --- a/lib/realm-execution/include/realm-execution/realm_allocator.h +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct RealmAllocator : public IAllocator { RealmAllocator(Realm::Processor processor, Realm::Memory memory); + + RealmAllocator() = delete; RealmAllocator(RealmAllocator const &) = delete; RealmAllocator(RealmAllocator &&) = delete; ~RealmAllocator() = default; diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 755bf595d6..eb4d6d0935 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -3,18 +3,19 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" +#include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" -#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/task_id_t.dtg.h" +#include #include namespace FlexFlow { struct RealmContext { public: - RealmContext(Realm::Processor); + RealmContext(Realm::Processor processor); virtual ~RealmContext(); RealmContext() = delete; @@ -66,10 +67,15 @@ struct RealmContext { void discover_machine_topology(); + static std::optional + make_device_handle_for_processor(Realm::Processor processor); + protected: Realm::Runtime runtime; Realm::Processor processor; Allocator allocator; + std::optional managed_handle; + device_handle_t device_handle; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 93b42743a0..d56dbb9ca9 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -106,7 +106,6 @@ PCGInstance create_parallel_computation_graph_instance( ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: - // * current device handle // * external instances // * task argument serializer // * copies diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index a77383779f..38ce052da9 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,22 +1,27 @@ #include "realm-execution/realm_context.h" +#include "kernels/device_handle_t.dtg.h" +#include "kernels/device_handle_t.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" +#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/containers/contains_key.h" #include "utils/containers/transform.h" -#include "utils/exception.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/one_to_many/one_to_many.h" #include "utils/positive_int/positive_int.h" namespace FlexFlow { -RealmContext::RealmContext(Realm::Processor proc) - : processor(proc), allocator(get_realm_allocator( - proc, RealmContext::get_nearest_memory(proc))) {} +RealmContext::RealmContext(Realm::Processor processor) + : processor(processor), + allocator(get_realm_allocator( + processor, RealmContext::get_nearest_memory(processor))), + managed_handle(RealmContext::make_device_handle_for_processor(processor)), + device_handle(device_handle_t_from_managed_handle(managed_handle)) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -54,6 +59,10 @@ Realm::Processor RealmContext::map_device_coord_to_processor( } Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) { + if (!proc.exists()) { + return Realm::Memory::NO_MEMORY; + } + // FIMXE: this isn't going to do what you expect until // https://github.com/StanfordLegion/realm/pull/392 merges Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); @@ -71,8 +80,9 @@ Allocator &RealmContext::get_current_device_allocator() { } device_handle_t const &RealmContext::get_current_device_handle() const { - NOT_IMPLEMENTED(); + return this->device_handle; } + device_id_t RealmContext::get_current_device_idx() const { Realm::Processor proc = this->get_current_processor(); @@ -245,4 +255,26 @@ void RealmContext::discover_machine_topology() { } } +std::optional + RealmContext::make_device_handle_for_processor(Realm::Processor processor) { + if (!processor.exists()) { + return std::nullopt; + } + + switch (processor.kind()) { + case Realm::Processor::LOC_PROC: + return std::nullopt; + case Realm::Processor::TOC_PROC: + // FIXME: not sure what workSpaceSize to choose here + return initialize_multi_gpu_handle( + /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), + /*my_rank=*/processor.address_space(), + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + default: + PANIC("Unhandled Realm::ProcessorKind", + fmt::to_string(int{processor.kind()})); + } +} + } // namespace FlexFlow From 03cda523ef47e6aa6a4d1330ca69b1f7cf2a36e2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 14:47:23 -0800 Subject: [PATCH 46/88] Distributed device handle initialization. --- .../distributed_device_handle.h | 38 +++++++ .../impl/device_handle_init_return_task.h | 24 +++++ .../tasks/impl/device_handle_init_task.h | 24 +++++ .../tasks/impl/device_init_task.h | 29 ----- ...task.h => device_state_init_return_task.h} | 8 +- .../tasks/impl/device_state_init_task.h | 29 +++++ .../realm-execution/tasks/task_id_t.dtg.toml | 8 +- .../distributed_device_handle.cc | 50 +++++++++ ...distributed_device_state_initialization.cc | 18 ++-- .../impl/device_handle_init_return_task.cc | 55 ++++++++++ .../tasks/impl/device_handle_init_task.cc | 100 ++++++++++++++++++ .../tasks/impl/device_init_return_task.cc | 51 --------- .../impl/device_state_init_return_task.cc | 53 ++++++++++ ...init_task.cc => device_state_init_task.cc} | 67 ++++++------ .../tasks/realm_task_registry.cc | 10 +- 15 files changed, 432 insertions(+), 132 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/distributed_device_handle.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h rename lib/realm-execution/include/realm-execution/tasks/impl/{device_init_return_task.h => device_state_init_return_task.h} (77%) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h create mode 100644 lib/realm-execution/src/realm-execution/distributed_device_handle.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc delete mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc rename lib/realm-execution/src/realm-execution/tasks/impl/{device_init_task.cc => device_state_init_task.cc} (58%) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h new file mode 100644 index 0000000000..ca3f08fc41 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific.h" +#include +#include + +namespace FlexFlow { + +struct DistributedDeviceHandle { +public: + DistributedDeviceHandle() = delete; + explicit DistributedDeviceHandle( + std::map>> const + &handles); + + DeviceSpecific> const & + at(Realm::Processor processor) const; + +private: + std::map>> + handles; +}; + +DistributedDeviceHandle create_distributed_device_handle( + RealmContext &ctx, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Event precondition = Realm::Event::NO_EVENT); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h new file mode 100644 index 0000000000..8b358ee4ce --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_handle_init_return_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_handle_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecific> const &result, + DeviceSpecific> + *origin_result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h new file mode 100644 index 0000000000..c26633bd9a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_handle_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_handle_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + DeviceSpecific> *result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h deleted file mode 100644 index 7842963c7b..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H - -#include "kernels/profiling_settings.dtg.h" -#include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/realm.h" -#include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" - -namespace FlexFlow { - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -std::optional - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition); - -} // namespace FlexFlow - -#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h similarity index 77% rename from lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h rename to lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h index 0f92b35c24..8f44680815 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" @@ -7,10 +7,10 @@ namespace FlexFlow { -void device_init_return_task_body( +void device_state_init_return_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_device_init_return_task( +Realm::Event spawn_device_state_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h new file mode 100644 index 0000000000..4cd65a0a2a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_TASK_H + +#include "kernels/profiling_settings.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" + +namespace FlexFlow { + +void device_state_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +std::optional + spawn_device_state_init_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index 34e5183488..97b19b5f51 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -12,7 +12,13 @@ features = [ name = "CONTROLLER_TASK_ID" [[values]] -name = "DEVICE_INIT_RETURN_TASK_ID" +name = "DEVICE_HANDLE_INIT_TASK_ID" + +[[values]] +name = "DEVICE_HANDLE_INIT_RETURN_TASK_ID" + +[[values]] +name = "DEVICE_STATE_INIT_RETURN_TASK_ID" [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc new file mode 100644 index 0000000000..00c2e76360 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -0,0 +1,50 @@ +#include "realm-execution/distributed_device_handle.h" +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "task-spec/device_specific.h" + +namespace FlexFlow { + +DistributedDeviceHandle::DistributedDeviceHandle( + std::map>> const + &handles) + : handles(handles) {} + +DeviceSpecific> const & + DistributedDeviceHandle::at(Realm::Processor processor) const { + return this->handles.at(processor); +} + +DistributedDeviceHandle + create_distributed_device_handle(RealmContext &ctx, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Event precondition) { + std::map>> + handles; + + // Allocate space for the result before launching any tasks + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + for (Realm::Processor proc : pq) { + handles.insert( + {proc, + DeviceSpecific>::create( + ctx.get_current_device_idx(), std::nullopt)}); + } + + for (auto &[proc, handle] : handles) { + spawn_device_handle_init_task(ctx, + proc, + workSpaceSize, + allowTensorOpMathConversion, + &handle, + precondition); + } + + ctx.get_outstanding_events().wait(); + + return DistributedDeviceHandle{handles}; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index 4ea8d0bbd1..9627a71e87 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,6 +1,6 @@ #include "realm-execution/distributed_device_state_initialization.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "utils/optional.h" @@ -33,14 +33,14 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( static_cast( malloc(sizeof(DeviceSpecificPerDeviceOpState))); std::optional result = - spawn_device_init_task(ctx, - target_proc, - invocation, - profiling_settings, - iteration_config, - optimizer_attrs, - output, - precondition); + spawn_device_state_init_task(ctx, + target_proc, + invocation, + profiling_settings, + iteration_config, + optimizer_attrs, + output, + precondition); if (result) { result_map[invocation] = output; } else { diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc new file mode 100644 index 0000000000..2839beef0c --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc @@ -0,0 +1,55 @@ +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceHandleInitReturnTaskArgs { +public: + DeviceHandleInitReturnTaskArgs() = delete; + DeviceHandleInitReturnTaskArgs( + DeviceSpecific> result, + Realm::Processor origin_proc, + DeviceSpecific> + *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecific> result; + Realm::Processor origin_proc; + DeviceSpecific> *origin_result_ptr; +}; + +void device_handle_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceHandleInitReturnTaskArgs)); + DeviceHandleInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_handle_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecific> const &result, + DeviceSpecific> + *origin_result_ptr, + Realm::Event precondition) { + DeviceHandleInitReturnTaskArgs task_args{ + result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_HANDLE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc new file mode 100644 index 0000000000..86a576d26b --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -0,0 +1,100 @@ +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceHandleInitTaskArgs { + DeviceHandleInitTaskArgs() = delete; + DeviceHandleInitTaskArgs( + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Processor origin_proc, + DeviceSpecific> + *origin_result_ptr) + : workSpaceSize(workSpaceSize), + allowTensorOpMathConversion(allowTensorOpMathConversion), + origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + +public: + size_t workSpaceSize; + bool allowTensorOpMathConversion; + Realm::Processor origin_proc; + DeviceSpecific> *origin_result_ptr; +}; +static_assert(std::is_trivially_copy_constructible_v); + +static std::optional + make_device_handle_for_processor(Realm::Processor processor, + size_t workSpaceSize, + bool allowTensorOpMathConversion) { + switch (processor.kind()) { + case Realm::Processor::LOC_PROC: + return std::nullopt; + case Realm::Processor::TOC_PROC: + return new ManagedPerDeviceFFHandle{initialize_multi_gpu_handle( + /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), + /*my_rank=*/processor.address_space(), + /*workSpaceSize=*/workSpaceSize, + /*allowTensorOpMathConversion=*/allowTensorOpMathConversion)}; + default: + PANIC("Unhandled Realm::ProcessorKind", + fmt::to_string(int{processor.kind()})); + } +} + +void device_handle_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceHandleInitTaskArgs)); + DeviceHandleInitTaskArgs task_args = + *reinterpret_cast(args); + + // FIXME: serialize instead of passing pointers around + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + DeviceSpecific> managed_handle = + DeviceSpecific>::create( + ctx.get_current_device_idx(), + make_device_handle_for_processor( + proc, + task_args.workSpaceSize, + task_args.allowTensorOpMathConversion)); + + spawn_device_handle_init_return_task(ctx, + task_args.origin_proc, + managed_handle, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); +} + +Realm::Event spawn_device_handle_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + DeviceSpecific> *result_ptr, + Realm::Event precondition) { + DeviceHandleInitTaskArgs task_args{ + workSpaceSize, + allowTensorOpMathConversion, + ctx.get_current_processor(), + result_ptr, + }; + + return ctx.spawn_task(target_proc, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc deleted file mode 100644 index 610500a94b..0000000000 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "realm-execution/tasks/impl/device_init_task.h" -#include "realm-execution/tasks/task_id_t.dtg.h" - -namespace FlexFlow { - -// FIXME: Can't make this trivially copyable? -struct DeviceInitReturnTaskArgs { -public: - DeviceInitReturnTaskArgs() = delete; - DeviceInitReturnTaskArgs(DeviceSpecificPerDeviceOpState result, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) - : result(result), origin_proc(origin_proc), - origin_result_ptr(origin_result_ptr) {} - -public: - DeviceSpecificPerDeviceOpState result; - Realm::Processor origin_proc; - DeviceSpecificPerDeviceOpState *origin_result_ptr; -}; - -void device_init_return_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceInitReturnTaskArgs)); - DeviceInitReturnTaskArgs task_args = - *reinterpret_cast(args); - - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); - *task_args.origin_result_ptr = task_args.result; -} - -Realm::Event spawn_device_init_return_task( - RealmContext &ctx, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr, - Realm::Event precondition) { - DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; - - return ctx.spawn_task(origin_proc, - task_id_t::DEVICE_INIT_RETURN_TASK_ID, - &task_args, - sizeof(task_args), - Realm::ProfilingRequestSet{}, - precondition); -} - -} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc new file mode 100644 index 0000000000..c1bd7c1081 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc @@ -0,0 +1,53 @@ +#include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceStateInitReturnTaskArgs { +public: + DeviceStateInitReturnTaskArgs() = delete; + DeviceStateInitReturnTaskArgs( + DeviceSpecificPerDeviceOpState result, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecificPerDeviceOpState result; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; + +void device_state_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceStateInitReturnTaskArgs)); + DeviceStateInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_state_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition) { + DeviceStateInitReturnTaskArgs task_args{ + result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc similarity index 58% rename from lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc rename to lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 7f36f48921..f63efba14b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,6 +1,6 @@ -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" @@ -14,14 +14,14 @@ namespace FlexFlow { // TODO: at some point we're going to have to actually serialize these, but for // now just pass the pointer and assume we're running inside a single address // space -struct DeviceInitTaskArgs { - DeviceInitTaskArgs() = delete; - DeviceInitTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) +struct DeviceStateInitTaskArgs { + DeviceStateInitTaskArgs() = delete; + DeviceStateInitTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} @@ -34,16 +34,17 @@ struct DeviceInitTaskArgs { Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; -static_assert(std::has_unique_object_representations_v); +static_assert( + std::has_unique_object_representations_v); -void device_init_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceInitTaskArgs)); - DeviceInitTaskArgs task_args = - *reinterpret_cast(args); +void device_state_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceStateInitTaskArgs)); + DeviceStateInitTaskArgs task_args = + *reinterpret_cast(args); // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); @@ -63,23 +64,23 @@ void device_init_task_body(void const *args, // the allocation here DeviceSpecificPerDeviceOpState *result_state_ptr = new DeviceSpecificPerDeviceOpState{result_state}; - spawn_device_init_return_task(ctx, - task_args.origin_proc, - *result_state_ptr, - task_args.origin_result_ptr, - Realm::Event::NO_EVENT); + spawn_device_state_init_return_task(ctx, + task_args.origin_proc, + *result_state_ptr, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition) { - DeviceInitTaskArgs task_args{ + spawn_device_state_init_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { + DeviceStateInitTaskArgs task_args{ &invocation, &profiling_settings, &iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index c63d4727a9..9150ce6892 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,7 +1,7 @@ #include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/tasks/impl/controller_task.h" -#include "realm-execution/tasks/impl/device_init_return_task.h" -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "utils/exception.h" @@ -48,7 +48,7 @@ Realm::Event register_all_tasks() { for (task_id_t task_id : init_task_ids) { pending_registrations.push_back(register_task( - Realm::Processor::TOC_PROC, task_id, device_init_task_body)); + Realm::Processor::TOC_PROC, task_id, device_state_init_task_body)); } std::vector task_ids = { @@ -127,8 +127,8 @@ Realm::Event register_all_tasks() { controller_task_body)); pending_registrations.push_back( register_task(Realm::Processor::LOC_PROC, - task_id_t::DEVICE_INIT_RETURN_TASK_ID, - device_init_return_task_body)); + task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, + device_state_init_return_task_body)); return Realm::Event::merge_events(pending_registrations); } From a10b35a7dbbc43a7d0272bf1a027b2a033e7e2a8 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 16:45:58 -0800 Subject: [PATCH 47/88] Distributed device handle initialization. --- lib/kernels/include/kernels/device_handle_t.h | 3 ++ lib/kernels/src/kernels/device_handle_t.cc | 9 ++++ ...ific_managed_per_device_ff_handle.dtg.toml | 16 ++++++ ...ce_specific_managed_per_device_ff_handle.h | 19 +++++++ .../distributed_device_handle.h | 16 +++--- .../distributed_device_state_initialization.h | 2 + .../include/realm-execution/fmt/instance.h | 4 +- .../include/realm-execution/hash/processor.h | 16 ++++++ .../pcg_instance/pcg_instance.h | 2 + .../include/realm-execution/realm_context.h | 3 -- .../impl/device_handle_init_return_task.h | 8 ++- .../tasks/impl/device_handle_init_task.h | 5 +- .../tasks/impl/device_state_init_task.h | 20 ++++---- .../realm-execution/tasks/impl/op_task.h | 17 ++++--- ...e_specific_managed_per_device_ff_handle.cc | 21 ++++++++ .../distributed_device_handle.cc | 17 +++---- ...distributed_device_state_initialization.cc | 2 + .../src/realm-execution/hash/processor.cc | 11 +++++ .../pcg_instance/pcg_instance.cc | 30 ++++++++---- .../src/realm-execution/realm_context.cc | 30 +----------- .../impl/device_handle_init_return_task.cc | 15 +++--- .../tasks/impl/device_handle_init_task.cc | 12 ++--- .../impl/device_state_init_return_task.cc | 1 - .../tasks/impl/device_state_init_task.cc | 49 +++++++++++-------- .../src/realm-execution/tasks/impl/op_task.cc | 29 +++++++---- 25 files changed, 224 insertions(+), 133 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h create mode 100644 lib/realm-execution/include/realm-execution/hash/processor.h create mode 100644 lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc create mode 100644 lib/realm-execution/src/realm-execution/hash/processor.cc diff --git a/lib/kernels/include/kernels/device_handle_t.h b/lib/kernels/include/kernels/device_handle_t.h index 9b7769355e..0836503717 100644 --- a/lib/kernels/include/kernels/device_handle_t.h +++ b/lib/kernels/include/kernels/device_handle_t.h @@ -9,6 +9,9 @@ namespace FlexFlow { device_handle_t device_handle_t_from_managed_handle( std::optional const &managed_handle); +device_handle_t device_handle_t_from_managed_handle_ptr( + std::optional const &managed_handle); + device_handle_t gpu_make_device_handle_t(PerDeviceFFHandle const &ff_handle); device_handle_t cpu_make_device_handle_t(); diff --git a/lib/kernels/src/kernels/device_handle_t.cc b/lib/kernels/src/kernels/device_handle_t.cc index 85f9e2a388..0225ee8e94 100644 --- a/lib/kernels/src/kernels/device_handle_t.cc +++ b/lib/kernels/src/kernels/device_handle_t.cc @@ -11,6 +11,15 @@ device_handle_t device_handle_t_from_managed_handle( } } +device_handle_t device_handle_t_from_managed_handle_ptr( + std::optional const &managed_handle) { + if (managed_handle.has_value()) { + return gpu_make_device_handle_t(managed_handle.value()->raw_handle()); + } else { + return cpu_make_device_handle_t(); + } +} + device_handle_t gpu_make_device_handle_t(PerDeviceFFHandle const &ff_handle) { return device_handle_t{ ff_handle, diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml new file mode 100644 index 0000000000..1458adcba3 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DeviceSpecificManagedPerDeviceFFHandle" +type = "struct" +features = [ + "eq", +] + +includes = [ + "", + "kernels/managed_per_device_ff_handle.h", + "task-spec/device_specific.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::DeviceSpecific>" diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h new file mode 100644 index 0000000000..eefa6c86ac --- /dev/null +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_MANAGED_PER_DEVICE_FF_HANDLE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_MANAGED_PER_DEVICE_FF_HANDLE_H + +#include "kernels/device_handle_t.dtg.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "pcg/device_id_t.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" + +namespace FlexFlow { + +DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( + device_id_t const &, std::optional const &); + +device_handle_t device_handle_t_from_device_specific_managed_handle( + DeviceSpecificManagedPerDeviceFFHandle const &, device_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index ca3f08fc41..3f55c47192 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -1,12 +1,11 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/hash/processor.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific.h" -#include -#include +#include namespace FlexFlow { @@ -14,17 +13,14 @@ struct DistributedDeviceHandle { public: DistributedDeviceHandle() = delete; explicit DistributedDeviceHandle( - std::map>> const + std::unordered_map const &handles); - DeviceSpecific> const & + DeviceSpecificManagedPerDeviceFFHandle const & at(Realm::Processor processor) const; private: - std::map>> - handles; + std::unordered_map handles; }; DistributedDeviceHandle create_distributed_device_handle( diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index 5530f473d8..ca24ecdd4c 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -3,6 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" @@ -13,6 +14,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition); diff --git a/lib/realm-execution/include/realm-execution/fmt/instance.h b/lib/realm-execution/include/realm-execution/fmt/instance.h index b2efc59b7d..c7c2df6735 100644 --- a/lib/realm-execution/include/realm-execution/fmt/instance.h +++ b/lib/realm-execution/include/realm-execution/fmt/instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H #include "realm-execution/realm.h" #include "utils/check_fmtable.h" diff --git a/lib/realm-execution/include/realm-execution/hash/processor.h b/lib/realm-execution/include/realm-execution/hash/processor.h new file mode 100644 index 0000000000..e5eb8eb503 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/hash/processor.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H + +#include "realm-execution/realm.h" +#include + +namespace std { + +template <> +struct hash<::FlexFlow::Realm::Processor> { + size_t operator()(::FlexFlow::Realm::Processor const &p) const; +}; + +} // namespace std + +#endif diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index 3c5b4189ea..b917477df4 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -10,6 +10,7 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -53,6 +54,7 @@ PCGInstance create_pcg_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index eb4d6d0935..b8baad41b9 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -30,7 +30,6 @@ struct RealmContext { // Current device context Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator(); - device_handle_t const &get_current_device_handle() const; device_id_t get_current_device_idx() const; // Task creation @@ -74,8 +73,6 @@ struct RealmContext { Realm::Runtime runtime; Realm::Processor processor; Allocator allocator; - std::optional managed_handle; - device_handle_t device_handle; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h index 8b358ee4ce..9bae546403 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -1,10 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" namespace FlexFlow { @@ -14,9 +13,8 @@ void device_handle_init_return_task_body( Realm::Event spawn_device_handle_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecific> const &result, - DeviceSpecific> - *origin_result_ptr, + DeviceSpecificManagedPerDeviceFFHandle const &result, + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h index c26633bd9a..624eb6e682 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -1,10 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" namespace FlexFlow { @@ -16,7 +15,7 @@ Realm::Event spawn_device_handle_init_task( Realm::Processor target_proc, size_t workSpaceSize, bool allowTensorOpMathConversion, - DeviceSpecific> *result_ptr, + DeviceSpecificManagedPerDeviceFFHandle *result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 4cd65a0a2a..933d4f9283 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -3,6 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" @@ -14,15 +15,16 @@ namespace FlexFlow { void device_state_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -std::optional - spawn_device_state_init_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition); +std::optional spawn_device_state_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 21d8795339..847154192a 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -4,6 +4,7 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -13,13 +14,15 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition); +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc new file mode 100644 index 0000000000..440b9d18f7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -0,0 +1,21 @@ +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "kernels/device_handle_t.h" + +namespace FlexFlow { + +DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( + device_id_t const &device_id, + std::optional const &managed_handle) { + return DeviceSpecificManagedPerDeviceFFHandle{ + DeviceSpecific>::create( + device_id, managed_handle)}; +} + +device_handle_t device_handle_t_from_device_specific_managed_handle( + DeviceSpecificManagedPerDeviceFFHandle const &device_specific, + device_id_t device_idx) { + return device_handle_t_from_managed_handle_ptr( + *device_specific.handle.get(device_idx)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 00c2e76360..404feb014c 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -1,16 +1,16 @@ #include "realm-execution/distributed_device_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_task.h" #include "task-spec/device_specific.h" namespace FlexFlow { DistributedDeviceHandle::DistributedDeviceHandle( - std::map>> const + std::unordered_map const &handles) : handles(handles) {} -DeviceSpecific> const & +DeviceSpecificManagedPerDeviceFFHandle const & DistributedDeviceHandle::at(Realm::Processor processor) const { return this->handles.at(processor); } @@ -20,17 +20,14 @@ DistributedDeviceHandle size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Event precondition) { - std::map>> - handles; + std::unordered_map handles; // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); for (Realm::Processor proc : pq) { - handles.insert( - {proc, - DeviceSpecific>::create( - ctx.get_current_device_idx(), std::nullopt)}); + handles.insert({proc, + make_device_specific_managed_handle( + ctx.get_current_device_idx(), std::nullopt)}); } for (auto &[proc, handle] : handles) { diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index 9627a71e87..cab2b49e15 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -13,6 +13,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition) { @@ -37,6 +38,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( target_proc, invocation, profiling_settings, + device_handle.at(target_proc), iteration_config, optimizer_attrs, output, diff --git a/lib/realm-execution/src/realm-execution/hash/processor.cc b/lib/realm-execution/src/realm-execution/hash/processor.cc new file mode 100644 index 0000000000..dcc1bc5d06 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/hash/processor.cc @@ -0,0 +1,11 @@ +#include "realm-execution/hash/processor.h" +#include + +namespace std { + +size_t hash<::FlexFlow::Realm::Processor>::operator()( + ::FlexFlow::Realm::Processor const &p) const { + return hash<::FlexFlow::Realm::Processor::id_t>{}(p.id); +} + +} // namespace std diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index d56dbb9ca9..c79d8e8abd 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -57,6 +57,7 @@ PCGInstance create_parallel_computation_graph_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config) { DynamicOpenDataflowGraph dg = @@ -91,6 +92,7 @@ PCGInstance create_parallel_computation_graph_instance( dg, ctx, profiling_settings, + device_handle, iteration_config, optimizer_attrs, ctx.get_outstanding_events()); @@ -117,6 +119,7 @@ static std::unordered_map std::vector const &invocations, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { // For simplicity we'll track a dependency on all outstanding operations up to // this point. This will create an effective barrier between phases. @@ -136,15 +139,16 @@ static std::unordered_map Realm::Event dependencies = Realm::Event::merge_events( Realm::Event::merge_events(input_dependencies), Realm::Event::merge_events(output_dependencies)); - Realm::Event result = - spawn_op_task(ctx, - ctx.map_device_coord_to_processor(assert_unwrap( - invocation.node_attrs.device_coord)), - invocation, - profiling_settings, - iteration_config, - optimizer_attrs, - dependencies); + Realm::Processor target_proc = ctx.map_device_coord_to_processor( + assert_unwrap(invocation.node_attrs.device_coord)); + Realm::Event result = spawn_op_task(ctx, + target_proc, + invocation, + profiling_settings, + device_handle.at(target_proc), + iteration_config, + optimizer_attrs, + dependencies); for (DynamicValueAttrs const &value : values(invocation.inputs)) { dependency_set.add_reader(value, result); } @@ -159,6 +163,7 @@ std::unordered_map perform_all_passes_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = instance.get_execution_order(); @@ -168,6 +173,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; @@ -177,6 +183,7 @@ std::unordered_map perform_forward_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -191,6 +198,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); } @@ -198,6 +206,7 @@ std::unordered_map perform_backward_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -212,6 +221,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); } @@ -219,6 +229,7 @@ std::unordered_map perform_update_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -234,6 +245,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 38ce052da9..3427e8cbee 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -19,9 +19,7 @@ namespace FlexFlow { RealmContext::RealmContext(Realm::Processor processor) : processor(processor), allocator(get_realm_allocator( - processor, RealmContext::get_nearest_memory(processor))), - managed_handle(RealmContext::make_device_handle_for_processor(processor)), - device_handle(device_handle_t_from_managed_handle(managed_handle)) {} + processor, RealmContext::get_nearest_memory(processor))) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -79,10 +77,6 @@ Allocator &RealmContext::get_current_device_allocator() { return this->allocator; } -device_handle_t const &RealmContext::get_current_device_handle() const { - return this->device_handle; -} - device_id_t RealmContext::get_current_device_idx() const { Realm::Processor proc = this->get_current_processor(); @@ -255,26 +249,4 @@ void RealmContext::discover_machine_topology() { } } -std::optional - RealmContext::make_device_handle_for_processor(Realm::Processor processor) { - if (!processor.exists()) { - return std::nullopt; - } - - switch (processor.kind()) { - case Realm::Processor::LOC_PROC: - return std::nullopt; - case Realm::Processor::TOC_PROC: - // FIXME: not sure what workSpaceSize to choose here - return initialize_multi_gpu_handle( - /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), - /*my_rank=*/processor.address_space(), - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); - default: - PANIC("Unhandled Realm::ProcessorKind", - fmt::to_string(int{processor.kind()})); - } -} - } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc index 2839beef0c..bda6f7781c 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc @@ -3,22 +3,20 @@ namespace FlexFlow { -// FIXME: Can't make this trivially copyable? struct DeviceHandleInitReturnTaskArgs { public: DeviceHandleInitReturnTaskArgs() = delete; DeviceHandleInitReturnTaskArgs( - DeviceSpecific> result, + DeviceSpecificManagedPerDeviceFFHandle result, Realm::Processor origin_proc, - DeviceSpecific> - *origin_result_ptr) + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) : result(result), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} public: - DeviceSpecific> result; + DeviceSpecificManagedPerDeviceFFHandle result; Realm::Processor origin_proc; - DeviceSpecific> *origin_result_ptr; + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; }; void device_handle_init_return_task_body(void const *args, @@ -37,9 +35,8 @@ void device_handle_init_return_task_body(void const *args, Realm::Event spawn_device_handle_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecific> const &result, - DeviceSpecific> - *origin_result_ptr, + DeviceSpecificManagedPerDeviceFFHandle const &result, + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr, Realm::Event precondition) { DeviceHandleInitReturnTaskArgs task_args{ result, origin_proc, origin_result_ptr}; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index 86a576d26b..cd5608ca7e 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -1,4 +1,5 @@ #include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include @@ -14,8 +15,7 @@ struct DeviceHandleInitTaskArgs { size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Processor origin_proc, - DeviceSpecific> - *origin_result_ptr) + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) : workSpaceSize(workSpaceSize), allowTensorOpMathConversion(allowTensorOpMathConversion), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} @@ -24,7 +24,7 @@ struct DeviceHandleInitTaskArgs { size_t workSpaceSize; bool allowTensorOpMathConversion; Realm::Processor origin_proc; - DeviceSpecific> *origin_result_ptr; + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; }; static_assert(std::is_trivially_copy_constructible_v); @@ -60,8 +60,8 @@ void device_handle_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - DeviceSpecific> managed_handle = - DeviceSpecific>::create( + DeviceSpecificManagedPerDeviceFFHandle managed_handle = + make_device_specific_managed_handle( ctx.get_current_device_idx(), make_device_handle_for_processor( proc, @@ -80,7 +80,7 @@ Realm::Event spawn_device_handle_init_task( Realm::Processor target_proc, size_t workSpaceSize, bool allowTensorOpMathConversion, - DeviceSpecific> *result_ptr, + DeviceSpecificManagedPerDeviceFFHandle *result_ptr, Realm::Event precondition) { DeviceHandleInitTaskArgs task_args{ workSpaceSize, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc index c1bd7c1081..306697e950 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc @@ -3,7 +3,6 @@ namespace FlexFlow { -// FIXME: Can't make this trivially copyable? struct DeviceStateInitReturnTaskArgs { public: DeviceStateInitReturnTaskArgs() = delete; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index f63efba14b..5a51b1c803 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/impl/device_state_init_task.h" +#include "kernels/device_handle_t.dtg.h" #include "local-execution/device_state_initialization.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" @@ -16,26 +18,28 @@ namespace FlexFlow { // space struct DeviceStateInitTaskArgs { DeviceStateInitTaskArgs() = delete; - DeviceStateInitTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) + DeviceStateInitTaskArgs( + DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), - iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), - origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + device_handle(device_handle), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; + DeviceSpecificManagedPerDeviceFFHandle device_handle; FFIterationConfig const *iteration_config; OptimizerAttrs const *optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; -static_assert( - std::has_unique_object_representations_v); void device_state_init_task_body(void const *args, size_t arglen, @@ -50,11 +54,14 @@ void device_state_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; + device_handle_t device_handle = + device_handle_t_from_device_specific_managed_handle( + task_args.device_handle, ctx.get_current_device_idx()); DynamicNodeInvocation result_invocation = initialize_node(*task_args.invocation, ctx.get_current_device_allocator(), *task_args.profiling_settings, - ctx.get_current_device_handle(), + device_handle, *task_args.iteration_config, *task_args.optimizer_attrs, ctx.get_current_device_idx()); @@ -71,18 +78,20 @@ void device_state_init_task_body(void const *args, Realm::Event::NO_EVENT); } -std::optional - spawn_device_state_init_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition) { +std::optional spawn_device_state_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ &invocation, &profiling_settings, + device_handle, &iteration_config, &optimizer_attrs, ctx.get_current_processor(), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 216f0badde..e17973febb 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,5 +1,6 @@ #include "realm-execution/tasks/impl/op_task.h" #include "local-execution/task_execution.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" #include "utils/optional.h" @@ -15,20 +16,22 @@ struct OpTaskArgs { OpTaskArgs() = delete; OpTaskArgs(DynamicNodeInvocation const *invocation, ProfilingSettings const *profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const *iteration_config, std::optional const *optimizer_attrs, Realm::Processor origin_proc) : invocation(invocation), profiling_settings(profiling_settings), - iteration_config(iteration_config), optimizer_attrs(optimizer_attrs) {} + device_handle(device_handle), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; + DeviceSpecificManagedPerDeviceFFHandle device_handle; FFIterationConfig const *iteration_config; std::optional const *optimizer_attrs; Realm::Processor origin_proc; }; -static_assert(std::has_unique_object_representations_v); void op_task_body(void const *args, size_t arglen, @@ -42,11 +45,14 @@ void op_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; + device_handle_t device_handle = + device_handle_t_from_device_specific_managed_handle( + task_args.device_handle, ctx.get_current_device_idx()); execute_dynamic_node_invocation( /*invocation=*/*task_args.invocation, /*allocator=*/ctx.get_current_device_allocator(), /*profiling_settings=*/*task_args.profiling_settings, - /*ff_handle=*/ctx.get_current_device_handle(), + /*ff_handle=*/device_handle, /*per_device_op_state=*/ transform(task_args.invocation->node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { @@ -58,15 +64,18 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition) { +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{&invocation, &profiling_settings, + device_handle, &iteration_config, &optimizer_attrs, ctx.get_current_processor()}; From 2fc992cf4af7a6d58fa29c1a9270eecaf8418cc4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 17:02:00 -0800 Subject: [PATCH 48/88] Test distributed device handle. --- .../realm-execution/distributed_device_handle.h | 6 ++++-- .../realm-execution/distributed_device_handle.cc | 7 ++++--- .../realm-execution/tasks/realm_task_registry.cc | 14 ++++++++++++++ .../test/src/realm-execution/realm_manager.cc | 14 ++++++++++++-- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index 3f55c47192..40f3b98fb3 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -13,14 +13,16 @@ struct DistributedDeviceHandle { public: DistributedDeviceHandle() = delete; explicit DistributedDeviceHandle( - std::unordered_map const + std::unordered_map const &handles); DeviceSpecificManagedPerDeviceFFHandle const & at(Realm::Processor processor) const; private: - std::unordered_map handles; + std::unordered_map + handles; }; DistributedDeviceHandle create_distributed_device_handle( diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 404feb014c..3cd01f292e 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -6,8 +6,8 @@ namespace FlexFlow { DistributedDeviceHandle::DistributedDeviceHandle( - std::unordered_map const - &handles) + std::unordered_map const &handles) : handles(handles) {} DeviceSpecificManagedPerDeviceFFHandle const & @@ -20,7 +20,8 @@ DistributedDeviceHandle size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Event precondition) { - std::unordered_map handles; + std::unordered_map + handles; // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 9150ce6892..cff12c2391 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/tasks/impl/controller_task.h" +#include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/impl/device_handle_init_task.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/impl/device_state_init_task.h" #include "realm-execution/tasks/impl/op_task.h" @@ -125,6 +127,18 @@ Realm::Event register_all_tasks() { pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, task_id_t::CONTROLLER_TASK_ID, controller_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + device_handle_init_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::TOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + device_handle_init_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_RETURN_TASK_ID, + device_handle_init_return_task_body)); pending_registrations.push_back( register_task(Realm::Processor::LOC_PROC, task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 94e0d7d0f4..41fa63f4f9 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/distributed_device_handle.h" #include using namespace ::FlexFlow; @@ -16,8 +17,17 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - FlexFlow::Realm::Event event = manager.start_controller( - [&](RealmContext &ctx) { ASSERT(some_data == 123); }); + FlexFlow::Realm::Event event = + manager.start_controller([&](RealmContext &ctx) { + // Data is captured and retains value + ASSERT(some_data == 123); + + // Launch some basic task to ensure everything works + DistributedDeviceHandle handle = create_distributed_device_handle( + /*ctx=*/ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + }); // Need to block on the completion of the event to ensure we don't race event.wait(); } From 939c49aef2346113fd8777261be1577af3032dd5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 20:54:38 -0800 Subject: [PATCH 49/88] Guard the kinds of procs we run on. --- .../src/realm-execution/distributed_device_handle.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 3cd01f292e..87376be9b1 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -26,9 +26,12 @@ DistributedDeviceHandle // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); for (Realm::Processor proc : pq) { - handles.insert({proc, - make_device_specific_managed_handle( - ctx.get_current_device_idx(), std::nullopt)}); + if (proc.kind() == Realm::Processor::LOC_PROC || + proc.kind() == Realm::Processor::TOC_PROC) { + handles.insert({proc, + make_device_specific_managed_handle( + ctx.get_current_device_idx(), std::nullopt)}); + } } for (auto &[proc, handle] : handles) { From d21558ab6716e47bd2c9116e32f4ebbc1b229f4f Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 10:30:17 -0800 Subject: [PATCH 50/88] Switch to own DeviceSpecific implementation with raw pointers. --- ...pecific_managed_per_device_ff_handle.dtg.toml | 16 ---------------- ...evice_specific_managed_per_device_ff_handle.h | 14 +++++++++++++- .../realm-execution/distributed_device_handle.h | 2 +- .../tasks/impl/device_handle_init_return_task.h | 2 +- .../tasks/impl/device_handle_init_task.h | 2 +- .../tasks/impl/device_state_init_task.h | 2 +- .../include/realm-execution/tasks/impl/op_task.h | 2 +- ...vice_specific_managed_per_device_ff_handle.cc | 16 ++++++++++++---- 8 files changed, 30 insertions(+), 26 deletions(-) delete mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml deleted file mode 100644 index 1458adcba3..0000000000 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "DeviceSpecificManagedPerDeviceFFHandle" -type = "struct" -features = [ - "eq", -] - -includes = [ - "", - "kernels/managed_per_device_ff_handle.h", - "task-spec/device_specific.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::DeviceSpecific>" diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index eefa6c86ac..19a70491a2 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,10 +4,22 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" namespace FlexFlow { +struct DeviceSpecificManagedPerDeviceFFHandle { +public: + DeviceSpecificManagedPerDeviceFFHandle() = delete; + explicit DeviceSpecificManagedPerDeviceFFHandle( + device_id_t owner, std::optional handle); + + std::optional get(device_id_t device_idx) const; + +private: + device_id_t owner; + std::optional handle; +}; + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &, std::optional const &); diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index 40f3b98fb3..268be3583d 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/hash/processor.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h index 9bae546403..a87652b5ce 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h index 624eb6e682..312ed26add 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 933d4f9283..4ed8c1726d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -3,7 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 847154192a..9d4c2fd451 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -4,7 +4,7 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index 440b9d18f7..99ff7a6dd6 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -3,19 +3,27 @@ namespace FlexFlow { +DeviceSpecificManagedPerDeviceFFHandle::DeviceSpecificManagedPerDeviceFFHandle( + device_id_t owner, std::optional handle) + : owner(owner), handle(handle) {} + +std::optional + DeviceSpecificManagedPerDeviceFFHandle::get(device_id_t device_idx) const { + ASSERT(this->owner == device_idx); + return this->handle; +} + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &device_id, std::optional const &managed_handle) { - return DeviceSpecificManagedPerDeviceFFHandle{ - DeviceSpecific>::create( - device_id, managed_handle)}; + return DeviceSpecificManagedPerDeviceFFHandle{device_id, managed_handle}; } device_handle_t device_handle_t_from_device_specific_managed_handle( DeviceSpecificManagedPerDeviceFFHandle const &device_specific, device_id_t device_idx) { return device_handle_t_from_managed_handle_ptr( - *device_specific.handle.get(device_idx)); + *device_specific.get(device_idx)); } } // namespace FlexFlow From 1beaa05b5eb784add56d4fb8fb3cbbe6d5448db5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 10:57:11 -0800 Subject: [PATCH 51/88] Separate device handle test. --- .../distributed_device_handle.cc | 38 +++++++++++++++++++ .../test/src/realm-execution/realm_manager.cc | 16 ++++---- .../test/src/realm-execution/test_e2e.cc | 5 +++ 3 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc diff --git a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc new file mode 100644 index 0000000000..5a5402a140 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc @@ -0,0 +1,38 @@ +#include "realm-execution/distributed_device_handle.h" +#include "realm-execution/realm_manager.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DistributedDeviceHandle") { + // Construct some fake command line for our test + char fake_executable_name[] = "fake_executable_name"; + char arg0[] = "-ll:cpu"; + char arg1[] = "2"; + std::vector fake_args{fake_executable_name, arg0, arg1}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager(&fake_argc, &fake_argv); + + (void)manager.start_controller([](RealmContext &ctx) { + DistributedDeviceHandle handle = create_distributed_device_handle( + /*ctx=*/ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + // Make sure we have handles for the processors we're expecting + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + pq.only_kind(Realm::Processor::LOC_PROC); + for (Realm::Processor proc : pq) { + handle.at(proc); + } + }); + } +} + +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 41fa63f4f9..5fe659cdc2 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -2,7 +2,10 @@ #include "realm-execution/distributed_device_handle.h" #include +namespace test { + using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmManager") { @@ -17,18 +20,15 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - FlexFlow::Realm::Event event = + Realm::Event event = manager.start_controller([&](RealmContext &ctx) { // Data is captured and retains value ASSERT(some_data == 123); - - // Launch some basic task to ensure everything works - DistributedDeviceHandle handle = create_distributed_device_handle( - /*ctx=*/ctx, - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); }); - // Need to block on the completion of the event to ensure we don't race + // Need to block on the completion of the event to ensure we don't race, + // because the lambda captures the environment event.wait(); } } + +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 37f1a9b42c..9592cb221c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -2,7 +2,10 @@ #include "realm-execution/realm_manager.h" #include +namespace test { + using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training") { @@ -14,3 +17,5 @@ TEST_SUITE(FF_TEST_SUITE) { (void)manager.start_controller([](RealmContext &ctx) {}); } } + +} // namespace test From 68ce681189b6acf1d5d8770c2b0f94a62b5e5488 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 12:26:57 -0800 Subject: [PATCH 52/88] More work on Realm tests. --- .../parallel_computation_graph.h | 4 + .../parallel_computation_graph.cc | 21 +++ .../pcg_instance/pcg_instance.h | 32 +++- .../pcg_instance/pcg_instance.cc | 10 +- .../test/src/internal/realm_test_utils.cc | 28 +++ .../test/src/internal/realm_test_utils.h | 15 ++ .../distributed_device_handle.cc | 8 +- .../test/src/realm-execution/realm_manager.cc | 15 +- .../test/src/realm-execution/test_e2e.cc | 173 +++++++++++++++++- 9 files changed, 283 insertions(+), 23 deletions(-) create mode 100644 lib/realm-execution/test/src/internal/realm_test_utils.cc create mode 100644 lib/realm-execution/test/src/internal/realm_test_utils.h diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 3d948ac107..21f33f6d3d 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -32,6 +32,10 @@ ParallelLayerAddedResult add_parallel_layer( ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, TensorShape const &tensor_shape); +ParallelLayerAddedResult + pcg_add_input_layer_with_grad(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape); + OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 907dc05620..959747dbc7 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -142,6 +142,27 @@ ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, }); } +ParallelLayerAddedResult + pcg_add_input_layer_with_grad(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{tensor_shape}}, + /*name=*/std::nullopt, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*weights=*/{}, + /*output_flags=*/ + std::unordered_map{ + { + TensorSlotName::OUTPUT, + CreateGrad::YES, + }, + }); +} + OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { PCGOperatorAttrs op_attrs = pcg_get_op_attrs(pcg, layer); diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index b917477df4..b0037f51b2 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_PCG_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_PCG_INSTANCE_H #include "kernels/accessor.h" #include "kernels/allocation.h" @@ -57,6 +57,34 @@ PCGInstance create_pcg_instance( DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config); +std::unordered_map + perform_all_passes_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_forward_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_backward_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_update_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index c79d8e8abd..de7cdcb687 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -47,7 +47,7 @@ std::optional return this->logit_grad_tensor; } -PCGInstance create_parallel_computation_graph_instance( +PCGInstance create_pcg_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, @@ -160,7 +160,7 @@ static std::unordered_map } std::unordered_map - perform_all_passes_for_parallel_computation_graph_instance( + perform_all_passes_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -180,7 +180,7 @@ std::unordered_map } std::unordered_map - perform_forward_pass_for_parallel_computation_graph_instance( + perform_forward_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -203,7 +203,7 @@ std::unordered_map } std::unordered_map - perform_backward_pass_for_parallel_computation_graph_instance( + perform_backward_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -226,7 +226,7 @@ std::unordered_map } std::unordered_map - perform_update_pass_for_parallel_computation_graph_instance( + perform_update_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, diff --git a/lib/realm-execution/test/src/internal/realm_test_utils.cc b/lib/realm-execution/test/src/internal/realm_test_utils.cc new file mode 100644 index 0000000000..e381feb8de --- /dev/null +++ b/lib/realm-execution/test/src/internal/realm_test_utils.cc @@ -0,0 +1,28 @@ +#include "internal/realm_test_utils.h" +#include +#include + +namespace FlexFlow { + +static char *leak_string_contents(std::string const &str) { + // Realm command-line arguments require char* so intentionally leak the + // allocated string contents here + std::vector *content = new std::vector{str.begin(), str.end()}; + content->push_back(0); // NUL byte + return content->data(); +} + +std::vector make_fake_realm_args(positive_int num_cpus, + nonnegative_int num_gpus) { + std::vector result; + result.push_back(leak_string_contents("fake_executable_name")); + result.push_back(leak_string_contents("-ll:cpu")); + result.push_back(leak_string_contents(fmt::to_string(num_cpus))); + if (num_gpus > 0) { + result.push_back(leak_string_contents("-ll:gpu")); + result.push_back(leak_string_contents(fmt::to_string(num_gpus))); + } + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/internal/realm_test_utils.h b/lib/realm-execution/test/src/internal/realm_test_utils.h new file mode 100644 index 0000000000..8e2775ad8b --- /dev/null +++ b/lib/realm-execution/test/src/internal/realm_test_utils.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_TEST_SRC_INTERNAL_REALM_TEST_UTILS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_TEST_SRC_INTERNAL_REALM_TEST_UTILS_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/positive_int/positive_int.h" +#include + +namespace FlexFlow { + +std::vector make_fake_realm_args(positive_int num_cpus, + nonnegative_int num_gpus); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc index 5a5402a140..fb7dff01e3 100644 --- a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc @@ -1,4 +1,5 @@ #include "realm-execution/distributed_device_handle.h" +#include "internal/realm_test_utils.h" #include "realm-execution/realm_manager.h" #include @@ -9,11 +10,8 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DistributedDeviceHandle") { - // Construct some fake command line for our test - char fake_executable_name[] = "fake_executable_name"; - char arg0[] = "-ll:cpu"; - char arg1[] = "2"; - std::vector fake_args{fake_executable_name, arg0, arg1}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 5fe659cdc2..450d7fd3ec 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "internal/realm_test_utils.h" #include "realm-execution/distributed_device_handle.h" #include @@ -9,9 +10,8 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmManager") { - // Construct some fake command line for our test - char fake_executable_name[] = "fake_executable_name"; - std::vector fake_args{fake_executable_name}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); @@ -20,11 +20,10 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - Realm::Event event = - manager.start_controller([&](RealmContext &ctx) { - // Data is captured and retains value - ASSERT(some_data == 123); - }); + Realm::Event event = manager.start_controller([&](RealmContext &ctx) { + // Data is captured and retains value + ASSERT(some_data == 123); + }); // Need to block on the completion of the event to ensure we don't race, // because the lambda captures the environment event.wait(); diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 9592cb221c..33ad2bbbc1 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,5 +1,12 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/pcg_instance/pcg_instance.h" #include "realm-execution/realm_manager.h" +#include "utils/containers/require_only_key.h" #include namespace test { @@ -9,12 +16,172 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training") { - char fake_executable_name[] = "fake_executable_name"; - std::vector fake_args{fake_executable_name}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); + RealmManager manager(&fake_argc, &fake_argv); - (void)manager.start_controller([](RealmContext &ctx) {}); + + (void)manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor_backing = + allocator.allocate_tensor(output_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + TensorShape weight_shape_1 = TensorShape{ + TensorDims{FFOrdered{hidden_dim, data_dim}}, DataType::FLOAT}; + TensorShape weight_shape_2 = TensorShape{ + TensorDims{FFOrdered{output_dim, hidden_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer_with_grad(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_1, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_1 = + require_only_key(weights_layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_2, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_2 = + require_only_key(weights_layer_2.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{hidden_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_1, + }, + }); + parallel_tensor_guid_t t_linear_1 = + require_only_key(linear_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{output_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_linear_1, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_2, + }, + }); + parallel_tensor_guid_t t_linear_2 = + require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); + + MappedParallelComputationGraph mpcg{pcg, {}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedDeviceHandle device_handle = create_distributed_device_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/loss_attrs, + /*label_tensor=*/label_tensor, + /*logit_tensor=*/t_linear_2, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 5; + std::vector loss_values; + + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + // loss_values.push_back(copy_tensor_accessor_r( + // pcg_instance.get_loss_tensor_accessor().value(), + // allocator)); + } + + // // Assert that each sample in the batch has a lower loss in last epoch + // // than the first epoch + // GenericTensorAccessorR first_epoch_loss = loss_values.at(0); + // GenericTensorAccessorR last_epoch_loss = loss_values.back(); + // CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), + // check_kv("first_epoch_loss", + // format_accessor_r_contents(first_epoch_loss)), + // check_kv("last_epoch_loss", + // format_accessor_r_contents(last_epoch_loss))); + }); } } From 2476d92f5b62ff6b96245babb02aa82b8ae6a834 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 16:55:22 -0800 Subject: [PATCH 53/88] JSON serialization of a bunch of data types. --- lib/pcg/include/pcg/layer_guid_t.dtg.toml | 1 + .../mapped_operator_task_group.h | 12 ++++++ .../parallel_layer_guid_t.dtg.toml | 1 + .../mapped_operator_task_group.cc | 17 ++++++++ .../mapped_operator_task_group.cc | 42 ++++++++++++++++++ .../dynamic_layer_guid_t.dtg.toml | 1 + .../serializable_dynamic_node_attrs.dtg.toml | 43 +++++++++++++++++++ ...ializable_dynamic_node_invocation.dtg.toml | 33 ++++++++++++++ .../serializable_dynamic_value_attrs.dtg.toml | 34 +++++++++++++++ .../training_operation_attrs.dtg.toml | 1 + 10 files changed, 185 insertions(+) create mode 100644 lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.toml b/lib/pcg/include/pcg/layer_guid_t.dtg.toml index d73cf547da..2f2f7694a0 100644 --- a/lib/pcg/include/pcg/layer_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h index 5b1cad5e99..ebfdefa478 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -5,6 +5,7 @@ #include "pcg/machine_space_coordinate.dtg.h" #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "utils/bidict/bidict.h" +#include namespace FlexFlow { @@ -45,4 +46,15 @@ struct hash<::FlexFlow::MappedOperatorTaskGroup> { }; } // namespace std + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::MappedOperatorTaskGroup> { + static ::FlexFlow::MappedOperatorTaskGroup from_json(json const &j); + static void to_json(json &j, ::FlexFlow::MappedOperatorTaskGroup const &t); +}; + +} // namespace nlohmann + #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml index 618bcb0dc4..292b361fc8 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc index b96a447383..4436efd727 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -90,3 +90,20 @@ size_t hash<::FlexFlow::MappedOperatorTaskGroup>::operator()( } } // namespace std + +namespace nlohmann { + +::FlexFlow::MappedOperatorTaskGroup + adl_serializer<::FlexFlow::MappedOperatorTaskGroup>::from_json( + json const &j) { + return ::FlexFlow::MappedOperatorTaskGroup{j.template get< + ::FlexFlow::bidict<::FlexFlow::MachineSpaceCoordinate, + ::FlexFlow::OperatorAtomicTaskShardBinding>>()}; +} + +void adl_serializer<::FlexFlow::MappedOperatorTaskGroup>::to_json( + json &j, ::FlexFlow::MappedOperatorTaskGroup const &t) { + j = t.get_shard_bindings(); +} + +} // namespace nlohmann diff --git a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc new file mode 100644 index 0000000000..1c3667afc7 --- /dev/null +++ b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -0,0 +1,42 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + bidict + shard_bindings{ + {MachineSpaceCoordinate{0_n, 0_n, DeviceType::CPU}, + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, + ParallelTensorSpaceCoordinate{ + 0_n, 0_n, FFOrdered{1_n, 2_n, 3_n}}}, + }, + }}, + }; + MappedOperatorTaskGroup deserialized{shard_bindings}; + nlohmann::json serialized = shard_bindings; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + MappedOperatorTaskGroup result = serialized; + MappedOperatorTaskGroup correct = deserialized; + + CHECK(result == correct); + } + } +} diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml index c6e6673f33..bd64f52567 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml new file mode 100644 index 0000000000..3c43e1d637 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "SerializableDynamicNodeAttrs" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_task_type.dtg.h", + "pcg/machine_space_coordinate.dtg.h", + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", + "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", + "task-spec/dynamic_graph/training_operation_attrs.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "task_type" +type = "std::optional<::FlexFlow::DynamicTaskType>" + +[[fields]] +name = "device_coord" +type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "mapping" +type = "std::optional<::FlexFlow::MappedOperatorTaskGroup>" + +[[fields]] +name = "op_attrs" +type = "std::optional<::FlexFlow::TrainingOperationAttrs>" + +[[fields]] +name = "layer_guid" +type = "::FlexFlow::dynamic_layer_guid_t" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml new file mode 100644 index 0000000000..01f4cc8876 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "SerializableDynamicNodeInvocation" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "inputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" + +[[fields]] +name = "node_attrs" +type = "::FlexFlow::SerializableDynamicNodeAttrs" + +[[fields]] +name = "outputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml new file mode 100644 index 0000000000..05864b4b47 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "SerializableDynamicValueAttrs" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "std::optional<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "shard_coord" +type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" + +[[fields]] +name = "role" +type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 66c475b3a9..1051d8ac13 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ From 9c6de3c8a8b869739096f164eafbf7fbeb64debb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 17:20:49 -0800 Subject: [PATCH 54/88] Make more stuff serializable. --- .../parallel_tensor_guid_t.dtg.toml | 1 + lib/pcg/include/pcg/tensor_guid_t.dtg.toml | 1 + .../dynamic_tensor_guid_t.dtg.toml | 1 + .../serializable_dynamic_value_attrs.dtg.toml | 4 +++ .../serializable_dynamic_value_attrs.h | 16 +++++++++++ .../serializable_dynamic_value_attrs.cc | 27 +++++++++++++++++++ .../kwarg_dataflow_output.dtg.toml | 1 + 7 files changed, 51 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml index 4494a31ac2..2710a15664 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml index 151f7b1f0f..e8caf0021f 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml index 75e9099104..c9171b928b 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml index 05864b4b47..6209bfa247 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -21,6 +21,10 @@ src_includes = [ "utils/json/optional.h", ] +[[fields]] +name = "tensor_guid" +type = "::FlexFlow::dynamic_tensor_guid_t" + [[fields]] name = "parallel_tensor_shape" type = "std::optional<::FlexFlow::ParallelTensorShape>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h new file mode 100644 index 0000000000..6272265b7e --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_VALUE_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_VALUE_ATTRS_H + +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.h" + +namespace FlexFlow { + +SerializableDynamicValueAttrs + dynamic_value_attrs_to_serializable(DynamicValueAttrs const &); +DynamicValueAttrs dynamic_value_attrs_from_serializable( + SerializableDynamicValueAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc new file mode 100644 index 0000000000..2dc0b509ab --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc @@ -0,0 +1,27 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include + +namespace FlexFlow { + +SerializableDynamicValueAttrs + dynamic_value_attrs_to_serializable(DynamicValueAttrs const &attrs) { + return SerializableDynamicValueAttrs{ + /*tensor_guid=*/attrs.tensor_guid, + /*parallel_tensor_shape=*/attrs.parallel_tensor_shape, + /*shard_coord=*/attrs.shard_coord, + /*role=*/attrs.role, + }; +} + +DynamicValueAttrs dynamic_value_attrs_from_serializable( + SerializableDynamicValueAttrs const &attrs) { + return DynamicValueAttrs{ + /*tensor_guid=*/attrs.tensor_guid, + /*parallel_tensor_shape=*/attrs.parallel_tensor_shape, + /*shard_coord=*/attrs.shard_coord, + /*accessor=*/std::nullopt, + /*role=*/attrs.role, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml index f286fb90a7..5b537eac88 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] template_params = [ From 1d1586f532c0253658266357f935df7783b1dac0 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 22:18:55 -0800 Subject: [PATCH 55/88] To-do notes. --- .../src/realm-execution/pcg_instance/pcg_instance.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index de7cdcb687..199f2dc090 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -111,6 +111,10 @@ PCGInstance create_pcg_instance( // * external instances // * task argument serializer // * copies + // * parallel operator implementation (partition, reduce, gather, etc.) + // * and fused parallel operators (reduce + broadcast = allreduce) + // * memory-optimizing compiler integration (tensor creation/destruction, + // tensor reuse) } static std::unordered_map From 8e9cefce62ef9478a7cbe16b0be459f9329a8e3a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:17:15 -0800 Subject: [PATCH 56/88] More serialization routines. --- .../serializable_dynamic_node_attrs.h | 16 ++++++++++ .../serializable_dynamic_node_invocation.h | 16 ++++++++++ .../serializable_dynamic_node_attrs.cc | 29 +++++++++++++++++ .../serializable_dynamic_node_invocation.cc | 31 +++++++++++++++++++ 4 files changed, 92 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h new file mode 100644 index 0000000000..7a274a1e7b --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_ATTRS_H + +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.h" + +namespace FlexFlow { + +SerializableDynamicNodeAttrs + dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &); +DynamicNodeAttrs + dynamic_node_attrs_from_serializable(SerializableDynamicNodeAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h new file mode 100644 index 0000000000..2bcdb9a898 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_INVOCATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_INVOCATION_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +SerializableDynamicNodeInvocation + dynamic_node_invocation_to_serializable(DynamicNodeInvocation const &); +DynamicNodeInvocation dynamic_node_invocation_from_serializable( + SerializableDynamicNodeInvocation const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc new file mode 100644 index 0000000000..d613194d14 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc @@ -0,0 +1,29 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.h" +#include + +namespace FlexFlow { + +SerializableDynamicNodeAttrs + dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &attrs) { + return SerializableDynamicNodeAttrs{ + /*task_type=*/attrs.task_type, + /*device_coord=*/attrs.device_coord, + /*mapping=*/attrs.mapping, + /*op_attrs=*/attrs.op_attrs, + /*layer_guid=*/attrs.layer_guid, + }; +} + +DynamicNodeAttrs dynamic_node_attrs_from_serializable( + SerializableDynamicNodeAttrs const &attrs) { + return DynamicNodeAttrs{ + /*task_type=*/attrs.task_type, + /*device_coord=*/attrs.device_coord, + /*mapping=*/attrs.mapping, + /*op_attrs=*/attrs.op_attrs, + /*layer_guid=*/attrs.layer_guid, + /*per_device_op_state=*/std::nullopt, + }; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc new file mode 100644 index 0000000000..334623ee67 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc @@ -0,0 +1,31 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +SerializableDynamicNodeInvocation dynamic_node_invocation_to_serializable( + DynamicNodeInvocation const &invocation) { + return SerializableDynamicNodeInvocation{ + /*inputs=*/map_values(invocation.inputs, + dynamic_value_attrs_to_serializable), + /*node_attrs=*/dynamic_node_attrs_to_serializable(invocation.node_attrs), + /*outputs=*/ + map_values(invocation.outputs, dynamic_value_attrs_to_serializable), + }; +} + +DynamicNodeInvocation dynamic_node_invocation_from_serializable( + SerializableDynamicNodeInvocation const &invocation) { + return DynamicNodeInvocation{ + /*inputs=*/map_values(invocation.inputs, + dynamic_value_attrs_from_serializable), + /*node_attrs=*/ + dynamic_node_attrs_from_serializable(invocation.node_attrs), + /*outputs=*/ + map_values(invocation.outputs, dynamic_value_attrs_from_serializable), + }; +} + +} // namespace FlexFlow From 365dca0b036197c70630b06f18e32b9c031a62a6 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:18:49 -0800 Subject: [PATCH 57/88] Most of serializer finished. --- .../serializable_realm_processor.dtg.toml | 17 ++++++ .../serializer/serializable_realm_processor.h | 16 +++++ .../tasks/serializer/task_arg_serializer.h | 26 ++++++++ .../tasks/impl/device_state_init_task.cc | 61 +++++++++++++------ .../serializable_realm_processor.cc | 15 +++++ 5 files changed, 115 insertions(+), 20 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml new file mode 100644 index 0000000000..3cb64d95c1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SerializableRealmProcessor" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", +] + +[[fields]] +name = "id" +type = "::FlexFlow::Realm::Processor::id_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h new file mode 100644 index 0000000000..6b29b6e223 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_PROCESSOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_PROCESSOR_H + +#include "realm-execution/realm.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" + +namespace FlexFlow { + +SerializableRealmProcessor + realm_processor_to_serializable(Realm::Processor const &); +Realm::Processor + realm_processor_from_serializable(SerializableRealmProcessor const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h new file mode 100644 index 0000000000..fc5abba587 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_TASK_ARG_SERIALIZER_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_TASK_ARG_SERIALIZER_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::string serialize_task_args(T const &args) { + nlohmann::json j; + args.serialize(j); + return j.dump(); +} + +template +T deserialize_task_args(void const *args, size_t arglen) { + nlohmann::json j = nlohmann::json::parse( + std::string_view{reinterpret_cast(args), arglen}); + return T::deserialize(j); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 5a51b1c803..0e7730e485 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -3,11 +3,16 @@ #include "local-execution/device_state_initialization.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" +#include "utils/exception.h" #include "utils/optional.h" +#include #include #include @@ -19,11 +24,11 @@ namespace FlexFlow { struct DeviceStateInitTaskArgs { DeviceStateInitTaskArgs() = delete; DeviceStateInitTaskArgs( - DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), @@ -31,12 +36,28 @@ struct DeviceStateInitTaskArgs { optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + void serialize(nlohmann::json &j) const { + j = { + {"invocation", dynamic_node_invocation_to_serializable(invocation)}, + {"profiling_settings", profiling_settings}, + // {"device_handle", device_handle}, + {"iteration_config", iteration_config}, + {"optimizer_attrs", optimizer_attrs}, + {"origin_proc", realm_processor_to_serializable(origin_proc)}, + {"origin_result_ptr", reinterpret_cast(origin_result_ptr)}, + }; + } + + static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { + NOT_IMPLEMENTED(); + } + public: - DynamicNodeInvocation const *invocation; - ProfilingSettings const *profiling_settings; + DynamicNodeInvocation invocation; + ProfilingSettings profiling_settings; DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig const *iteration_config; - OptimizerAttrs const *optimizer_attrs; + FFIterationConfig iteration_config; + OptimizerAttrs optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; @@ -46,9 +67,8 @@ void device_state_init_task_body(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceStateInitTaskArgs)); DeviceStateInitTaskArgs task_args = - *reinterpret_cast(args); + deserialize_task_args(args, arglen); // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); @@ -58,12 +78,12 @@ void device_state_init_task_body(void const *args, device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); DynamicNodeInvocation result_invocation = - initialize_node(*task_args.invocation, + initialize_node(task_args.invocation, ctx.get_current_device_allocator(), - *task_args.profiling_settings, + task_args.profiling_settings, device_handle, - *task_args.iteration_config, - *task_args.optimizer_attrs, + task_args.iteration_config, + task_args.optimizer_attrs, ctx.get_current_device_idx()); DeviceSpecificPerDeviceOpState result_state = assert_unwrap(result_invocation.node_attrs.per_device_op_state); @@ -89,11 +109,11 @@ std::optional spawn_device_state_init_task( DeviceSpecificPerDeviceOpState *result_ptr, Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ - &invocation, - &profiling_settings, + invocation, + profiling_settings, device_handle, - &iteration_config, - &optimizer_attrs, + iteration_config, + optimizer_attrs, ctx.get_current_processor(), result_ptr, }; @@ -105,10 +125,11 @@ std::optional spawn_device_state_init_task( }), get_init_task_id_for_op_attrs); if (task_id.has_value()) { + std::string args = serialize_task_args(task_args); return ctx.spawn_task(target_proc, assert_unwrap(task_id), - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc new file mode 100644 index 0000000000..b16e2891c4 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc @@ -0,0 +1,15 @@ +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" + +namespace FlexFlow { + +SerializableRealmProcessor + realm_processor_to_serializable(Realm::Processor const &proc) { + return SerializableRealmProcessor{proc.id}; +} + +Realm::Processor + realm_processor_from_serializable(SerializableRealmProcessor const &proc) { + return Realm::Processor{proc.id}; +} + +} // namespace FlexFlow From 2c1949326b2369b72446af6a5c58146383ec2c69 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:41:51 -0800 Subject: [PATCH 58/88] Finish serialization of device init task. --- ...ce_specific_managed_per_device_ff_handle.h | 6 ++++ ...e_specific_managed_per_device_ff_handle.cc | 28 +++++++++++++++++++ .../tasks/impl/device_state_init_task.cc | 24 ++++++++++++++-- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index 19a70491a2..45617ffcbf 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,6 +4,8 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" +#include +#include namespace FlexFlow { @@ -15,6 +17,10 @@ struct DeviceSpecificManagedPerDeviceFFHandle { std::optional get(device_id_t device_idx) const; + void serialize(nlohmann::json &j) const; + static DeviceSpecificManagedPerDeviceFFHandle + deserialize(nlohmann::json const &j); + private: device_id_t owner; std::optional handle; diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index 99ff7a6dd6..ea0782fd4b 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -1,5 +1,8 @@ #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "kernels/device_handle_t.h" +#include "utils/containers/transform.h" +#include "utils/json/optional.h" +#include namespace FlexFlow { @@ -13,6 +16,31 @@ std::optional return this->handle; } +void DeviceSpecificManagedPerDeviceFFHandle::serialize( + nlohmann::json &j) const { + j = { + {"owner", owner}, + {"handle", + transform(handle, + [](ManagedPerDeviceFFHandle *ptr) { + return reinterpret_cast(ptr); + })}, + }; +} + +DeviceSpecificManagedPerDeviceFFHandle + DeviceSpecificManagedPerDeviceFFHandle::deserialize( + nlohmann::json const &j) { + return DeviceSpecificManagedPerDeviceFFHandle{ + /*owner=*/j.at("owner").get(), + /*handle=*/ + transform(j.at("handle").get>(), + [](uintptr_t ptrval) { + return reinterpret_cast(ptrval); + }), + }; +} + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &device_id, std::optional const &managed_handle) { diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 0e7730e485..312c3f2401 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -3,6 +3,7 @@ #include "local-execution/device_state_initialization.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" #include "realm-execution/tasks/serializer/serializable_realm_processor.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" @@ -10,7 +11,6 @@ #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" -#include "utils/exception.h" #include "utils/optional.h" #include #include @@ -37,10 +37,12 @@ struct DeviceStateInitTaskArgs { origin_result_ptr(origin_result_ptr) {} void serialize(nlohmann::json &j) const { + nlohmann::json j_device_handle; + device_handle.serialize(j_device_handle); j = { {"invocation", dynamic_node_invocation_to_serializable(invocation)}, {"profiling_settings", profiling_settings}, - // {"device_handle", device_handle}, + {"device_handle", j_device_handle}, {"iteration_config", iteration_config}, {"optimizer_attrs", optimizer_attrs}, {"origin_proc", realm_processor_to_serializable(origin_proc)}, @@ -49,7 +51,23 @@ struct DeviceStateInitTaskArgs { } static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { - NOT_IMPLEMENTED(); + return DeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable( + j.at("invocation").get()), + /*profiling_settings=*/ + j.at("profiling_settings").get(), + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize( + j.at("device_handle")), + /*iteration_config=*/j.at("iteration_config").get(), + /*optimizer_attrs=*/j.at("optimizer_attrs").get(), + /*origin_proc=*/ + realm_processor_from_serializable( + j.at("origin_proc").get()), + /*origin_result_ptr=*/ + reinterpret_cast( + j.at("origin_result_ptr").get()), + }; } public: From d05b73ef50c932207038d27d18280d61bde704f2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 14:55:30 -0800 Subject: [PATCH 59/88] Switch over to explicit DTGs for task arguments and serialization. --- ...ce_specific_managed_per_device_ff_handle.h | 5 +- .../device_handle_init_task_args.dtg.toml | 26 ++++++ .../impl/device_state_init_task_args.dtg.toml | 42 ++++++++++ ...able_device_handle_init_task_args.dtg.toml | 30 +++++++ ...erializable_device_handle_init_task_args.h | 17 ++++ ...zable_device_state_init_task_args.dtg.toml | 48 +++++++++++ ...serializable_device_state_init_task_args.h | 16 ++++ .../serializable_device_specific_ptr.dtg.toml | 28 +++++++ .../tasks/serializer/task_arg_serializer.h | 5 +- ...e_specific_managed_per_device_ff_handle.cc | 24 +++--- .../pcg_instance/pcg_instance.cc | 1 + .../tasks/impl/device_handle_init_task.cc | 35 ++------ .../tasks/impl/device_state_init_task.cc | 82 ++----------------- ...rializable_device_handle_init_task_args.cc | 28 +++++++ ...erializable_device_state_init_task_args.cc | 36 ++++++++ 15 files changed, 304 insertions(+), 119 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index 45617ffcbf..d48a80f438 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,6 +4,7 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" +#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h" #include #include @@ -17,9 +18,9 @@ struct DeviceSpecificManagedPerDeviceFFHandle { std::optional get(device_id_t device_idx) const; - void serialize(nlohmann::json &j) const; + SerializableDeviceSpecificPtr serialize() const; static DeviceSpecificManagedPerDeviceFFHandle - deserialize(nlohmann::json const &j); + deserialize(SerializableDeviceSpecificPtr const &j); private: device_id_t owner; diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml new file mode 100644 index 0000000000..c0ba37bb5d --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DeviceHandleInitTaskArgs" +type = "struct" +features = [] + +includes = [ + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_realm_processor.h", +] + +[[fields]] +name = "workSpaceSize" +type = "size_t" + +[[fields]] +name = "allowTensorOpMathConversion" +type = "bool" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::Realm::Processor" + +[[fields]] +name = "origin_result_ptr" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle *" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml new file mode 100644 index 0000000000..a9aa77dde9 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "DeviceStateInitTaskArgs" +type = "struct" +features = [] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/realm.h", + "task-spec/device_specific_per_device_op_state.dtg.h", + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::DynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::Realm::Processor" + +[[fields]] +name = "origin_result_ptr" +type = "::FlexFlow::DeviceSpecificPerDeviceOpState *" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml new file mode 100644 index 0000000000..3a187924c8 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "SerializableDeviceHandleInitTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", +] + +[[fields]] +name = "workSpaceSize" +type = "size_t" + +[[fields]] +name = "allowTensorOpMathConversion" +type = "bool" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::SerializableRealmProcessor" + +[[fields]] +name = "origin_result_ptr" +type = "uintptr_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h new file mode 100644 index 0000000000..b239221c16 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H + +#include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.h" + +namespace FlexFlow { + +SerializableDeviceHandleInitTaskArgs + device_handle_init_task_args_to_serializable( + DeviceHandleInitTaskArgs const &); +DeviceHandleInitTaskArgs device_handle_init_task_args_from_serializable( + SerializableDeviceHandleInitTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml new file mode 100644 index 0000000000..68076b7d70 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -0,0 +1,48 @@ +namespace = "FlexFlow" +name = "SerializableDeviceStateInitTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", + "task-spec/device_specific_per_device_op_state.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::SerializableDynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::SerializableDeviceSpecificPtr" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::SerializableRealmProcessor" + +[[fields]] +name = "origin_result_ptr" +type = "uintptr_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h new file mode 100644 index 0000000000..2467f2067c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H + +#include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.h" + +namespace FlexFlow { + +SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( + DeviceStateInitTaskArgs const &); +DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( + SerializableDeviceStateInitTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml new file mode 100644 index 0000000000..07cf61f7e1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "SerializableDeviceSpecificPtr" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "pcg/device_id_t.dtg.h", + "cstdint", + "optional", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "device_idx" +type = "::FlexFlow::device_id_t" + +[[fields]] +name = "ptr" +type = "std::optional" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h index fc5abba587..3208368d2d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h @@ -9,8 +9,7 @@ namespace FlexFlow { template std::string serialize_task_args(T const &args) { - nlohmann::json j; - args.serialize(j); + nlohmann::json j = args; return j.dump(); } @@ -18,7 +17,7 @@ template T deserialize_task_args(void const *args, size_t arglen) { nlohmann::json j = nlohmann::json::parse( std::string_view{reinterpret_cast(args), arglen}); - return T::deserialize(j); + return j.get(); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index ea0782fd4b..6e0cef0bb2 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -16,25 +16,25 @@ std::optional return this->handle; } -void DeviceSpecificManagedPerDeviceFFHandle::serialize( - nlohmann::json &j) const { - j = { - {"owner", owner}, - {"handle", - transform(handle, - [](ManagedPerDeviceFFHandle *ptr) { - return reinterpret_cast(ptr); - })}, +SerializableDeviceSpecificPtr + DeviceSpecificManagedPerDeviceFFHandle::serialize() const { + return SerializableDeviceSpecificPtr{ + /*device_idx=*/owner, + /*ptr=*/ + transform(handle, + [](ManagedPerDeviceFFHandle *ptr) { + return reinterpret_cast(ptr); + }), }; } DeviceSpecificManagedPerDeviceFFHandle DeviceSpecificManagedPerDeviceFFHandle::deserialize( - nlohmann::json const &j) { + SerializableDeviceSpecificPtr const &handle) { return DeviceSpecificManagedPerDeviceFFHandle{ - /*owner=*/j.at("owner").get(), + /*owner=*/handle.device_idx, /*handle=*/ - transform(j.at("handle").get>(), + transform(handle.ptr, [](uintptr_t ptrval) { return reinterpret_cast(ptrval); }), diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 199f2dc090..8e6ab022aa 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -110,6 +110,7 @@ PCGInstance create_pcg_instance( // TODO list: // * external instances // * task argument serializer + // * pass instances to task and convert to tensor accessor // * copies // * parallel operator implementation (partition, reduce, gather, etc.) // * and fused parallel operators (reduce + broadcast = allreduce) diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index cd5608ca7e..5cd53ea062 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -1,33 +1,14 @@ #include "realm-execution/tasks/impl/device_handle_init_task.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceHandleInitTaskArgs { - DeviceHandleInitTaskArgs() = delete; - DeviceHandleInitTaskArgs( - size_t workSpaceSize, - bool allowTensorOpMathConversion, - Realm::Processor origin_proc, - DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) - : workSpaceSize(workSpaceSize), - allowTensorOpMathConversion(allowTensorOpMathConversion), - origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} - -public: - size_t workSpaceSize; - bool allowTensorOpMathConversion; - Realm::Processor origin_proc; - DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; -}; -static_assert(std::is_trivially_copy_constructible_v); - static std::optional make_device_handle_for_processor(Realm::Processor processor, size_t workSpaceSize, @@ -52,12 +33,10 @@ void device_handle_init_task_body(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceHandleInitTaskArgs)); DeviceHandleInitTaskArgs task_args = - *reinterpret_cast(args); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + device_handle_init_task_args_from_serializable( + deserialize_task_args(args, + arglen)); RealmContext ctx{proc}; DeviceSpecificManagedPerDeviceFFHandle managed_handle = @@ -89,6 +68,8 @@ Realm::Event spawn_device_handle_init_task( result_ptr, }; + std::string args = serialize_task_args( + device_handle_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, task_id_t::DEVICE_HANDLE_INIT_TASK_ID, &task_args, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 312c3f2401..99c72cf5e7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,95 +1,26 @@ #include "realm-execution/tasks/impl/device_state_init_task.h" -#include "kernels/device_handle_t.dtg.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" -#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" -#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" -#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" -#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "utils/optional.h" -#include #include #include namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceStateInitTaskArgs { - DeviceStateInitTaskArgs() = delete; - DeviceStateInitTaskArgs( - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) - : invocation(invocation), profiling_settings(profiling_settings), - device_handle(device_handle), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), - origin_result_ptr(origin_result_ptr) {} - - void serialize(nlohmann::json &j) const { - nlohmann::json j_device_handle; - device_handle.serialize(j_device_handle); - j = { - {"invocation", dynamic_node_invocation_to_serializable(invocation)}, - {"profiling_settings", profiling_settings}, - {"device_handle", j_device_handle}, - {"iteration_config", iteration_config}, - {"optimizer_attrs", optimizer_attrs}, - {"origin_proc", realm_processor_to_serializable(origin_proc)}, - {"origin_result_ptr", reinterpret_cast(origin_result_ptr)}, - }; - } - - static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { - return DeviceStateInitTaskArgs{ - /*invocation=*/dynamic_node_invocation_from_serializable( - j.at("invocation").get()), - /*profiling_settings=*/ - j.at("profiling_settings").get(), - /*device_handle=*/ - DeviceSpecificManagedPerDeviceFFHandle::deserialize( - j.at("device_handle")), - /*iteration_config=*/j.at("iteration_config").get(), - /*optimizer_attrs=*/j.at("optimizer_attrs").get(), - /*origin_proc=*/ - realm_processor_from_serializable( - j.at("origin_proc").get()), - /*origin_result_ptr=*/ - reinterpret_cast( - j.at("origin_result_ptr").get()), - }; - } - -public: - DynamicNodeInvocation invocation; - ProfilingSettings profiling_settings; - DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig iteration_config; - OptimizerAttrs optimizer_attrs; - Realm::Processor origin_proc; - DeviceSpecificPerDeviceOpState *origin_result_ptr; -}; - void device_state_init_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { DeviceStateInitTaskArgs task_args = - deserialize_task_args(args, arglen); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + device_state_init_task_args_from_serializable( + deserialize_task_args(args, + arglen)); RealmContext ctx{proc}; device_handle_t device_handle = @@ -143,7 +74,8 @@ std::optional spawn_device_state_init_task( }), get_init_task_id_for_op_attrs); if (task_id.has_value()) { - std::string args = serialize_task_args(task_args); + std::string args = serialize_task_args( + device_state_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, assert_unwrap(task_id), args.data(), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc new file mode 100644 index 0000000000..a44a5a5db1 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc @@ -0,0 +1,28 @@ +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.h" + +namespace FlexFlow { + +SerializableDeviceHandleInitTaskArgs + device_handle_init_task_args_to_serializable( + DeviceHandleInitTaskArgs const &args) { + return SerializableDeviceHandleInitTaskArgs{ + /*workSpaceSize=*/args.workSpaceSize, + /*allowTensorOpMathConversion=*/args.allowTensorOpMathConversion, + /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), + /*origin_result_ptr=*/reinterpret_cast(args.origin_result_ptr), + }; +} + +DeviceHandleInitTaskArgs device_handle_init_task_args_from_serializable( + SerializableDeviceHandleInitTaskArgs const &args) { + return DeviceHandleInitTaskArgs{ + /*workSpaceSize=*/args.workSpaceSize, + /*allowTensorOpMathConversion=*/args.allowTensorOpMathConversion, + /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), + /*origin_result_ptr=*/ + reinterpret_cast( + args.origin_result_ptr), + }; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc new file mode 100644 index 0000000000..528ff26867 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -0,0 +1,36 @@ +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" + +namespace FlexFlow { + +SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( + DeviceStateInitTaskArgs const &args) { + return SerializableDeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/args.device_handle.serialize(), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), + /*origin_result_ptr=*/reinterpret_cast(args.origin_result_ptr), + }; +} + +DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( + SerializableDeviceStateInitTaskArgs const &args) { + return DeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), + /*origin_result_ptr=*/ + reinterpret_cast( + args.origin_result_ptr), + }; +} + +} // namespace FlexFlow From 6a380ce941a56c5b994c0dfe1b32e7c6928a9a31 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 15:13:38 -0800 Subject: [PATCH 60/88] Convert op task args. --- .../tasks/impl/op_task_args.dtg.toml | 32 ++++++++++ ...able_device_handle_init_task_args.dtg.toml | 1 - ...erializable_device_handle_init_task_args.h | 4 +- ...zable_device_state_init_task_args.dtg.toml | 1 - ...serializable_device_state_init_task_args.h | 4 +- .../impl/serializable_op_task_args.dtg.toml | 42 +++++++++++++ .../tasks/impl/serializable_op_task_args.h | 14 +++++ .../tasks/impl/device_handle_init_task.cc | 4 +- .../src/realm-execution/tasks/impl/op_task.cc | 60 ++++++------------- .../tasks/impl/serializable_op_task_args.cc | 27 +++++++++ 10 files changed, 139 insertions(+), 50 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml new file mode 100644 index 0000000000..814f9f802b --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "OpTaskArgs" +type = "struct" +features = [] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::DynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml index 3a187924c8..34f52880f8 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml @@ -9,7 +9,6 @@ features = [ ] includes = [ - "realm-execution/realm.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", ] diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h index b239221c16..63d70fe10a 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_ARGS_H #include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml index 68076b7d70..c99d2758c0 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -11,7 +11,6 @@ features = [ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", - "realm-execution/realm.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", "task-spec/device_specific_per_device_op_state.dtg.h", diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h index 2467f2067c..f028820974 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_ARGS_H #include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml new file mode 100644 index 0000000000..a0f89e3ae2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "SerializableOpTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::SerializableDynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::SerializableDeviceSpecificPtr" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h new file mode 100644 index 0000000000..3b2d05d0b6 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_OP_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_OP_TASK_ARGS_H + +#include "realm-execution/tasks/impl/op_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_op_task_args.dtg.h" + +namespace FlexFlow { + +SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &); +OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index 5cd53ea062..b806aa1277 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -72,8 +72,8 @@ Realm::Event spawn_device_handle_init_task( device_handle_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, task_id_t::DEVICE_HANDLE_INIT_TASK_ID, - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index e17973febb..d8b8873442 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/impl/op_task.h" #include "local-execution/task_execution.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "realm-execution/tasks/impl/op_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" #include "utils/optional.h" @@ -8,59 +11,31 @@ namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct OpTaskArgs { -public: - OpTaskArgs() = delete; - OpTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const *iteration_config, - std::optional const *optimizer_attrs, - Realm::Processor origin_proc) - : invocation(invocation), profiling_settings(profiling_settings), - device_handle(device_handle), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs) {} - -public: - DynamicNodeInvocation const *invocation; - ProfilingSettings const *profiling_settings; - DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig const *iteration_config; - std::optional const *optimizer_attrs; - Realm::Processor origin_proc; -}; - void op_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(OpTaskArgs)); - OpTaskArgs task_args = *reinterpret_cast(args); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + OpTaskArgs task_args = op_task_args_from_serializable( + deserialize_task_args(args, arglen)); RealmContext ctx{proc}; device_handle_t device_handle = device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); execute_dynamic_node_invocation( - /*invocation=*/*task_args.invocation, + /*invocation=*/task_args.invocation, /*allocator=*/ctx.get_current_device_allocator(), - /*profiling_settings=*/*task_args.profiling_settings, + /*profiling_settings=*/task_args.profiling_settings, /*ff_handle=*/device_handle, /*per_device_op_state=*/ - transform(task_args.invocation->node_attrs.per_device_op_state, + transform(task_args.invocation.node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { return get_device_state_from_device_specific( op_state, ctx.get_current_device_idx()); }), - /*iteration_config=*/*task_args.iteration_config, - /*optimizer_attrs=*/*task_args.optimizer_attrs, + /*iteration_config=*/task_args.iteration_config, + /*optimizer_attrs=*/task_args.optimizer_attrs, /*device_idx=*/ctx.get_current_device_idx()); } @@ -73,17 +48,18 @@ Realm::Event FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, Realm::Event precondition) { - OpTaskArgs task_args{&invocation, - &profiling_settings, + OpTaskArgs task_args{invocation, + profiling_settings, device_handle, - &iteration_config, - &optimizer_attrs, - ctx.get_current_processor()}; + iteration_config, + optimizer_attrs}; + std::string args = + serialize_task_args(op_task_args_to_serializable(task_args)); return ctx.spawn_task( target_proc, assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc new file mode 100644 index 0000000000..0513bc6df7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -0,0 +1,27 @@ +#include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" + +namespace FlexFlow { + +SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { + return SerializableOpTaskArgs{ + /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/args.device_handle.serialize(), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + }; +} + +OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { + return OpTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + }; +} + +} // namespace FlexFlow From a46dd46e40a9c81f0de77924a3b810915eda56fb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:37:07 -0800 Subject: [PATCH 61/88] Map the PCG for test. --- .../test/src/realm-execution/test_e2e.cc | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 33ad2bbbc1..8e5edf72ad 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,7 +1,12 @@ #include "internal/realm_test_utils.h" #include "kernels/allocation.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/distributed_device_handle.h" #include "realm-execution/pcg_instance/pcg_instance.h" @@ -126,7 +131,44 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_linear_2 = require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); - MappedParallelComputationGraph mpcg{pcg, {}}; + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {linear_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {linear_operator_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + }, + }; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ From 056312fbe2f1c393537703dd1f8ee934badd0532 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:43:29 -0800 Subject: [PATCH 62/88] Fix a bug in shard expansion. --- lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index 33b7fb8591..402e0ef055 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -15,7 +15,7 @@ bool value_is_shard_expanded(DynamicValueAttrs const &n) { bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { auto slot_is_shard_expanded = [](DynamicTensorSlot const &) -> bool { - return true; + return false; }; return no_part_of_dynamic_graph_satisfies(g, From c44035f55b0515a091f6e722a9fc1193c54dfcb5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:53:00 -0800 Subject: [PATCH 63/88] Finish body of instance allocation. --- .../src/realm-execution/instance_allocation.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index c033f0bac1..b740859e22 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -5,6 +5,7 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -14,6 +15,7 @@ #include "utils/containers/make.h" #include "utils/containers/map_values.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" #include "utils/exception.h" #include "utils/optional.h" @@ -59,6 +61,15 @@ TensorInstanceBacking perform_instance_allocation( } }; + for (DynamicNodeInvocation const &invocation : g.invocations) { + for (DynamicValueAttrs const &input : values(invocation.inputs)) { + allocate(invocation.node_attrs, input); + } + for (DynamicValueAttrs const &output : values(invocation.outputs)) { + allocate(invocation.node_attrs, output); + } + } + return result; } From b9417d0c5105d50379ecbfc798252f2e054e01c9 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 09:59:14 -0800 Subject: [PATCH 64/88] Fix some bugs in loss insertion, instance allocation. --- .../computation_graph_instance.cc | 3 ++- .../realm-execution/pcg_instance/pcg_instance.h | 1 + .../src/realm-execution/instance_allocation.cc | 5 ++--- .../src/realm-execution/pcg_instance/pcg_instance.cc | 4 +++- .../test/src/realm-execution/test_e2e.cc | 7 +++++++ .../include/task-spec/dynamic_graph/loss_insertion.h | 9 ++++++--- .../src/task-spec/dynamic_graph/loss_insertion.cc | 10 ++++++---- 7 files changed, 27 insertions(+), 12 deletions(-) diff --git a/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc b/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc index e251fafe5f..40d9b187c4 100644 --- a/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc +++ b/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc @@ -81,7 +81,8 @@ ComputationGraphInstance create_computation_graph_instance( auto [loss_inserted_dg, label_v, logit_grad_v] = perform_loss_insertion( dg, assert_unwrap(loss_attrs), - dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}); + dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}, + std::nullopt); dg = loss_inserted_dg; logit_grad_value = logit_grad_v; inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index b0037f51b2..fa163d1419 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -51,6 +51,7 @@ PCGInstance create_pcg_instance( std::optional const &loss_attrs, std::optional label_tensor, std::optional logit_tensor, + std::optional const &loss_mapping, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index b740859e22..797455573c 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -52,12 +52,11 @@ TensorInstanceBacking perform_instance_allocation( // FIXME: Attach external instance to existing allocation and use that NOT_IMPLEMENTED(); } else { - if (contains_key(result.backing, v)) { - return result.backing.at(v); - } else { + if (!contains_key(result.backing, v)) { result.backing.insert( std::pair{v, perform_instance_allocation_for_value(n, v, ctx)}); } + return result.backing.at(v); } }; diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 8e6ab022aa..7b047bcb72 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -54,6 +54,7 @@ PCGInstance create_pcg_instance( std::optional const &loss_attrs, std::optional label_tensor, std::optional logit_tensor, + std::optional const &loss_mapping, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, @@ -71,7 +72,8 @@ PCGInstance create_pcg_instance( auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( dg, assert_unwrap(loss_attrs), - dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}); + dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}, + loss_mapping); dg = dg2; logit_grad_value = logit_grad_v; inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 8e5edf72ad..4dbfe09045 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -169,6 +169,12 @@ TEST_SUITE(FF_TEST_SUITE) { }}}}}}, }, }; + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ @@ -194,6 +200,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*loss=*/loss_attrs, /*label_tensor=*/label_tensor, /*logit_tensor=*/t_linear_2, + /*loss_mapping=*/loss_mapping, /*input_tensors=*/input_tensors, /*profiling_settings=*/ProfilingSettings{0, 0}, /*device_handle=*/device_handle, diff --git a/lib/task-spec/include/task-spec/dynamic_graph/loss_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/loss_insertion.h index c7cef3f06f..b3b2a465f8 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/loss_insertion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/loss_insertion.h @@ -6,12 +6,15 @@ #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/dynamic_graph/loss_insertion_result.dtg.h" +#include namespace FlexFlow { -LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, - LossAttrs const &loss_attrs, - dynamic_tensor_guid_t logit_tensor); +LossInsertionResult perform_loss_insertion( + DynamicOpenDataflowGraph const &dg, + LossAttrs const &loss_attrs, + dynamic_tensor_guid_t logit_tensor, + std::optional const &loss_mapping); } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 4270119612..857fed1a84 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -12,9 +12,11 @@ namespace FlexFlow { -LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, - LossAttrs const &loss_attrs, - dynamic_tensor_guid_t logit_tensor) { +LossInsertionResult perform_loss_insertion( + DynamicOpenDataflowGraph const &dg, + LossAttrs const &loss_attrs, + dynamic_tensor_guid_t logit_tensor, + std::optional const &loss_mapping) { DynamicValueAttrs logit_value = assert_unwrap( find_output_value_attrs(dg, logit_tensor, mk_dynamic_tensor_role_fwd())); @@ -45,7 +47,7 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::LOSS, /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, + /*mapping=*/loss_mapping, /*op_attrs=*/TrainingOperationAttrs{loss_attrs}, /*layer_guid=*/mk_dynamic_layer_guid_for_loss(), /*per_device_op_state=*/std::nullopt, From aec4a197fdc9cdde7b7d48fef2fe27446fb8a1f7 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 10:53:26 -0800 Subject: [PATCH 65/88] Fixes for PCG initialization. --- .../include/realm-execution/fmt/realm_event.h | 35 +++++++++++++++++++ .../fmt/{instance.h => realm_instance.h} | 4 +-- .../tensor_instance_backing.dtg.toml | 5 +-- .../src/realm-execution/fmt/realm_event.cc | 10 ++++++ .../fmt/{instance.cc => realm_instance.cc} | 2 +- .../pcg_instance/pcg_instance.cc | 17 +++++++++ .../tasks/realm_task_registry.cc | 2 ++ 7 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/fmt/realm_event.h rename lib/realm-execution/include/realm-execution/fmt/{instance.h => realm_instance.h} (96%) create mode 100644 lib/realm-execution/src/realm-execution/fmt/realm_event.cc rename lib/realm-execution/src/realm-execution/fmt/{instance.cc => realm_instance.cc} (80%) diff --git a/lib/realm-execution/include/realm-execution/fmt/realm_event.h b/lib/realm-execution/include/realm-execution/fmt/realm_event.h new file mode 100644 index 0000000000..a7df28ced6 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/fmt/realm_event.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H + +#include "realm-execution/realm.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter<::FlexFlow::Realm::Event, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::FlexFlow::Realm::Event const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = fmt::format("", m.id); + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::Event const &m); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/fmt/instance.h b/lib/realm-execution/include/realm-execution/fmt/realm_instance.h similarity index 96% rename from lib/realm-execution/include/realm-execution/fmt/instance.h rename to lib/realm-execution/include/realm-execution/fmt/realm_instance.h index c7c2df6735..e6d2846c1f 100644 --- a/lib/realm-execution/include/realm-execution/fmt/instance.h +++ b/lib/realm-execution/include/realm-execution/fmt/realm_instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_REALM_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_REALM_INSTANCE_H #include "realm-execution/realm.h" #include "utils/check_fmtable.h" diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml index e6a8bd58d9..6c43990282 100644 --- a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -3,7 +3,7 @@ name = "TensorInstanceBacking" type = "struct" features = [ "eq", - #"fmt", + "fmt", #"hash", ] @@ -14,7 +14,8 @@ includes = [ ] src_includes = [ - "realm-execution/fmt/instance.h", + "realm-execution/fmt/realm_event.h", + "realm-execution/fmt/realm_instance.h", "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", ] diff --git a/lib/realm-execution/src/realm-execution/fmt/realm_event.cc b/lib/realm-execution/src/realm-execution/fmt/realm_event.cc new file mode 100644 index 0000000000..7c5ad7d848 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/fmt/realm_event.cc @@ -0,0 +1,10 @@ +#include "realm-execution/fmt/realm_event.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::Event const &m) { + return s << fmt::to_string(m); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/fmt/instance.cc b/lib/realm-execution/src/realm-execution/fmt/realm_instance.cc similarity index 80% rename from lib/realm-execution/src/realm-execution/fmt/instance.cc rename to lib/realm-execution/src/realm-execution/fmt/realm_instance.cc index f8eabe9bb0..301954f824 100644 --- a/lib/realm-execution/src/realm-execution/fmt/instance.cc +++ b/lib/realm-execution/src/realm-execution/fmt/realm_instance.cc @@ -1,4 +1,4 @@ -#include "realm-execution/fmt/instance.h" +#include "realm-execution/fmt/realm_instance.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 7b047bcb72..c21737300c 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -1,11 +1,14 @@ #include "realm-execution/pcg_instance/pcg_instance.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/optimizer_attrs.h" #include "realm-execution/dependency_set.h" #include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/impl/op_task.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" @@ -83,6 +86,20 @@ PCGInstance create_pcg_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); + logit_grad_value = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { + for (DynamicNodeInvocation const &invocation : dg.invocations) { + if (invocation.node_attrs.task_type != DynamicTaskType::LOSS) { + continue; + } + for (auto const &[slot, value] : invocation.outputs) { + if (slot.slot_name == TensorSlotName::LOGIT && value.tensor_guid == lgv.tensor_guid && value.role == lgv.role) { + return value; + } + } + } + PANIC("couldn't find updated logit grad in the shard-expanded dynamic graph"); + }); + std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { return backing.backing.at(lgv).first; diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index cff12c2391..914e8d1e29 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -49,6 +49,8 @@ Realm::Event register_all_tasks() { }; for (task_id_t task_id : init_task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::LOC_PROC, task_id, device_state_init_task_body)); pending_registrations.push_back(register_task( Realm::Processor::TOC_PROC, task_id, device_state_init_task_body)); } From 6adf137b80b3fafe4ef7f87832a056a4d701e168 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 11:12:26 -0800 Subject: [PATCH 66/88] Fix a bug in device state handling. --- .../device_specific_managed_per_device_ff_handle.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index 6e0cef0bb2..bcc0a22ccf 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -51,7 +51,7 @@ device_handle_t device_handle_t_from_device_specific_managed_handle( DeviceSpecificManagedPerDeviceFFHandle const &device_specific, device_id_t device_idx) { return device_handle_t_from_managed_handle_ptr( - *device_specific.get(device_idx)); + device_specific.get(device_idx)); } } // namespace FlexFlow From 27660ecd63e039c63e56ffda48a81f9fd4778bed Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 12:27:19 -0800 Subject: [PATCH 67/88] Implement most of tensor backing in task. --- .../distributed_device_state_initialization.h | 4 +- .../dynamic_tensor_accessor_from_instance.h | 14 +++ .../include/realm-execution/fmt/realm_event.h | 11 +-- .../include/realm-execution/hash/processor.h | 4 + .../pcg_instance/pcg_instance.h | 12 ++- .../include/realm-execution/realm.h | 2 +- .../tasks/impl/device_state_init_task.h | 4 + .../impl/device_state_init_task_args.dtg.toml | 5 + .../realm-execution/tasks/impl/op_task.h | 20 ++-- .../tasks/impl/op_task_args.dtg.toml | 6 ++ ...zable_device_state_init_task_args.dtg.toml | 5 + .../impl/serializable_op_task_args.dtg.toml | 5 + .../serializable_realm_instance.dtg.toml | 17 ++++ .../serializer/serializable_realm_instance.h | 16 +++ .../realm-execution/tensor_instance_backing.h | 4 + ...distributed_device_state_initialization.cc | 16 ++- .../dynamic_tensor_accessor_from_instance.cc | 11 +++ .../src/realm-execution/fmt/realm_event.cc | 3 +- .../src/realm-execution/hash/processor.cc | 4 + .../pcg_instance/pcg_instance.cc | 99 ++++++++++++------- .../tasks/impl/device_state_init_task.cc | 21 +++- .../src/realm-execution/tasks/impl/op_task.cc | 37 +++++-- ...erializable_device_state_init_task_args.cc | 11 +++ .../tasks/impl/serializable_op_task_args.cc | 11 +++ .../serializer/serializable_realm_instance.cc | 15 +++ .../tensor_instance_backing.cc | 14 +++ 26 files changed, 303 insertions(+), 68 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.h create mode 100644 lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index ca24ecdd4c..e257834e65 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -5,14 +5,16 @@ #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { DynamicOpenDataflowGraph perform_distributed_device_state_initialization( - DynamicOpenDataflowGraph const &dg, RealmContext &ctx, + DynamicOpenDataflowGraph const &dg, + TensorInstanceBacking const &tensor_instance_backing, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h b/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h new file mode 100644 index 0000000000..48cfbde924 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DYNAMIC_TENSOR_ACCESSOR_FROM_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DYNAMIC_TENSOR_ACCESSOR_FROM_INSTANCE_H + +#include "realm-execution/realm.h" +#include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" + +namespace FlexFlow { + +DynamicTensorAccessor + dynamic_tensor_accessor_from_instance(Realm::RegionInstance const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/fmt/realm_event.h b/lib/realm-execution/include/realm-execution/fmt/realm_event.h index a7df28ced6..a245968f39 100644 --- a/lib/realm-execution/include/realm-execution/fmt/realm_event.h +++ b/lib/realm-execution/include/realm-execution/fmt/realm_event.h @@ -9,10 +9,10 @@ namespace fmt { template -struct formatter<::FlexFlow::Realm::Event, - Char, - std::enable_if_t::value>> +struct formatter< + ::FlexFlow::Realm::Event, + Char, + std::enable_if_t::value>> : formatter<::std::string> { template auto format(::FlexFlow::Realm::Event const &m, FormatContext &ctx) @@ -27,8 +27,7 @@ struct formatter<::FlexFlow::Realm::Event, namespace FlexFlow { -std::ostream &operator<<(std::ostream &s, - ::FlexFlow::Realm::Event const &m); +std::ostream &operator<<(std::ostream &s, ::FlexFlow::Realm::Event const &m); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/hash/processor.h b/lib/realm-execution/include/realm-execution/hash/processor.h index e5eb8eb503..efe6e6186b 100644 --- a/lib/realm-execution/include/realm-execution/hash/processor.h +++ b/lib/realm-execution/include/realm-execution/hash/processor.h @@ -4,6 +4,8 @@ #include "realm-execution/realm.h" #include +#ifdef FLEXFLOW_USE_PREALM + namespace std { template <> @@ -14,3 +16,5 @@ struct hash<::FlexFlow::Realm::Processor> { } // namespace std #endif + +#endif diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index fa163d1419..1238097b2a 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -12,6 +12,7 @@ #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -29,10 +30,12 @@ struct PCGInstance { explicit PCGInstance( RealmContext &ctx, std::vector const &execution_order, + TensorInstanceBacking const &tensor_instance_backing, OptimizerAttrs const &optimizer_attrs, std::optional logit_grad_tensor); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; + TensorInstanceBacking const &get_tensor_instance_backing() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); std::optional get_loss_tensor_instance() const; @@ -40,6 +43,7 @@ struct PCGInstance { private: RealmContext &ctx; std::vector execution_order; + TensorInstanceBacking tensor_instance_backing; OptimizerAttrs optimizer_attrs; std::optional logit_grad_tensor; }; @@ -60,28 +64,28 @@ PCGInstance create_pcg_instance( std::unordered_map perform_all_passes_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config); std::unordered_map perform_forward_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config); std::unordered_map perform_backward_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config); std::unordered_map perform_update_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config); diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h index b6913e66f5..fe83e69583 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H -#define FLEXFLOW_USE_PREALM +// #define FLEXFLOW_USE_PREALM #ifdef FLEXFLOW_USE_PREALM #include diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 4ed8c1726d..9c53748916 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -8,7 +8,9 @@ #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" +#include namespace FlexFlow { @@ -19,6 +21,8 @@ std::optional spawn_device_state_init_task( RealmContext &ctx, Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + std::unordered_map const + &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml index a9aa77dde9..888c62af54 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml @@ -10,6 +10,7 @@ includes = [ "realm-execution/realm.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -17,6 +18,10 @@ includes = [ name = "invocation" type = "::FlexFlow::DynamicNodeInvocation" +[[fields]] +name = "tensor_backing" +type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, ::FlexFlow::Realm::RegionInstance>" + [[fields]] name = "profiling_settings" type = "::FlexFlow::ProfilingSettings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 9d4c2fd451..37a801a508 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -14,15 +14,17 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition); +Realm::Event spawn_op_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + std::unordered_map const + &tensor_backing, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml index 814f9f802b..84fa384d25 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -7,7 +7,9 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/realm.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -15,6 +17,10 @@ includes = [ name = "invocation" type = "::FlexFlow::DynamicNodeInvocation" +[[fields]] +name = "tensor_backing" +type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, ::FlexFlow::Realm::RegionInstance>" + [[fields]] name = "profiling_settings" type = "::FlexFlow::ProfilingSettings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml index c99d2758c0..f3847c9137 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -12,6 +12,7 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", @@ -22,6 +23,10 @@ includes = [ name = "invocation" type = "::FlexFlow::SerializableDynamicNodeInvocation" +[[fields]] +name = "tensor_backing" +type = "std::unordered_map<::FlexFlow::SerializableDynamicValueAttrs, ::FlexFlow::SerializableRealmInstance>" + [[fields]] name = "profiling_settings" type = "::FlexFlow::ProfilingSettings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml index a0f89e3ae2..3ca338689a 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -12,6 +12,7 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -25,6 +26,10 @@ src_includes = [ name = "invocation" type = "::FlexFlow::SerializableDynamicNodeInvocation" +[[fields]] +name = "tensor_backing" +type = "std::unordered_map<::FlexFlow::SerializableDynamicValueAttrs, ::FlexFlow::SerializableRealmInstance>" + [[fields]] name = "profiling_settings" type = "::FlexFlow::ProfilingSettings" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml new file mode 100644 index 0000000000..150801367d --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SerializableRealmInstance" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", +] + +[[fields]] +name = "id" +type = "::FlexFlow::Realm::RegionInstance::id_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.h new file mode 100644 index 0000000000..7262ec4f09 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_INSTANCE_H + +#include "realm-execution/realm.h" +#include "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h" + +namespace FlexFlow { + +SerializableRealmInstance + realm_instance_to_serializable(Realm::RegionInstance const &); +Realm::RegionInstance + realm_instance_from_serializable(SerializableRealmInstance const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.h b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h index 1d143b7409..72a8bf439a 100644 --- a/lib/realm-execution/include/realm-execution/tensor_instance_backing.h +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h @@ -2,11 +2,15 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TENSOR_INSTANCE_BACKING_H #include "realm-execution/tensor_instance_backing.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" namespace FlexFlow { TensorInstanceBacking make_empty_tensor_instance_backing(); +TensorInstanceBacking subset_tensor_instance_backing_for_invocation( + TensorInstanceBacking const &, DynamicNodeInvocation const &); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index cab2b49e15..de8060aa12 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,8 +1,11 @@ #include "realm-execution/distributed_device_state_initialization.h" #include "local-execution/device_state_initialization.h" #include "realm-execution/tasks/impl/device_state_init_task.h" +#include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "utils/containers/map_values.h" #include "utils/optional.h" #include #include @@ -10,8 +13,9 @@ namespace FlexFlow { DynamicOpenDataflowGraph perform_distributed_device_state_initialization( - DynamicOpenDataflowGraph const &dg, RealmContext &ctx, + DynamicOpenDataflowGraph const &dg, + TensorInstanceBacking const &tensor_instance_backing, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, @@ -27,6 +31,15 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); + std::unordered_map + tensor_backing = map_values( + subset_tensor_instance_backing_for_invocation( + tensor_instance_backing, invocation) + .backing, + [](std::pair const &v) { + return v.first; + }); + // FIXME: in the absense of a real serializer we're just tossing around raw // bytes, which means we need to bypass the constructor for this type (yes, // ugh) @@ -37,6 +50,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( spawn_device_state_init_task(ctx, target_proc, invocation, + tensor_backing, profiling_settings, device_handle.at(target_proc), iteration_config, diff --git a/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc new file mode 100644 index 0000000000..cb9382cfe0 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc @@ -0,0 +1,11 @@ +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "utils/exception.h" + +namespace FlexFlow { + +DynamicTensorAccessor + dynamic_tensor_accessor_from_instance(Realm::RegionInstance const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/fmt/realm_event.cc b/lib/realm-execution/src/realm-execution/fmt/realm_event.cc index 7c5ad7d848..a5aed9481d 100644 --- a/lib/realm-execution/src/realm-execution/fmt/realm_event.cc +++ b/lib/realm-execution/src/realm-execution/fmt/realm_event.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -std::ostream &operator<<(std::ostream &s, - ::FlexFlow::Realm::Event const &m) { +std::ostream &operator<<(std::ostream &s, ::FlexFlow::Realm::Event const &m) { return s << fmt::to_string(m); } diff --git a/lib/realm-execution/src/realm-execution/hash/processor.cc b/lib/realm-execution/src/realm-execution/hash/processor.cc index dcc1bc5d06..5a8624f676 100644 --- a/lib/realm-execution/src/realm-execution/hash/processor.cc +++ b/lib/realm-execution/src/realm-execution/hash/processor.cc @@ -1,6 +1,8 @@ #include "realm-execution/hash/processor.h" #include +#ifdef FLEXFLOW_USE_PREALM + namespace std { size_t hash<::FlexFlow::Realm::Processor>::operator()( @@ -9,3 +11,5 @@ size_t hash<::FlexFlow::Realm::Processor>::operator()( } } // namespace std + +#endif diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index c21737300c..496c3210c0 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -6,6 +6,7 @@ #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" @@ -16,6 +17,7 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" +#include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" @@ -26,9 +28,11 @@ namespace FlexFlow { PCGInstance::PCGInstance( RealmContext &ctx, std::vector const &execution_order, + TensorInstanceBacking const &tensor_instance_backing, OptimizerAttrs const &optimizer_attrs, std::optional logit_grad_tensor) : ctx(ctx), execution_order(execution_order), + tensor_instance_backing(tensor_instance_backing), optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} RealmContext &PCGInstance::get_realm_context() { @@ -38,6 +42,9 @@ std::vector const & PCGInstance::get_execution_order() const { return this->execution_order; } +TensorInstanceBacking const &PCGInstance::get_tensor_instance_backing() const { + return this->tensor_instance_backing; +} OptimizerAttrs const &PCGInstance::get_optimizer_attrs() const { return this->optimizer_attrs; } @@ -86,19 +93,23 @@ PCGInstance create_pcg_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); - logit_grad_value = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - for (DynamicNodeInvocation const &invocation : dg.invocations) { - if (invocation.node_attrs.task_type != DynamicTaskType::LOSS) { - continue; - } - for (auto const &[slot, value] : invocation.outputs) { - if (slot.slot_name == TensorSlotName::LOGIT && value.tensor_guid == lgv.tensor_guid && value.role == lgv.role) { - return value; + logit_grad_value = + transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { + for (DynamicNodeInvocation const &invocation : dg.invocations) { + if (invocation.node_attrs.task_type != DynamicTaskType::LOSS) { + continue; + } + for (auto const &[slot, value] : invocation.outputs) { + if (slot.slot_name == TensorSlotName::LOGIT && + value.tensor_guid == lgv.tensor_guid && + value.role == lgv.role) { + return value; + } + } } - } - } - PANIC("couldn't find updated logit grad in the shard-expanded dynamic graph"); - }); + PANIC("couldn't find updated logit grad in the shard-expanded dynamic " + "graph"); + }); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -108,8 +119,9 @@ PCGInstance create_pcg_instance( // FIXME: for now we're going to be lazy and block on everything rather than // do fine-grained dependencies on instances dg = perform_distributed_device_state_initialization( - dg, ctx, + dg, + backing, profiling_settings, device_handle, iteration_config, @@ -123,8 +135,11 @@ PCGInstance create_pcg_instance( std::vector invocation_topo_order = transform( node_topo_order, [&](Node node) { return node_map.at_l(node); }); - return PCGInstance{ - ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; + return PCGInstance{/*ctx=*/ctx, + /*execution_order=*/invocation_topo_order, + /*tensor_instance_backing=*/backing, + /*optimizer_attrs=*/optimizer_attrs, + /*logit_grad_tensor=*/logit_grad_tensor}; // TODO list: // * external instances @@ -141,6 +156,7 @@ static std::unordered_map execute_distributed_dynamic_node_invocation_set( RealmContext &ctx, std::vector const &invocations, + TensorInstanceBacking const &tensor_instance_backing, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -165,9 +181,20 @@ static std::unordered_map Realm::Event::merge_events(output_dependencies)); Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); + + std::unordered_map + tensor_backing = map_values( + subset_tensor_instance_backing_for_invocation( + tensor_instance_backing, invocation) + .backing, + [](std::pair const &v) { + return v.first; + }); + Realm::Event result = spawn_op_task(ctx, target_proc, invocation, + tensor_backing, profiling_settings, device_handle.at(target_proc), iteration_config, @@ -185,32 +212,34 @@ static std::unordered_map std::unordered_map perform_all_passes_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = - instance.get_execution_order(); + pcg_instance.get_execution_order(); std::unordered_map result = execute_distributed_dynamic_node_invocation_set( - /*ctx=*/instance.get_realm_context(), + /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, - /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*tensor_instance_backing=*/ + pcg_instance.get_tensor_instance_backing(), + /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); - instance.update_optimizer_attrs_for_next_iter(); + pcg_instance.update_optimizer_attrs_for_next_iter(); return result; } std::unordered_map perform_forward_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = - filter(instance.get_execution_order(), + filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = assert_unwrap(invocation.node_attrs.task_type); @@ -218,9 +247,10 @@ std::unordered_map }); return execute_distributed_dynamic_node_invocation_set( - /*ctx=*/instance.get_realm_context(), + /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, - /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*tensor_instance_backing=*/pcg_instance.get_tensor_instance_backing(), + /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); @@ -228,12 +258,12 @@ std::unordered_map std::unordered_map perform_backward_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = - filter(instance.get_execution_order(), + filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = assert_unwrap(invocation.node_attrs.task_type); @@ -241,9 +271,10 @@ std::unordered_map }); return execute_distributed_dynamic_node_invocation_set( - /*ctx=*/instance.get_realm_context(), + /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, - /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*tensor_instance_backing=*/pcg_instance.get_tensor_instance_backing(), + /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); @@ -251,12 +282,12 @@ std::unordered_map std::unordered_map perform_update_pass_for_pcg_instance( - PCGInstance &instance, + PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = - filter(instance.get_execution_order(), + filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = assert_unwrap(invocation.node_attrs.task_type); @@ -265,13 +296,15 @@ std::unordered_map std::unordered_map result = execute_distributed_dynamic_node_invocation_set( - /*ctx=*/instance.get_realm_context(), + /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, - /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*tensor_instance_backing=*/ + pcg_instance.get_tensor_instance_backing(), + /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); - instance.update_optimizer_attrs_for_next_iter(); + pcg_instance.update_optimizer_attrs_for_next_iter(); return result; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 99c72cf5e7..d455b493da 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,11 +1,15 @@ #include "realm-execution/tasks/impl/device_state_init_task.h" #include "local-execution/device_state_initialization.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "utils/containers/map_values.h" #include "utils/optional.h" #include #include @@ -26,8 +30,20 @@ void device_state_init_task_body(void const *args, device_handle_t device_handle = device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); + + // Patch the invocation to include the provided instances + auto map_instance_to_accessor = [&](DynamicValueAttrs const &value) { + DynamicValueAttrs result = value; + result.accessor = dynamic_tensor_accessor_from_instance( + task_args.tensor_backing.at(value)); + return result; + }; + DynamicNodeInvocation invocation = task_args.invocation; + invocation.inputs = map_values(invocation.inputs, map_instance_to_accessor); + invocation.outputs = map_values(invocation.outputs, map_instance_to_accessor); + DynamicNodeInvocation result_invocation = - initialize_node(task_args.invocation, + initialize_node(invocation, ctx.get_current_device_allocator(), task_args.profiling_settings, device_handle, @@ -51,6 +67,8 @@ std::optional spawn_device_state_init_task( RealmContext &ctx, Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + std::unordered_map const + &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, @@ -59,6 +77,7 @@ std::optional spawn_device_state_init_task( Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ invocation, + tensor_backing, profiling_settings, device_handle, iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index d8b8873442..0f65b808aa 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,11 +1,13 @@ #include "realm-execution/tasks/impl/op_task.h" #include "local-execution/task_execution.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" #include "realm-execution/tasks/impl/op_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_op_task_args.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" +#include "utils/containers/map_values.h" #include "utils/optional.h" #include @@ -23,8 +25,20 @@ void op_task_body(void const *args, device_handle_t device_handle = device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); + + // Patch the invocation to include the provided instances + auto map_instance_to_accessor = [&](DynamicValueAttrs const &value) { + DynamicValueAttrs result = value; + result.accessor = dynamic_tensor_accessor_from_instance( + task_args.tensor_backing.at(value)); + return result; + }; + DynamicNodeInvocation invocation = task_args.invocation; + invocation.inputs = map_values(invocation.inputs, map_instance_to_accessor); + invocation.outputs = map_values(invocation.outputs, map_instance_to_accessor); + execute_dynamic_node_invocation( - /*invocation=*/task_args.invocation, + /*invocation=*/invocation, /*allocator=*/ctx.get_current_device_allocator(), /*profiling_settings=*/task_args.profiling_settings, /*ff_handle=*/device_handle, @@ -39,16 +53,19 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition) { +Realm::Event spawn_op_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + std::unordered_map const + &tensor_backing, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{invocation, + tensor_backing, profiling_settings, device_handle, iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc index 528ff26867..59a1dd71a6 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" +#include "realm-execution/tasks/serializer/serializable_realm_instance.h" #include "realm-execution/tasks/serializer/serializable_realm_processor.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include "utils/containers/map_keys_and_values.h" namespace FlexFlow { @@ -8,6 +11,10 @@ SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( DeviceStateInitTaskArgs const &args) { return SerializableDeviceStateInitTaskArgs{ /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*tensor_backing*/ + map_keys_and_values(args.tensor_backing, + dynamic_value_attrs_to_serializable, + realm_instance_to_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/args.device_handle.serialize(), /*iteration_config=*/args.iteration_config, @@ -21,6 +28,10 @@ DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( SerializableDeviceStateInitTaskArgs const &args) { return DeviceStateInitTaskArgs{ /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*tensor_backing*/ + map_keys_and_values(args.tensor_backing, + dynamic_value_attrs_from_serializable, + realm_instance_from_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index 0513bc6df7..04a213e906 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -1,11 +1,18 @@ #include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "realm-execution/tasks/serializer/serializable_realm_instance.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include "utils/containers/map_keys_and_values.h" namespace FlexFlow { SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { return SerializableOpTaskArgs{ /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*tensor_backing*/ + map_keys_and_values(args.tensor_backing, + dynamic_value_attrs_to_serializable, + realm_instance_to_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/args.device_handle.serialize(), /*iteration_config=*/args.iteration_config, @@ -16,6 +23,10 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { return OpTaskArgs{ /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*tensor_backing*/ + map_keys_and_values(args.tensor_backing, + dynamic_value_attrs_from_serializable, + realm_instance_from_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc new file mode 100644 index 0000000000..f2d42a96ca --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc @@ -0,0 +1,15 @@ +#include "realm-execution/tasks/serializer/serializable_realm_instance.h" + +namespace FlexFlow { + +SerializableRealmInstance + realm_instance_to_serializable(Realm::RegionInstance const &inst) { + return SerializableRealmInstance{inst.id}; +} + +Realm::RegionInstance + realm_instance_from_serializable(SerializableRealmInstance const &inst) { + return Realm::RegionInstance{inst.id}; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc index 53c2a2b271..dea51d8c92 100644 --- a/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc +++ b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc @@ -1,4 +1,5 @@ #include "realm-execution/tensor_instance_backing.h" +#include "utils/containers/values.h" namespace FlexFlow { @@ -8,4 +9,17 @@ TensorInstanceBacking make_empty_tensor_instance_backing() { }; } +TensorInstanceBacking subset_tensor_instance_backing_for_invocation( + TensorInstanceBacking const &backing, + DynamicNodeInvocation const &invocation) { + TensorInstanceBacking result = make_empty_tensor_instance_backing(); + for (DynamicValueAttrs const &value : values(invocation.inputs)) { + result.backing.insert(std::pair{value, backing.backing.at(value)}); + } + for (DynamicValueAttrs const &value : values(invocation.outputs)) { + result.backing.insert(std::pair{value, backing.backing.at(value)}); + } + return result; +} + } // namespace FlexFlow From afad03b735a0d2e6edb513fa51ecd7afe3232143 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 14:42:49 -0800 Subject: [PATCH 68/88] Refactor and finish tensor instance backing. --- .../dynamic_tensor_accessor_from_instance.h | 10 ++++-- .../tasks/impl/device_state_init_task.h | 6 ++-- .../impl/device_state_init_task_args.dtg.toml | 4 +-- .../realm-execution/tasks/impl/op_task.h | 22 ++++++------ .../tasks/impl/op_task_args.dtg.toml | 5 ++- ...zable_device_state_init_task_args.dtg.toml | 4 +-- .../impl/serializable_op_task_args.dtg.toml | 4 +-- .../serializable_realm_event.dtg.toml | 17 +++++++++ .../serializer/serializable_realm_event.h | 14 ++++++++ ...ializable_tensor_instance_backing.dtg.toml | 26 ++++++++++++++ .../serializable_tensor_instance_backing.h | 16 +++++++++ ...distributed_device_state_initialization.cc | 12 +++---- .../dynamic_tensor_accessor_from_instance.cc | 36 +++++++++++++++++-- .../pcg_instance/pcg_instance.cc | 11 ++---- .../tasks/impl/device_state_init_task.cc | 10 ++++-- .../src/realm-execution/tasks/impl/op_task.cc | 29 ++++++++------- ...erializable_device_state_init_task_args.cc | 12 ++----- .../tasks/impl/serializable_op_task_args.cc | 12 ++----- .../serializer/serializable_realm_event.cc | 14 ++++++++ .../serializable_tensor_instance_backing.cc | 32 +++++++++++++++++ 20 files changed, 218 insertions(+), 78 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_event.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc diff --git a/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h b/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h index 48cfbde924..8c8ccf6ac4 100644 --- a/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h +++ b/lib/realm-execution/include/realm-execution/dynamic_tensor_accessor_from_instance.h @@ -1,13 +1,19 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DYNAMIC_TENSOR_ACCESSOR_FROM_INSTANCE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DYNAMIC_TENSOR_ACCESSOR_FROM_INSTANCE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "realm-execution/realm.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "task-spec/permissions.h" namespace FlexFlow { -DynamicTensorAccessor - dynamic_tensor_accessor_from_instance(Realm::RegionInstance const &); +DynamicTensorAccessor dynamic_tensor_accessor_from_instance( + Realm::RegionInstance inst, + Realm::Event ready, + ParallelTensorShape const ¶llel_tensor_shape, + Permissions const &permissions, + Realm::Processor for_processor); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 9c53748916..54bddc1ddd 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -6,11 +6,10 @@ #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" -#include namespace FlexFlow { @@ -21,8 +20,7 @@ std::optional spawn_device_state_init_task( RealmContext &ctx, Realm::Processor target_proc, DynamicNodeInvocation const &invocation, - std::unordered_map const - &tensor_backing, + TensorInstanceBacking const &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml index 888c62af54..fbec9298dd 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml @@ -7,10 +7,10 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/tensor_instance_backing.dtg.h", "realm-execution/realm.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", - "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -20,7 +20,7 @@ type = "::FlexFlow::DynamicNodeInvocation" [[fields]] name = "tensor_backing" -type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, ::FlexFlow::Realm::RegionInstance>" +type = "TensorInstanceBacking" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 37a801a508..330da4d2b2 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -7,6 +7,7 @@ #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" @@ -14,17 +15,16 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_op_task( - RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - std::unordered_map const - &tensor_backing, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition); +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + TensorInstanceBacking const &tensor_backing, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml index 84fa384d25..2a55ffbf80 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -7,9 +7,8 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/device_specific_managed_per_device_ff_handle.h", - "realm-execution/realm.h", + "realm-execution/tensor_instance_backing.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", - "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -19,7 +18,7 @@ type = "::FlexFlow::DynamicNodeInvocation" [[fields]] name = "tensor_backing" -type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, ::FlexFlow::Realm::RegionInstance>" +type = "TensorInstanceBacking" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml index f3847c9137..034132f9d1 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -12,8 +12,8 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", - "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", + "realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", "task-spec/ff_iteration_config.dtg.h", @@ -25,7 +25,7 @@ type = "::FlexFlow::SerializableDynamicNodeInvocation" [[fields]] name = "tensor_backing" -type = "std::unordered_map<::FlexFlow::SerializableDynamicValueAttrs, ::FlexFlow::SerializableRealmInstance>" +type = "::FlexFlow::SerializableTensorInstanceBacking" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml index 3ca338689a..ac31e78d0d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -12,7 +12,7 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", - "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h", + "realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", "task-spec/ff_iteration_config.dtg.h", ] @@ -28,7 +28,7 @@ type = "::FlexFlow::SerializableDynamicNodeInvocation" [[fields]] name = "tensor_backing" -type = "std::unordered_map<::FlexFlow::SerializableDynamicValueAttrs, ::FlexFlow::SerializableRealmInstance>" +type = "::FlexFlow::SerializableTensorInstanceBacking" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.dtg.toml new file mode 100644 index 0000000000..3217d58608 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SerializableRealmEvent" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", +] + +[[fields]] +name = "id" +type = "::FlexFlow::Realm::Event::id_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.h new file mode 100644 index 0000000000..ae1f1e8265 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_event.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_EVENT_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_EVENT_H + +#include "realm-execution/realm.h" +#include "realm-execution/tasks/serializer/serializable_realm_event.dtg.h" + +namespace FlexFlow { + +SerializableRealmEvent realm_event_to_serializable(Realm::Event const &); +Realm::Event realm_event_from_serializable(SerializableRealmEvent const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.toml new file mode 100644 index 0000000000..75a796b2ee --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SerializableTensorInstanceBacking" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "realm-execution/tasks/serializer/serializable_realm_event.dtg.h", + "realm-execution/tasks/serializer/serializable_realm_instance.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/pair.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "backing" +type = "std::unordered_map<::FlexFlow::SerializableDynamicValueAttrs, std::pair<::FlexFlow::SerializableRealmInstance, ::FlexFlow::SerializableRealmEvent>>" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.h new file mode 100644 index 0000000000..b536972b40 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_tensor_instance_backing.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_TENSOR_INSTANCE_BACKING_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_TENSOR_INSTANCE_BACKING_H + +#include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.h" +#include "realm-execution/tensor_instance_backing.dtg.h" + +namespace FlexFlow { + +SerializableTensorInstanceBacking + tensor_instance_backing_to_serializable(TensorInstanceBacking const &); +TensorInstanceBacking tensor_instance_backing_from_serializable( + SerializableTensorInstanceBacking const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index de8060aa12..d2d876a50b 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,6 +1,7 @@ #include "realm-execution/distributed_device_state_initialization.h" #include "local-execution/device_state_initialization.h" #include "realm-execution/tasks/impl/device_state_init_task.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" @@ -31,14 +32,9 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); - std::unordered_map - tensor_backing = map_values( - subset_tensor_instance_backing_for_invocation( - tensor_instance_backing, invocation) - .backing, - [](std::pair const &v) { - return v.first; - }); + TensorInstanceBacking tensor_backing = + subset_tensor_instance_backing_for_invocation(tensor_instance_backing, + invocation); // FIXME: in the absense of a real serializer we're just tossing around raw // bytes, which means we need to bypass the constructor for this type (yes, diff --git a/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc index cb9382cfe0..d1c773b1fa 100644 --- a/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc +++ b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc @@ -1,11 +1,41 @@ #include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/device_type.dtg.h" +#include "task-spec/permissions.h" #include "utils/exception.h" namespace FlexFlow { -DynamicTensorAccessor - dynamic_tensor_accessor_from_instance(Realm::RegionInstance const &) { - NOT_IMPLEMENTED(); +DynamicTensorAccessor dynamic_tensor_accessor_from_instance( + Realm::RegionInstance inst, + Realm::Event ready, + ParallelTensorShape const ¶llel_tensor_shape, + Permissions const &permissions, + Realm::Processor for_processor) { + ready.wait(); + + DeviceType device_type; + switch (for_processor.kind()) { + case Realm::Processor::LOC_PROC: + device_type = DeviceType::CPU; + break; + case Realm::Processor::TOC_PROC: + device_type = DeviceType::GPU; + break; + default: + PANIC("Unexpected Realm Processor kind", for_processor.kind()); + } + + size_t expected_size = + int{get_piece_size_in_bytes(parallel_tensor_shape).unwrap_num_bytes()}; + void *ptr = inst.pointer_untyped(/*offset=*/0, /*datalen=*/expected_size); + if (permissions == Permissions::RO) { + return DynamicTensorAccessor{GenericTensorAccessorR{ + get_piece_shape(parallel_tensor_shape), ptr, device_type}}; + } else { + return DynamicTensorAccessor{GenericTensorAccessorW{ + get_piece_shape(parallel_tensor_shape), ptr, device_type}}; + } } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 496c3210c0..8390d12cc9 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -182,14 +182,9 @@ static std::unordered_map Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); - std::unordered_map - tensor_backing = map_values( - subset_tensor_instance_backing_for_invocation( - tensor_instance_backing, invocation) - .backing, - [](std::pair const &v) { - return v.first; - }); + TensorInstanceBacking tensor_backing = + subset_tensor_instance_backing_for_invocation( + tensor_instance_backing, invocation); Realm::Event result = spawn_op_task(ctx, target_proc, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index d455b493da..7f3f2d185c 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -34,8 +34,13 @@ void device_state_init_task_body(void const *args, // Patch the invocation to include the provided instances auto map_instance_to_accessor = [&](DynamicValueAttrs const &value) { DynamicValueAttrs result = value; + auto const &[inst, event] = task_args.tensor_backing.backing.at(value); result.accessor = dynamic_tensor_accessor_from_instance( - task_args.tensor_backing.at(value)); + inst, + event, + assert_unwrap(value.parallel_tensor_shape), + Permissions::RW, // FIXME: get real permissions? + ctx.get_current_processor()); return result; }; DynamicNodeInvocation invocation = task_args.invocation; @@ -67,8 +72,7 @@ std::optional spawn_device_state_init_task( RealmContext &ctx, Realm::Processor target_proc, DynamicNodeInvocation const &invocation, - std::unordered_map const - &tensor_backing, + TensorInstanceBacking const &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 0f65b808aa..dc262bbdb1 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -7,6 +7,7 @@ #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" +#include "task-spec/permissions.h" #include "utils/containers/map_values.h" #include "utils/optional.h" #include @@ -29,8 +30,13 @@ void op_task_body(void const *args, // Patch the invocation to include the provided instances auto map_instance_to_accessor = [&](DynamicValueAttrs const &value) { DynamicValueAttrs result = value; + auto const &[inst, event] = task_args.tensor_backing.backing.at(value); result.accessor = dynamic_tensor_accessor_from_instance( - task_args.tensor_backing.at(value)); + inst, + event, + assert_unwrap(value.parallel_tensor_shape), + Permissions::RW, // FIXME: get real permissions? + ctx.get_current_processor()); return result; }; DynamicNodeInvocation invocation = task_args.invocation; @@ -53,17 +59,16 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event spawn_op_task( - RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - std::unordered_map const - &tensor_backing, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition) { +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + TensorInstanceBacking const &tensor_backing, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{invocation, tensor_backing, profiling_settings, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc index 59a1dd71a6..64669b9f1e 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -1,9 +1,7 @@ #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" -#include "realm-execution/tasks/serializer/serializable_realm_instance.h" #include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" -#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" -#include "utils/containers/map_keys_and_values.h" namespace FlexFlow { @@ -12,9 +10,7 @@ SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( return SerializableDeviceStateInitTaskArgs{ /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), /*tensor_backing*/ - map_keys_and_values(args.tensor_backing, - dynamic_value_attrs_to_serializable, - realm_instance_to_serializable), + tensor_instance_backing_to_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/args.device_handle.serialize(), /*iteration_config=*/args.iteration_config, @@ -29,9 +25,7 @@ DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( return DeviceStateInitTaskArgs{ /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), /*tensor_backing*/ - map_keys_and_values(args.tensor_backing, - dynamic_value_attrs_from_serializable, - realm_instance_from_serializable), + tensor_instance_backing_from_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index 04a213e906..0ef2fb0442 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -1,8 +1,6 @@ #include "realm-execution/tasks/impl/serializable_op_task_args.h" -#include "realm-execution/tasks/serializer/serializable_realm_instance.h" +#include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" -#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" -#include "utils/containers/map_keys_and_values.h" namespace FlexFlow { @@ -10,9 +8,7 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { return SerializableOpTaskArgs{ /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), /*tensor_backing*/ - map_keys_and_values(args.tensor_backing, - dynamic_value_attrs_to_serializable, - realm_instance_to_serializable), + tensor_instance_backing_to_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/args.device_handle.serialize(), /*iteration_config=*/args.iteration_config, @@ -24,9 +20,7 @@ OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { return OpTaskArgs{ /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), /*tensor_backing*/ - map_keys_and_values(args.tensor_backing, - dynamic_value_attrs_from_serializable, - realm_instance_from_serializable), + tensor_instance_backing_from_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_event.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_event.cc new file mode 100644 index 0000000000..806059f3ed --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_event.cc @@ -0,0 +1,14 @@ +#include "realm-execution/tasks/serializer/serializable_realm_event.h" + +namespace FlexFlow { + +SerializableRealmEvent realm_event_to_serializable(Realm::Event const &event) { + return SerializableRealmEvent{event.id}; +} + +Realm::Event + realm_event_from_serializable(SerializableRealmEvent const &event) { + return Realm::Event{event.id}; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc new file mode 100644 index 0000000000..79a5176c4f --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc @@ -0,0 +1,32 @@ +#include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" +#include "realm-execution/tasks/serializer/serializable_realm_event.h" +#include "realm-execution/tasks/serializer/serializable_realm_instance.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include "utils/containers/map_keys_and_values.h" + +namespace FlexFlow { + +SerializableTensorInstanceBacking tensor_instance_backing_to_serializable( + TensorInstanceBacking const &backing) { + return SerializableTensorInstanceBacking{/*backing=*/map_keys_and_values( + backing.backing, + dynamic_value_attrs_to_serializable, + [](std::pair const &p) { + return std::pair{realm_instance_to_serializable(p.first), + realm_event_to_serializable(p.second)}; + })}; +} + +TensorInstanceBacking tensor_instance_backing_from_serializable( + SerializableTensorInstanceBacking const &backing) { + return TensorInstanceBacking{/*backing=*/map_keys_and_values( + backing.backing, + dynamic_value_attrs_from_serializable, + [](std::pair const + &p) { + return std::pair{realm_instance_from_serializable(p.first), + realm_event_from_serializable(p.second)}; + })}; +} + +} // namespace FlexFlow From 9da0b947a2efad54021814cf71bd447aa079dbab Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 15:36:42 -0800 Subject: [PATCH 69/88] Don't execute tasks on input or weight nodes. --- .../src/realm-execution/pcg_instance/pcg_instance.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 8390d12cc9..2287b9d54b 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -16,6 +16,7 @@ #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" @@ -166,6 +167,14 @@ static std::unordered_map DependencySet dependency_set{ctx.get_outstanding_events()}; return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { + TrainingOperationAttrs op_attrs = + assert_unwrap(invocation.node_attrs.op_attrs); + if (op_attrs.is_pcg_op() && (op_attrs.require_pcg_op().is_input() || + op_attrs.require_pcg_op().is_weight())) { + return std::pair{invocation.node_attrs.layer_guid, + Realm::Event::NO_EVENT}; + } + std::vector input_dependencies = transform(vector_of(values(invocation.inputs)), [&](DynamicValueAttrs const &value) { From ee32e03af47de26732b7553563d270ad1eeadffa Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 16:03:19 -0800 Subject: [PATCH 70/88] Refactor device specific managed handle. --- ...ce_specific_managed_per_device_ff_handle.h | 21 ++--------- .../realm-execution/device_specific_ptr.h | 36 +++++++++++++++++++ .../serializable_device_specific_ptr.h | 32 +++++++++++++++++ ...e_specific_managed_per_device_ff_handle.cc | 35 ------------------ ...erializable_device_state_init_task_args.cc | 6 ++-- .../tasks/impl/serializable_op_task_args.cc | 6 ++-- 6 files changed, 79 insertions(+), 57 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/device_specific_ptr.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.h diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index d48a80f438..9a42861fcd 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,28 +4,13 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h" -#include +#include "realm-execution/device_specific_ptr.h" #include namespace FlexFlow { -struct DeviceSpecificManagedPerDeviceFFHandle { -public: - DeviceSpecificManagedPerDeviceFFHandle() = delete; - explicit DeviceSpecificManagedPerDeviceFFHandle( - device_id_t owner, std::optional handle); - - std::optional get(device_id_t device_idx) const; - - SerializableDeviceSpecificPtr serialize() const; - static DeviceSpecificManagedPerDeviceFFHandle - deserialize(SerializableDeviceSpecificPtr const &j); - -private: - device_id_t owner; - std::optional handle; -}; +using DeviceSpecificManagedPerDeviceFFHandle = + DeviceSpecificPtr; DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &, std::optional const &); diff --git a/lib/realm-execution/include/realm-execution/device_specific_ptr.h b/lib/realm-execution/include/realm-execution/device_specific_ptr.h new file mode 100644 index 0000000000..81d41131b7 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/device_specific_ptr.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_PTR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_PTR_H + +#include "pcg/device_id_t.dtg.h" +#include + +namespace FlexFlow { + +template +struct DeviceSpecificPtr { +public: + DeviceSpecificPtr() = delete; + explicit DeviceSpecificPtr(device_id_t device_idx, std::optional handle) + : device_idx(device_idx), ptr(ptr) {} + + std::optional get(device_id_t device_idx) const { + ASSERT(this->device_idx == device_idx); + return this->ptr; + } + + device_id_t get_device_idx() const { + return this->device_idx; + } + + std::optional get_unsafe_raw_ptr() const { + return this->ptr; + } + +private: + device_id_t device_idx; + std::optional ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.h new file mode 100644 index 0000000000..726aef84ba --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_DEVICE_SPECIFIC_PTR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_DEVICE_SPECIFIC_PTR_H + +#include "realm-execution/device_specific_ptr.h" +#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h" + +namespace FlexFlow { + +template +SerializableDeviceSpecificPtr device_specific_ptr_to_serializable( + DeviceSpecificPtr const &device_specific) { + return SerializableDeviceSpecificPtr{ + /*device_idx=*/device_specific.get_device_idx(), + /*ptr=*/ + transform(device_specific.get_unsafe_raw_ptr(), + [](T *ptr) { return reinterpret_cast(ptr); }), + }; +} + +template +DeviceSpecificPtr device_specific_ptr_from_serializable( + SerializableDeviceSpecificPtr const &device_specific) { + return DeviceSpecificPtr{ + /*device_idx*/ device_specific.device_idx, + /*ptr=*/transform(device_specific.ptr, [](uintptr_t ptrval) { + return reinterpret_cast(ptrval); + })}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index bcc0a22ccf..ae9fc669d3 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -6,41 +6,6 @@ namespace FlexFlow { -DeviceSpecificManagedPerDeviceFFHandle::DeviceSpecificManagedPerDeviceFFHandle( - device_id_t owner, std::optional handle) - : owner(owner), handle(handle) {} - -std::optional - DeviceSpecificManagedPerDeviceFFHandle::get(device_id_t device_idx) const { - ASSERT(this->owner == device_idx); - return this->handle; -} - -SerializableDeviceSpecificPtr - DeviceSpecificManagedPerDeviceFFHandle::serialize() const { - return SerializableDeviceSpecificPtr{ - /*device_idx=*/owner, - /*ptr=*/ - transform(handle, - [](ManagedPerDeviceFFHandle *ptr) { - return reinterpret_cast(ptr); - }), - }; -} - -DeviceSpecificManagedPerDeviceFFHandle - DeviceSpecificManagedPerDeviceFFHandle::deserialize( - SerializableDeviceSpecificPtr const &handle) { - return DeviceSpecificManagedPerDeviceFFHandle{ - /*owner=*/handle.device_idx, - /*handle=*/ - transform(handle.ptr, - [](uintptr_t ptrval) { - return reinterpret_cast(ptrval); - }), - }; -} - DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &device_id, std::optional const &managed_handle) { diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc index 64669b9f1e..fed22ff393 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -1,4 +1,5 @@ #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" +#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.h" #include "realm-execution/tasks/serializer/serializable_realm_processor.h" #include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" @@ -12,7 +13,7 @@ SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( /*tensor_backing*/ tensor_instance_backing_to_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, - /*device_handle=*/args.device_handle.serialize(), + /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), @@ -28,7 +29,8 @@ DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( tensor_instance_backing_from_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ - DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + device_specific_ptr_from_serializable( + args.device_handle), /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index 0ef2fb0442..80994d4298 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -1,4 +1,5 @@ #include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.h" #include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" @@ -10,7 +11,7 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { /*tensor_backing*/ tensor_instance_backing_to_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, - /*device_handle=*/args.device_handle.serialize(), + /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, }; @@ -23,7 +24,8 @@ OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { tensor_instance_backing_from_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ - DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + device_specific_ptr_from_serializable( + args.device_handle), /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, }; From f7bb5ec35e165d3b9556c67f5fea7861d71cf6fb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 17:26:02 -0800 Subject: [PATCH 71/88] Refactor per-device op state backing. --- .../realm-execution/device_specific_ptr.h | 2 +- .../distributed_device_state_initialization.h | 7 ++- .../pcg_instance/pcg_instance.h | 4 ++ .../per_device_op_state_backing.dtg.toml | 15 +++++ .../impl/device_state_init_return_task.h | 7 ++- .../tasks/impl/device_state_init_task.h | 5 +- .../impl/device_state_init_task_args.dtg.toml | 3 +- .../realm-execution/tasks/impl/op_task.h | 3 + .../tasks/impl/op_task_args.dtg.toml | 8 ++- .../impl/serializable_op_task_args.dtg.toml | 4 ++ ...distributed_device_state_initialization.cc | 62 +++++++------------ .../pcg_instance/pcg_instance.cc | 61 ++++++++++-------- .../impl/device_state_init_return_task.cc | 12 ++-- .../tasks/impl/device_state_init_task.cc | 13 ++-- .../src/realm-execution/tasks/impl/op_task.cc | 11 ++-- ...erializable_device_state_init_task_args.cc | 2 +- .../tasks/impl/serializable_op_task_args.cc | 4 ++ 17 files changed, 131 insertions(+), 92 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/per_device_op_state_backing.dtg.toml diff --git a/lib/realm-execution/include/realm-execution/device_specific_ptr.h b/lib/realm-execution/include/realm-execution/device_specific_ptr.h index 81d41131b7..590b7dbc74 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_ptr.h +++ b/lib/realm-execution/include/realm-execution/device_specific_ptr.h @@ -10,7 +10,7 @@ template struct DeviceSpecificPtr { public: DeviceSpecificPtr() = delete; - explicit DeviceSpecificPtr(device_id_t device_idx, std::optional handle) + explicit DeviceSpecificPtr(device_id_t device_idx, std::optional ptr) : device_idx(device_idx), ptr(ptr) {} std::optional get(device_id_t device_idx) const { diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index e257834e65..b26a69078e 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -1,9 +1,10 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_PER_DEVICE_OP_STATE_BACKING_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_PER_DEVICE_OP_STATE_BACKING_H #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/distributed_device_handle.h" +#include "realm-execution/per_device_op_state_backing.dtg.h" #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" @@ -11,7 +12,7 @@ namespace FlexFlow { -DynamicOpenDataflowGraph perform_distributed_device_state_initialization( +PerDeviceOpStateBacking perform_distributed_device_state_initialization( RealmContext &ctx, DynamicOpenDataflowGraph const &dg, TensorInstanceBacking const &tensor_instance_backing, diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index 1238097b2a..e754fbbf5c 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -11,6 +11,7 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/distributed_device_handle.h" +#include "realm-execution/per_device_op_state_backing.dtg.h" #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" @@ -31,11 +32,13 @@ struct PCGInstance { RealmContext &ctx, std::vector const &execution_order, TensorInstanceBacking const &tensor_instance_backing, + PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, std::optional logit_grad_tensor); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; TensorInstanceBacking const &get_tensor_instance_backing() const; + PerDeviceOpStateBacking const &get_device_state_backing() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); std::optional get_loss_tensor_instance() const; @@ -44,6 +47,7 @@ struct PCGInstance { RealmContext &ctx; std::vector execution_order; TensorInstanceBacking tensor_instance_backing; + PerDeviceOpStateBacking device_state_backing; OptimizerAttrs optimizer_attrs; std::optional logit_grad_tensor; }; diff --git a/lib/realm-execution/include/realm-execution/per_device_op_state_backing.dtg.toml b/lib/realm-execution/include/realm-execution/per_device_op_state_backing.dtg.toml new file mode 100644 index 0000000000..90a9d01e69 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/per_device_op_state_backing.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PerDeviceOpStateBacking" +type = "struct" +features = [] + +includes = [ + "", + "realm-execution/device_specific_ptr.h", + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/per_device_op_state.dtg.h", +] + +[[fields]] +name = "backing" +type = "std::unordered_map<::FlexFlow::DynamicNodeInvocation, ::FlexFlow::DeviceSpecificPtr<::FlexFlow::PerDeviceOpState>>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h index 8f44680815..4de7e5689f 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H +#include "realm-execution/device_specific_ptr.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" namespace FlexFlow { @@ -13,8 +14,8 @@ void device_state_init_return_task_body( Realm::Event spawn_device_state_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr, + DeviceSpecificPtr const &result, + DeviceSpecificPtr *origin_result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 54bddc1ddd..657d2e8401 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -4,12 +4,13 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_ptr.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" namespace FlexFlow { @@ -25,7 +26,7 @@ std::optional spawn_device_state_init_task( DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, + DeviceSpecificPtr *result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml index fbec9298dd..9a7c2781d2 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml @@ -12,6 +12,7 @@ includes = [ "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", "task-spec/ff_iteration_config.dtg.h", + "task-spec/per_device_op_state.dtg.h", ] [[fields]] @@ -44,4 +45,4 @@ type = "::FlexFlow::Realm::Processor" [[fields]] name = "origin_result_ptr" -type = "::FlexFlow::DeviceSpecificPerDeviceOpState *" +type = "::FlexFlow::DeviceSpecificPtr *" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 330da4d2b2..33dcbff895 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -5,11 +5,13 @@ #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_ptr.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" namespace FlexFlow { @@ -20,6 +22,7 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, TensorInstanceBacking const &tensor_backing, + DeviceSpecificPtr const &device_state, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml index 2a55ffbf80..a15c8dce11 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -7,9 +7,11 @@ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/device_specific_ptr.h", "realm-execution/tensor_instance_backing.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", "task-spec/ff_iteration_config.dtg.h", + "task-spec/per_device_op_state.dtg.h", ] [[fields]] @@ -18,7 +20,11 @@ type = "::FlexFlow::DynamicNodeInvocation" [[fields]] name = "tensor_backing" -type = "TensorInstanceBacking" +type = "::FlexFlow::TensorInstanceBacking" + +[[fields]] +name = "device_state" +type = "::FlexFlow::DeviceSpecificPtr<::FlexFlow::PerDeviceOpState>" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml index ac31e78d0d..2be0034c46 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -30,6 +30,10 @@ type = "::FlexFlow::SerializableDynamicNodeInvocation" name = "tensor_backing" type = "::FlexFlow::SerializableTensorInstanceBacking" +[[fields]] +name = "device_state" +type = "::FlexFlow::SerializableDeviceSpecificPtr" + [[fields]] name = "profiling_settings" type = "::FlexFlow::ProfilingSettings" diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index d2d876a50b..c6d8ea3e69 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -13,7 +13,7 @@ namespace FlexFlow { -DynamicOpenDataflowGraph perform_distributed_device_state_initialization( +PerDeviceOpStateBacking perform_distributed_device_state_initialization( RealmContext &ctx, DynamicOpenDataflowGraph const &dg, TensorInstanceBacking const &tensor_instance_backing, @@ -26,8 +26,16 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( // Initialize all operators and save the per-device op state ASSERT(no_nodes_are_initialized(dg)); - std::unordered_map - result_map; + std::unordered_map> + result; + + // Preallocate output before launching tasks + for (DynamicNodeInvocation const &invocation : dg.invocations) { + result.insert(std::pair{invocation, + DeviceSpecificPtr{ + ctx.get_current_device_idx(), std::nullopt}}); + } + for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -36,47 +44,21 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( subset_tensor_instance_backing_for_invocation(tensor_instance_backing, invocation); - // FIXME: in the absense of a real serializer we're just tossing around raw - // bytes, which means we need to bypass the constructor for this type (yes, - // ugh) - DeviceSpecificPerDeviceOpState *output = - static_cast( - malloc(sizeof(DeviceSpecificPerDeviceOpState))); - std::optional result = - spawn_device_state_init_task(ctx, - target_proc, - invocation, - tensor_backing, - profiling_settings, - device_handle.at(target_proc), - iteration_config, - optimizer_attrs, - output, - precondition); - if (result) { - result_map[invocation] = output; - } else { - free(output); - } + spawn_device_state_init_task(ctx, + target_proc, + invocation, + tensor_backing, + profiling_settings, + device_handle.at(target_proc), + iteration_config, + optimizer_attrs, + &result.at(invocation), + precondition); } ctx.get_outstanding_events().wait(); - DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( - dg, [&](DynamicNodeInvocation const &invocation) { - DynamicNodeInvocation result = invocation; - auto device_state = result_map.find(invocation); - if (device_state != result_map.end()) { - result.node_attrs.per_device_op_state = *device_state->second; - } - return result; - }); - - for (auto &[invocation, output] : result_map) { - free(output); - } - - return result; + return PerDeviceOpStateBacking{/*backing=*/result}; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 2287b9d54b..5d1a63ba5b 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -30,10 +30,12 @@ PCGInstance::PCGInstance( RealmContext &ctx, std::vector const &execution_order, TensorInstanceBacking const &tensor_instance_backing, + PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, std::optional logit_grad_tensor) : ctx(ctx), execution_order(execution_order), tensor_instance_backing(tensor_instance_backing), + device_state_backing(device_state_backing), optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} RealmContext &PCGInstance::get_realm_context() { @@ -46,6 +48,9 @@ std::vector const & TensorInstanceBacking const &PCGInstance::get_tensor_instance_backing() const { return this->tensor_instance_backing; } +PerDeviceOpStateBacking const &PCGInstance::get_device_state_backing() const { + return this->device_state_backing; +} OptimizerAttrs const &PCGInstance::get_optimizer_attrs() const { return this->optimizer_attrs; } @@ -92,7 +97,8 @@ PCGInstance create_pcg_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); + TensorInstanceBacking tensor_instance_backing = + perform_instance_allocation(dg, inputs, ctx); logit_grad_value = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -114,20 +120,19 @@ PCGInstance create_pcg_instance( std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - return backing.backing.at(lgv).first; + return tensor_instance_backing.backing.at(lgv).first; }); - // FIXME: for now we're going to be lazy and block on everything rather than - // do fine-grained dependencies on instances - dg = perform_distributed_device_state_initialization( - ctx, - dg, - backing, - profiling_settings, - device_handle, - iteration_config, - optimizer_attrs, - ctx.get_outstanding_events()); + PerDeviceOpStateBacking device_state_backing = + perform_distributed_device_state_initialization( + ctx, + dg, + tensor_instance_backing, + profiling_settings, + device_handle, + iteration_config, + optimizer_attrs, + ctx.get_outstanding_events()); // Compute the topological ordering of the graph auto [kwarg_graph, node_map] = @@ -138,14 +143,13 @@ PCGInstance create_pcg_instance( return PCGInstance{/*ctx=*/ctx, /*execution_order=*/invocation_topo_order, - /*tensor_instance_backing=*/backing, + /*tensor_instance_backing=*/tensor_instance_backing, + /*device_state_backing=*/device_state_backing, /*optimizer_attrs=*/optimizer_attrs, /*logit_grad_tensor=*/logit_grad_tensor}; // TODO list: // * external instances - // * task argument serializer - // * pass instances to task and convert to tensor accessor // * copies // * parallel operator implementation (partition, reduce, gather, etc.) // * and fused parallel operators (reduce + broadcast = allreduce) @@ -158,6 +162,7 @@ static std::unordered_map RealmContext &ctx, std::vector const &invocations, TensorInstanceBacking const &tensor_instance_backing, + PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -195,15 +200,17 @@ static std::unordered_map subset_tensor_instance_backing_for_invocation( tensor_instance_backing, invocation); - Realm::Event result = spawn_op_task(ctx, - target_proc, - invocation, - tensor_backing, - profiling_settings, - device_handle.at(target_proc), - iteration_config, - optimizer_attrs, - dependencies); + Realm::Event result = + spawn_op_task(ctx, + target_proc, + invocation, + tensor_backing, + device_state_backing.backing.at(invocation), + profiling_settings, + device_handle.at(target_proc), + iteration_config, + optimizer_attrs, + dependencies); for (DynamicValueAttrs const &value : values(invocation.inputs)) { dependency_set.add_reader(value, result); } @@ -228,6 +235,7 @@ std::unordered_map /*invocations=*/execution_order, /*tensor_instance_backing=*/ pcg_instance.get_tensor_instance_backing(), + /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, @@ -254,6 +262,7 @@ std::unordered_map /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, /*tensor_instance_backing=*/pcg_instance.get_tensor_instance_backing(), + /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, @@ -278,6 +287,7 @@ std::unordered_map /*ctx=*/pcg_instance.get_realm_context(), /*invocations=*/execution_order, /*tensor_instance_backing=*/pcg_instance.get_tensor_instance_backing(), + /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, @@ -304,6 +314,7 @@ std::unordered_map /*invocations=*/execution_order, /*tensor_instance_backing=*/ pcg_instance.get_tensor_instance_backing(), + /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*device_handle=*/device_handle, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc index 306697e950..a1a7eb84a8 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc @@ -7,16 +7,16 @@ struct DeviceStateInitReturnTaskArgs { public: DeviceStateInitReturnTaskArgs() = delete; DeviceStateInitReturnTaskArgs( - DeviceSpecificPerDeviceOpState result, + DeviceSpecificPtr result, Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) + DeviceSpecificPtr *origin_result_ptr) : result(result), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} public: - DeviceSpecificPerDeviceOpState result; + DeviceSpecificPtr result; Realm::Processor origin_proc; - DeviceSpecificPerDeviceOpState *origin_result_ptr; + DeviceSpecificPtr *origin_result_ptr; }; void device_state_init_return_task_body(void const *args, @@ -35,8 +35,8 @@ void device_state_init_return_task_body(void const *args, Realm::Event spawn_device_state_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr, + DeviceSpecificPtr const &result, + DeviceSpecificPtr *origin_result_ptr, Realm::Event precondition) { DeviceStateInitReturnTaskArgs task_args{ result, origin_proc, origin_result_ptr}; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 7f3f2d185c..50c8daffb0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -9,7 +9,9 @@ #include "realm-execution/tasks/task_id_t.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/per_device_op_state.h" #include "utils/containers/map_values.h" +#include "utils/containers/transform.h" #include "utils/optional.h" #include #include @@ -59,11 +61,14 @@ void device_state_init_task_body(void const *args, assert_unwrap(result_invocation.node_attrs.per_device_op_state); // Important: to make sure this doesn't get deallocated, we intentionally leak // the allocation here - DeviceSpecificPerDeviceOpState *result_state_ptr = - new DeviceSpecificPerDeviceOpState{result_state}; + PerDeviceOpState *result_state_ptr = + new PerDeviceOpState{get_device_state_from_device_specific( + result_state, ctx.get_current_device_idx())}; + DeviceSpecificPtr result_device_specific{ + ctx.get_current_device_idx(), result_state_ptr}; spawn_device_state_init_return_task(ctx, task_args.origin_proc, - *result_state_ptr, + result_device_specific, task_args.origin_result_ptr, Realm::Event::NO_EVENT); } @@ -77,7 +82,7 @@ std::optional spawn_device_state_init_task( DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, + DeviceSpecificPtr *result_ptr, Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ invocation, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index dc262bbdb1..e67df885d3 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -6,9 +6,11 @@ #include "realm-execution/tasks/impl/serializable_op_task_args.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/per_device_op_state.dtg.h" #include "task-spec/per_device_op_state.h" #include "task-spec/permissions.h" #include "utils/containers/map_values.h" +#include "utils/containers/transform.h" #include "utils/optional.h" #include @@ -49,11 +51,8 @@ void op_task_body(void const *args, /*profiling_settings=*/task_args.profiling_settings, /*ff_handle=*/device_handle, /*per_device_op_state=*/ - transform(task_args.invocation.node_attrs.per_device_op_state, - [&](DeviceSpecificPerDeviceOpState const &op_state) { - return get_device_state_from_device_specific( - op_state, ctx.get_current_device_idx()); - }), + transform(task_args.device_state.get(ctx.get_current_device_idx()), + [](PerDeviceOpState *ptr) { return *ptr; }), /*iteration_config=*/task_args.iteration_config, /*optimizer_attrs=*/task_args.optimizer_attrs, /*device_idx=*/ctx.get_current_device_idx()); @@ -64,6 +63,7 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, TensorInstanceBacking const &tensor_backing, + DeviceSpecificPtr const &device_state, ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const &iteration_config, @@ -71,6 +71,7 @@ Realm::Event Realm::Event precondition) { OpTaskArgs task_args{invocation, tensor_backing, + device_state, profiling_settings, device_handle, iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc index fed22ff393..2e7e02b529 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -35,7 +35,7 @@ DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( /*optimizer_attrs=*/args.optimizer_attrs, /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), /*origin_result_ptr=*/ - reinterpret_cast( + reinterpret_cast *>( args.origin_result_ptr), }; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index 80994d4298..a17e58da5e 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -10,6 +10,7 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), /*tensor_backing*/ tensor_instance_backing_to_serializable(args.tensor_backing), + /*device_state=*/device_specific_ptr_to_serializable(args.device_state), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), /*iteration_config=*/args.iteration_config, @@ -22,6 +23,9 @@ OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), /*tensor_backing*/ tensor_instance_backing_from_serializable(args.tensor_backing), + /*device_state=*/ + device_specific_ptr_from_serializable( + args.device_state), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ device_specific_ptr_from_serializable( From 0fc66baf8562b720fb08cbd22f62c0e8edbce921 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 17:34:28 -0800 Subject: [PATCH 72/88] Register loss task. --- .../src/realm-execution/tasks/realm_task_registry.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 914e8d1e29..fa056d6f33 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -117,6 +117,9 @@ Realm::Event register_all_tasks() { // Update tasks task_id_t::SGD_UPD_NCCL_TASK_ID, task_id_t::ADAM_UPD_NCCL_TASK_ID, + + // Loss task + task_id_t::LOSS_BWD_TASK_ID, }; for (task_id_t task_id : task_ids) { From 5ba6a61661040bc8abe9d1cc0c01795513ba821e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 17 Feb 2026 19:28:44 -0800 Subject: [PATCH 73/88] Test loss in Realm. --- .../test/src/local-execution/test_e2e.cc | 4 +- .../test/src/realm-execution/test_e2e.cc | 49 ++++++++++++++----- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/lib/local-execution/test/src/local-execution/test_e2e.cc b/lib/local-execution/test/src/local-execution/test_e2e.cc index a74d165a31..615ba204cf 100644 --- a/lib/local-execution/test/src/local-execution/test_e2e.cc +++ b/lib/local-execution/test/src/local-execution/test_e2e.cc @@ -21,8 +21,8 @@ using namespace ::FlexFlow; -bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, - GenericTensorAccessorR const &last_epoch) { +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch) { Allocator cpu_allocator = create_local_cpu_memory_allocator(); return tensor_accessor_all( diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 4dbfe09045..28665e840b 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,5 +1,11 @@ #include "internal/realm_test_utils.h" #include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/local_cpu_allocator.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/device_type.dtg.h" @@ -9,8 +15,11 @@ #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/distributed_device_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" #include "realm-execution/pcg_instance/pcg_instance.h" #include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" #include "utils/containers/require_only_key.h" #include @@ -19,6 +28,14 @@ namespace test { using namespace ::FlexFlow; namespace Realm = ::FlexFlow::Realm; +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch) { + Allocator cpu_allocator = create_local_cpu_memory_allocator(); + + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, cpu_allocator)); +} + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training") { std::vector fake_args = @@ -216,20 +233,28 @@ TEST_SUITE(FF_TEST_SUITE) { /*profiling_settings=*/ProfilingSettings{0, 0}, /*device_handle=*/device_handle, /*iteration_config=*/FFIterationConfig{1_p}); - // loss_values.push_back(copy_tensor_accessor_r( - // pcg_instance.get_loss_tensor_accessor().value(), - // allocator)); + loss_values.push_back(copy_tensor_accessor_r( + dynamic_tensor_accessor_from_instance( + pcg_instance.get_loss_tensor_instance().value(), + Realm::Event::NO_EVENT, + lift_to_parallel( + TensorShape{TensorDims{FFOrdered{output_dim, hidden_dim}}, + DataType::FLOAT}), + Permissions::RO, + ctx.get_current_processor()) + .require_read(), + allocator)); } - // // Assert that each sample in the batch has a lower loss in last epoch - // // than the first epoch - // GenericTensorAccessorR first_epoch_loss = loss_values.at(0); - // GenericTensorAccessorR last_epoch_loss = loss_values.back(); - // CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), - // check_kv("first_epoch_loss", - // format_accessor_r_contents(first_epoch_loss)), - // check_kv("last_epoch_loss", - // format_accessor_r_contents(last_epoch_loss))); + // Assert that each sample in the batch has a lower loss in last epoch + // than the first epoch + GenericTensorAccessorR first_epoch_loss = loss_values.at(0); + GenericTensorAccessorR last_epoch_loss = loss_values.back(); + CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), + check_kv("first_epoch_loss", + format_accessor_r_contents(first_epoch_loss)), + check_kv("last_epoch_loss", + format_accessor_r_contents(last_epoch_loss))); }); } } From 8b13e27477e21a7a206ea64d28db5fe8ccaecc27 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 18 Feb 2026 10:05:10 -0800 Subject: [PATCH 74/88] Test CPU model parallelism. --- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 28665e840b..706fc002c1 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -37,9 +37,9 @@ static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, } TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("RealmBackend e2e Training") { + TEST_CASE("RealmBackend e2e Training (CPU Model Parallelism)") { std::vector fake_args = - make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/0_n); + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); @@ -149,6 +149,7 @@ TEST_SUITE(FF_TEST_SUITE) { require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; MappedParallelComputationGraph mpcg{ pcg, @@ -165,7 +166,7 @@ TEST_SUITE(FF_TEST_SUITE) { {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, {weights_layer_2.parallel_layer, MappedOperatorTaskGroup{ - {{cpu0, + {{cpu1, OperatorAtomicTaskShardBinding{ {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, {linear_operator_1.parallel_layer, @@ -178,7 +179,7 @@ TEST_SUITE(FF_TEST_SUITE) { }}}}}}, {linear_operator_2.parallel_layer, MappedOperatorTaskGroup{ - {{cpu0, + {{cpu1, OperatorAtomicTaskShardBinding{{ {TensorSlotName::INPUT, tensor_coord0}, {TensorSlotName::WEIGHT, tensor_coord0}, From a05fa06d8fed302e076c204e70a3532a099a4883 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 18 Feb 2026 10:10:47 -0800 Subject: [PATCH 75/88] Use Realm's own allocator in test. --- .../test/src/realm-execution/test_e2e.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 706fc002c1..02c1365039 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -3,7 +3,6 @@ #include "kernels/compare_tensor_accessors.h" #include "kernels/copy_tensor_accessor.h" #include "kernels/format_accessor_contents.h" -#include "kernels/local_cpu_allocator.h" #include "kernels/tensor_accessor_reductions.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" @@ -17,6 +16,7 @@ #include "realm-execution/distributed_device_handle.h" #include "realm-execution/dynamic_tensor_accessor_from_instance.h" #include "realm-execution/pcg_instance/pcg_instance.h" +#include "realm-execution/realm_context.h" #include "realm-execution/realm_manager.h" #include "task-spec/permissions.h" #include "test/utils/doctest/check_kv.h" @@ -28,12 +28,11 @@ namespace test { using namespace ::FlexFlow; namespace Realm = ::FlexFlow::Realm; -static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, +static bool did_loss_decrease(RealmContext &ctx, + GenericTensorAccessorR const &first_epoch, GenericTensorAccessorR const &last_epoch) { - Allocator cpu_allocator = create_local_cpu_memory_allocator(); - - return tensor_accessor_all( - compare_tensor_accessors_le(last_epoch, first_epoch, cpu_allocator)); + return tensor_accessor_all(compare_tensor_accessors_le( + last_epoch, first_epoch, ctx.get_current_device_allocator())); } TEST_SUITE(FF_TEST_SUITE) { From a59ba1e194861a4df23b708aaf81e80785ebdf6e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 18 Feb 2026 11:01:21 -0800 Subject: [PATCH 76/88] Fix typo. --- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 02c1365039..96a8ba49dc 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -250,7 +250,7 @@ TEST_SUITE(FF_TEST_SUITE) { // than the first epoch GenericTensorAccessorR first_epoch_loss = loss_values.at(0); GenericTensorAccessorR last_epoch_loss = loss_values.back(); - CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), + CHECK_MESSAGE(did_loss_decrease(ctx, first_epoch_loss, last_epoch_loss), check_kv("first_epoch_loss", format_accessor_r_contents(first_epoch_loss)), check_kv("last_epoch_loss", From ea76b0f13803239abe85ee1e358358a3a087c210 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 18 Feb 2026 16:22:19 -0800 Subject: [PATCH 77/88] Add Realm top-level README. --- lib/realm-execution/README.md | 32 +++++++++++++++++++ .../pcg_instance/pcg_instance.cc | 8 ----- 2 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 lib/realm-execution/README.md diff --git a/lib/realm-execution/README.md b/lib/realm-execution/README.md new file mode 100644 index 0000000000..1454c7eac8 --- /dev/null +++ b/lib/realm-execution/README.md @@ -0,0 +1,32 @@ +The Realm backend for distributed execution. + +This is a single-controller implementation. That means the controller (the task that launches all other work) runs on a single node and remotely launches work onto other nodes. Aside from caveats mentioned below, this implementation is (mostly) capable of distributed execution. + +Major components: + +* `PCGInstance`: the main public interface for the Realm backend. It takes a mapped PCG and lowers it through the dynamic graph to get the fully-specified execution order of tasks to be executed. Besides the usual dynamic graph passes (pass expansion, update insertion, shard expansion), this class also tracks the allocation of Realm instances for tensors. +* `RealmManager`: manages the initialization and shutdown of the Realm runtime. Provides the interface to launch the controller that runs the rest of the computation. +* `RealmContext`: an interface that wraps the rest of Realm and protects against certain classes of bugs, such as shutdown bugs. **Do NOT call Realm directly unless you know what you are doing.** +* `tasks/`: the Realm task implementations and their supporting infrastructure. + * `impl/`: the actual bodies of Realm tasks, along with interfaces to call them, and the serialization infrastructure for their arguments. + * `serializer/`: additional support for serializing Realm data types. + * `realm_task_registry.h`: manages the registration of Realm tasks. All Realm tasks go through this interface. + * `task_id_t.h` and `realm_task_id_t.h`: types to represent Realm tasks, along with an encoding to Realm's native task ID type. + +Other components used mainly within `PCGInstance`: + + * `DistributedDeviceHandle`: represents a distributed device handle (i.e., device handles on all the GPUs on the system), for convenience. + * `DependenceSet`: tracks dependencies during execution of tasks. + * `distributed_device_state_initialization.h`: performs device state initialization of dynamic graph nodes and returns the resulting `PerDeviceOpStateBacking`. + * `instance_allocation.h`: allocates instances for tensors in the dynamic graph and returns the resulting `TensorInstanceBacking`. + +TODO list: + +* external instances +* copies +* task fusion +* parallel operator implementation (partition, reduce, gather, etc.) +* and fused parallel operators (reduce + broadcast = allreduce) +* memory-optimizing compiler integration (tensor creation/destruction, tensor reuse) +* control replication +* Realm subgraphs diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 5d1a63ba5b..9f359ded10 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -147,14 +147,6 @@ PCGInstance create_pcg_instance( /*device_state_backing=*/device_state_backing, /*optimizer_attrs=*/optimizer_attrs, /*logit_grad_tensor=*/logit_grad_tensor}; - - // TODO list: - // * external instances - // * copies - // * parallel operator implementation (partition, reduce, gather, etc.) - // * and fused parallel operators (reduce + broadcast = allreduce) - // * memory-optimizing compiler integration (tensor creation/destruction, - // tensor reuse) } static std::unordered_map From 8be3fdce80e1957370bcb40ec828f2cf66ab73b7 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 20 Feb 2026 12:02:23 -0800 Subject: [PATCH 78/88] Add and fix GPU test (no loss so far). --- .../realm-execution/tasks/impl/op_task.h | 23 +- .../tasks/impl/op_task_args.dtg.toml | 2 +- .../impl/serializable_op_task_args.dtg.toml | 2 +- ...distributed_device_state_initialization.cc | 60 +++-- .../pcg_instance/pcg_instance.cc | 3 +- .../src/realm-execution/tasks/impl/op_task.cc | 27 ++- .../tasks/impl/serializable_op_task_args.cc | 9 +- .../test/src/realm-execution/test_e2e.cc | 224 ++++++++++++++++++ 8 files changed, 301 insertions(+), 49 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 33dcbff895..8399742424 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -12,22 +12,23 @@ #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/per_device_op_state.dtg.h" +#include namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - TensorInstanceBacking const &tensor_backing, - DeviceSpecificPtr const &device_state, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition); +Realm::Event spawn_op_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + TensorInstanceBacking const &tensor_backing, + std::optional> const &device_state, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml index a15c8dce11..f6bb83fbca 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -24,7 +24,7 @@ type = "::FlexFlow::TensorInstanceBacking" [[fields]] name = "device_state" -type = "::FlexFlow::DeviceSpecificPtr<::FlexFlow::PerDeviceOpState>" +type = "std::optional<::FlexFlow::DeviceSpecificPtr<::FlexFlow::PerDeviceOpState>>" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml index 2be0034c46..adac6631ee 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -32,7 +32,7 @@ type = "::FlexFlow::SerializableTensorInstanceBacking" [[fields]] name = "device_state" -type = "::FlexFlow::SerializableDeviceSpecificPtr" +type = "std::optional<::FlexFlow::SerializableDeviceSpecificPtr>" [[fields]] name = "profiling_settings" diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index c6d8ea3e69..5c0aff00c2 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -7,9 +7,12 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "utils/containers/map_values.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/optional.h" #include #include +#include namespace FlexFlow { @@ -26,16 +29,9 @@ PerDeviceOpStateBacking perform_distributed_device_state_initialization( // Initialize all operators and save the per-device op state ASSERT(no_nodes_are_initialized(dg)); - std::unordered_map> - result; - - // Preallocate output before launching tasks - for (DynamicNodeInvocation const &invocation : dg.invocations) { - result.insert(std::pair{invocation, - DeviceSpecificPtr{ - ctx.get_current_device_idx(), std::nullopt}}); - } - + std::unordered_map *> + device_state_map; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -44,20 +40,44 @@ PerDeviceOpStateBacking perform_distributed_device_state_initialization( subset_tensor_instance_backing_for_invocation(tensor_instance_backing, invocation); - spawn_device_state_init_task(ctx, - target_proc, - invocation, - tensor_backing, - profiling_settings, - device_handle.at(target_proc), - iteration_config, - optimizer_attrs, - &result.at(invocation), - precondition); + DeviceSpecificPtr *device_state_ptr = + new DeviceSpecificPtr{ctx.get_current_device_idx(), + std::nullopt}; + + std::optional completion_event = + spawn_device_state_init_task(ctx, + target_proc, + invocation, + tensor_backing, + profiling_settings, + device_handle.at(target_proc), + iteration_config, + optimizer_attrs, + device_state_ptr, + precondition); + + if (completion_event.has_value()) { + device_state_map.insert(std::pair{invocation, device_state_ptr}); + } else { + // Task doesn't require initialization, clean up and don't store result + delete device_state_ptr; + } } ctx.get_outstanding_events().wait(); + auto deref = [](DynamicNodeInvocation const &i, + DeviceSpecificPtr *const &p) { + return std::pair{i, *p}; + }; + std::unordered_map> + result = transform(device_state_map, deref); + + for (DeviceSpecificPtr *device_state_ptr : + values(device_state_map)) { + delete device_state_ptr; + } + return PerDeviceOpStateBacking{/*backing=*/result}; } diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 9f359ded10..4b08e9a430 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -20,6 +20,7 @@ #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/containers/try_at.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/optional.h" @@ -197,7 +198,7 @@ static std::unordered_map target_proc, invocation, tensor_backing, - device_state_backing.backing.at(invocation), + try_at(device_state_backing.backing, invocation), profiling_settings, device_handle.at(target_proc), iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index e67df885d3..c7dcdb39c2 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -51,24 +51,27 @@ void op_task_body(void const *args, /*profiling_settings=*/task_args.profiling_settings, /*ff_handle=*/device_handle, /*per_device_op_state=*/ - transform(task_args.device_state.get(ctx.get_current_device_idx()), + transform(and_then(task_args.device_state, + [&](DeviceSpecificPtr const &d) { + return d.get(ctx.get_current_device_idx()); + }), [](PerDeviceOpState *ptr) { return *ptr; }), /*iteration_config=*/task_args.iteration_config, /*optimizer_attrs=*/task_args.optimizer_attrs, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - TensorInstanceBacking const &tensor_backing, - DeviceSpecificPtr const &device_state, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition) { +Realm::Event spawn_op_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + TensorInstanceBacking const &tensor_backing, + std::optional> const &device_state, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{invocation, tensor_backing, device_state, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index a17e58da5e..32d54adc37 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -2,6 +2,7 @@ #include "realm-execution/tasks/serializer/serializable_device_specific_ptr.h" #include "realm-execution/tasks/serializer/serializable_tensor_instance_backing.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "utils/containers/transform.h" namespace FlexFlow { @@ -10,7 +11,9 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), /*tensor_backing*/ tensor_instance_backing_to_serializable(args.tensor_backing), - /*device_state=*/device_specific_ptr_to_serializable(args.device_state), + /*device_state=*/ + transform(args.device_state, + device_specific_ptr_to_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), /*iteration_config=*/args.iteration_config, @@ -24,8 +27,8 @@ OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { /*tensor_backing*/ tensor_instance_backing_from_serializable(args.tensor_backing), /*device_state=*/ - device_specific_ptr_from_serializable( - args.device_state), + transform(args.device_state, + device_specific_ptr_from_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/ device_specific_ptr_from_serializable( diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 96a8ba49dc..d9252693a1 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -259,4 +259,228 @@ TEST_SUITE(FF_TEST_SUITE) { } } +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/1_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager(&fake_argc, &fake_argv); + + (void)manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor_backing = + allocator.allocate_tensor(output_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + TensorShape weight_shape_1 = TensorShape{ + TensorDims{FFOrdered{hidden_dim, data_dim}}, DataType::FLOAT}; + TensorShape weight_shape_2 = TensorShape{ + TensorDims{FFOrdered{output_dim, hidden_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer_with_grad(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_1, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_1 = + require_only_key(weights_layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_2, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_2 = + require_only_key(weights_layer_2.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{hidden_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_1, + }, + }); + parallel_tensor_guid_t t_linear_1 = + require_only_key(linear_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{output_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_linear_1, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_2, + }, + }); + parallel_tensor_guid_t t_linear_2 = + require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); + + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {linear_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {linear_operator_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + }, + }; + MappedOperatorTaskGroup loss_mapping{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedDeviceHandle device_handle = create_distributed_device_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/loss_attrs, + /*label_tensor=*/label_tensor, + /*logit_tensor=*/t_linear_2, + /*loss_mapping=*/loss_mapping, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 5; + std::vector loss_values; + + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + // loss_values.push_back(copy_tensor_accessor_r( + // dynamic_tensor_accessor_from_instance( + // pcg_instance.get_loss_tensor_instance().value(), + // Realm::Event::NO_EVENT, + // lift_to_parallel( + // TensorShape{TensorDims{FFOrdered{output_dim, + // hidden_dim}}, + // DataType::FLOAT}), + // Permissions::RO, + // ctx.get_current_processor()) + // .require_read(), + // allocator)); + } + + // // Assert that each sample in the batch has a lower loss in last epoch + // // than the first epoch + // GenericTensorAccessorR first_epoch_loss = loss_values.at(0); + // GenericTensorAccessorR last_epoch_loss = loss_values.back(); + // CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), + // check_kv("first_epoch_loss", + // format_accessor_r_contents(first_epoch_loss)), + // check_kv("last_epoch_loss", + // format_accessor_r_contents(last_epoch_loss))); + }); + } +} + } // namespace test From bd0227ca590e81010887152b44d7a1ec6608f702 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 20 Feb 2026 13:05:23 -0800 Subject: [PATCH 79/88] Add a GPU distributed handle test. --- .../distributed_device_handle.cc | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc index fb7dff01e3..aaefe337db 100644 --- a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc @@ -24,9 +24,43 @@ TEST_SUITE(FF_TEST_SUITE) { /*allowTensorOpMathConversion=*/true); // Make sure we have handles for the processors we're expecting - Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); - pq.only_kind(Realm::Processor::LOC_PROC); - for (Realm::Processor proc : pq) { + Realm::Machine::ProcessorQuery cpus(Realm::Machine::get_machine()); + cpus.only_kind(Realm::Processor::LOC_PROC); + CHECK(cpus.count() == 2); + for (Realm::Processor proc : cpus) { + handle.at(proc); + } + }); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("DistributedDeviceHandle (GPU)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/1_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager(&fake_argc, &fake_argv); + + (void)manager.start_controller([](RealmContext &ctx) { + DistributedDeviceHandle handle = create_distributed_device_handle( + /*ctx=*/ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + // Make sure we have handles for the processors we're expecting + Realm::Machine::ProcessorQuery cpus(Realm::Machine::get_machine()); + cpus.only_kind(Realm::Processor::LOC_PROC); + CHECK(cpus.count() == 2); + for (Realm::Processor proc : cpus) { + handle.at(proc); + } + + Realm::Machine::ProcessorQuery gpus(Realm::Machine::get_machine()); + gpus.only_kind(Realm::Processor::TOC_PROC); + CHECK(gpus.count() == 1); + for (Realm::Processor proc : gpus) { handle.at(proc); } }); From 9b726fbfd62324a4c857f2fa64417bc0f099bc3e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 20 Feb 2026 15:07:25 -0800 Subject: [PATCH 80/88] Test GPU loss values. --- .../dynamic_tensor_accessor_from_instance.cc | 45 ++++++++++++++----- .../test/src/realm-execution/test_e2e.cc | 41 +++++++++-------- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc index d1c773b1fa..a2a40e3752 100644 --- a/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc +++ b/lib/realm-execution/src/realm-execution/dynamic_tensor_accessor_from_instance.cc @@ -6,6 +6,38 @@ namespace FlexFlow { +static DeviceType infer_device_type_from_memory_and_processor( + Realm::Memory inst_memory, Realm::Processor for_processor) { + DeviceType device_type; + switch (inst_memory.kind()) { + case Realm::Memory::SYSTEM_MEM: + // Only accessible on CPU + device_type = DeviceType::CPU; + break; + case Realm::Memory::GPU_FB_MEM: + // Only accessible on GPU + device_type = DeviceType::GPU; + break; + case Realm::Memory::Z_COPY_MEM: { + // Accessible on either CPU or GPU, so infer based on where we're trying + // to access from + switch (for_processor.kind()) { + case Realm::Processor::LOC_PROC: + device_type = DeviceType::CPU; + break; + case Realm::Processor::TOC_PROC: + device_type = DeviceType::GPU; + break; + default: + PANIC("Unexpected Realm Processor kind", for_processor.kind()); + } + } break; + default: + PANIC("Unexpected Realm Memory kind", inst_memory.kind()); + } + return device_type; +} + DynamicTensorAccessor dynamic_tensor_accessor_from_instance( Realm::RegionInstance inst, Realm::Event ready, @@ -14,17 +46,8 @@ DynamicTensorAccessor dynamic_tensor_accessor_from_instance( Realm::Processor for_processor) { ready.wait(); - DeviceType device_type; - switch (for_processor.kind()) { - case Realm::Processor::LOC_PROC: - device_type = DeviceType::CPU; - break; - case Realm::Processor::TOC_PROC: - device_type = DeviceType::GPU; - break; - default: - PANIC("Unexpected Realm Processor kind", for_processor.kind()); - } + DeviceType device_type = infer_device_type_from_memory_and_processor( + inst.get_location(), for_processor); size_t expected_size = int{get_piece_size_in_bytes(parallel_tensor_shape).unwrap_num_bytes()}; diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index d9252693a1..f5f7357105 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -456,29 +456,28 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { /*profiling_settings=*/ProfilingSettings{0, 0}, /*device_handle=*/device_handle, /*iteration_config=*/FFIterationConfig{1_p}); - // loss_values.push_back(copy_tensor_accessor_r( - // dynamic_tensor_accessor_from_instance( - // pcg_instance.get_loss_tensor_instance().value(), - // Realm::Event::NO_EVENT, - // lift_to_parallel( - // TensorShape{TensorDims{FFOrdered{output_dim, - // hidden_dim}}, - // DataType::FLOAT}), - // Permissions::RO, - // ctx.get_current_processor()) - // .require_read(), - // allocator)); + loss_values.push_back(copy_tensor_accessor_r( + dynamic_tensor_accessor_from_instance( + pcg_instance.get_loss_tensor_instance().value(), + Realm::Event::NO_EVENT, + lift_to_parallel( + TensorShape{TensorDims{FFOrdered{output_dim, hidden_dim}}, + DataType::FLOAT}), + Permissions::RO, + ctx.get_current_processor()) + .require_read(), + allocator)); } - // // Assert that each sample in the batch has a lower loss in last epoch - // // than the first epoch - // GenericTensorAccessorR first_epoch_loss = loss_values.at(0); - // GenericTensorAccessorR last_epoch_loss = loss_values.back(); - // CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), - // check_kv("first_epoch_loss", - // format_accessor_r_contents(first_epoch_loss)), - // check_kv("last_epoch_loss", - // format_accessor_r_contents(last_epoch_loss))); + // Assert that each sample in the batch has a lower loss in last epoch + // than the first epoch + GenericTensorAccessorR first_epoch_loss = loss_values.at(0); + GenericTensorAccessorR last_epoch_loss = loss_values.back(); + CHECK_MESSAGE(did_loss_decrease(ctx, first_epoch_loss, last_epoch_loss), + check_kv("first_epoch_loss", + format_accessor_r_contents(first_epoch_loss)), + check_kv("last_epoch_loss", + format_accessor_r_contents(last_epoch_loss))); }); } } From 0b75bbdbca42b7c095710e1e3d2c69debff8a2bc Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 20 Feb 2026 15:07:41 -0800 Subject: [PATCH 81/88] Update Realm to include build fixes. --- .flake/pkgs/realm.nix | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix index b809573690..9f1fb8832c 100644 --- a/.flake/pkgs/realm.nix +++ b/.flake/pkgs/realm.nix @@ -13,13 +13,13 @@ in stdenv.mkDerivation rec { pname = "realm"; - version = "2026-02-06"; + version = "2026-02-18"; src = fetchFromGitHub { owner = "StanfordLegion"; repo = "realm"; - rev = "0405b67ca14b586f7dec0dcddee194cecee7efa6"; - sha256 = "sha256-iUPVV1rh3QuyDKgXuu8aDlaZGlNwcpPvPsSVLWp8tr4="; + rev = "47f18543592cb69c5bc7c97ee7e2bc521d377d3e"; + sha256 = "sha256-brAWh2p67hIyfrtNKN+6XZjIB0V2gYGBjdIocuwtmj4="; }; nativeBuildInputs = [ From 91904e822b5ca553e5e0173bb3398221ae825cd5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Mon, 23 Feb 2026 11:19:16 -0800 Subject: [PATCH 82/88] Ensure that Realm tests do not leak instances. --- .../include/realm-execution/realm_allocator.h | 2 +- .../src/realm-execution/realm_allocator.cc | 4 ++ .../test/src/realm-execution/test_e2e.cc | 47 +++++++++++++------ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h index d72f2d7f91..e53994dd20 100644 --- a/lib/realm-execution/include/realm-execution/realm_allocator.h +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -8,11 +8,11 @@ namespace FlexFlow { struct RealmAllocator : public IAllocator { RealmAllocator(Realm::Processor processor, Realm::Memory memory); + ~RealmAllocator(); RealmAllocator() = delete; RealmAllocator(RealmAllocator const &) = delete; RealmAllocator(RealmAllocator &&) = delete; - ~RealmAllocator() = default; void *allocate(size_t) override; void deallocate(void *) override; diff --git a/lib/realm-execution/src/realm-execution/realm_allocator.cc b/lib/realm-execution/src/realm-execution/realm_allocator.cc index f24106b0bc..37721fbcee 100644 --- a/lib/realm-execution/src/realm-execution/realm_allocator.cc +++ b/lib/realm-execution/src/realm-execution/realm_allocator.cc @@ -7,6 +7,10 @@ namespace FlexFlow { RealmAllocator::RealmAllocator(Realm::Processor processor, Realm::Memory memory) : processor(processor), memory(memory) {} +RealmAllocator::~RealmAllocator() { + ASSERT(this->ptr_instances.empty()); +} + void *RealmAllocator::allocate(size_t requested_memory_size) { Realm::Rect<1> bounds{Realm::Point<1>::ZEROES(), Realm::Point<1>{requested_memory_size} - diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index f5f7357105..1ac471d491 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -28,11 +28,14 @@ namespace test { using namespace ::FlexFlow; namespace Realm = ::FlexFlow::Realm; -static bool did_loss_decrease(RealmContext &ctx, - GenericTensorAccessorR const &first_epoch, - GenericTensorAccessorR const &last_epoch) { - return tensor_accessor_all(compare_tensor_accessors_le( - last_epoch, first_epoch, ctx.get_current_device_allocator())); +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + GenericTensorAccessorW tensor_le = + compare_tensor_accessors_le(last_epoch, first_epoch, allocator); + bool result = tensor_accessor_all(tensor_le); + allocator.deallocate_tensor(tensor_le); + return result; } TEST_SUITE(FF_TEST_SUITE) { @@ -250,11 +253,18 @@ TEST_SUITE(FF_TEST_SUITE) { // than the first epoch GenericTensorAccessorR first_epoch_loss = loss_values.at(0); GenericTensorAccessorR last_epoch_loss = loss_values.back(); - CHECK_MESSAGE(did_loss_decrease(ctx, first_epoch_loss, last_epoch_loss), - check_kv("first_epoch_loss", - format_accessor_r_contents(first_epoch_loss)), - check_kv("last_epoch_loss", - format_accessor_r_contents(last_epoch_loss))); + CHECK_MESSAGE( + did_loss_decrease(first_epoch_loss, last_epoch_loss, allocator), + check_kv("first_epoch_loss", + format_accessor_r_contents(first_epoch_loss)), + check_kv("last_epoch_loss", + format_accessor_r_contents(last_epoch_loss))); + + for (GenericTensorAccessorR const &loss_value : loss_values) { + allocator.deallocate_tensor(loss_value); + } + allocator.deallocate_tensor(label_tensor); + allocator.deallocate_tensor(label_tensor_backing); }); } } @@ -473,11 +483,18 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // than the first epoch GenericTensorAccessorR first_epoch_loss = loss_values.at(0); GenericTensorAccessorR last_epoch_loss = loss_values.back(); - CHECK_MESSAGE(did_loss_decrease(ctx, first_epoch_loss, last_epoch_loss), - check_kv("first_epoch_loss", - format_accessor_r_contents(first_epoch_loss)), - check_kv("last_epoch_loss", - format_accessor_r_contents(last_epoch_loss))); + CHECK_MESSAGE( + did_loss_decrease(first_epoch_loss, last_epoch_loss, allocator), + check_kv("first_epoch_loss", + format_accessor_r_contents(first_epoch_loss)), + check_kv("last_epoch_loss", + format_accessor_r_contents(last_epoch_loss))); + + for (GenericTensorAccessorR const &loss_value : loss_values) { + allocator.deallocate_tensor(loss_value); + } + allocator.deallocate_tensor(label_tensor); + allocator.deallocate_tensor(label_tensor_backing); }); } } From 2f5decb17345b01e8aafde90b2a057718d3633d7 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 24 Feb 2026 13:36:10 -0800 Subject: [PATCH 83/88] Update Realm allocator to follow pattern of other allocators. --- .../include/realm-execution/realm_allocator.h | 2 +- .../src/realm-execution/realm_allocator.cc | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h index e53994dd20..d716016676 100644 --- a/lib/realm-execution/include/realm-execution/realm_allocator.h +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -8,11 +8,11 @@ namespace FlexFlow { struct RealmAllocator : public IAllocator { RealmAllocator(Realm::Processor processor, Realm::Memory memory); - ~RealmAllocator(); RealmAllocator() = delete; RealmAllocator(RealmAllocator const &) = delete; RealmAllocator(RealmAllocator &&) = delete; + ~RealmAllocator() override; void *allocate(size_t) override; void deallocate(void *) override; diff --git a/lib/realm-execution/src/realm-execution/realm_allocator.cc b/lib/realm-execution/src/realm-execution/realm_allocator.cc index 37721fbcee..194210cf5a 100644 --- a/lib/realm-execution/src/realm-execution/realm_allocator.cc +++ b/lib/realm-execution/src/realm-execution/realm_allocator.cc @@ -1,6 +1,8 @@ #include "realm-execution/realm_allocator.h" #include "kernels/device.h" #include "pcg/device_type.dtg.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/values.h" namespace FlexFlow { @@ -8,7 +10,9 @@ RealmAllocator::RealmAllocator(Realm::Processor processor, Realm::Memory memory) : processor(processor), memory(memory) {} RealmAllocator::~RealmAllocator() { - ASSERT(this->ptr_instances.empty()); + for (Realm::RegionInstance const &instance : values(this->ptr_instances)) { + instance.destroy(Realm::Event::NO_EVENT); + } } void *RealmAllocator::allocate(size_t requested_memory_size) { @@ -33,6 +37,9 @@ void *RealmAllocator::allocate(size_t requested_memory_size) { } void RealmAllocator::deallocate(void *ptr) { + ASSERT(contains_key(this->ptr_instances, ptr), + "Deallocating a pointer that was not allocated by this Allocator"); + this->ptr_instances.at(ptr).destroy(Realm::Event::NO_EVENT); this->ptr_instances.erase(ptr); } From 1db1448d40f6513a78b7e8e946ac4526880f2c6c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 24 Feb 2026 13:40:08 -0800 Subject: [PATCH 84/88] Remove explicit deallocation which is not required by updated allocator. --- .../test/src/realm-execution/test_e2e.cc | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 1ac471d491..0914c054d7 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -31,11 +31,8 @@ namespace Realm = ::FlexFlow::Realm; static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, GenericTensorAccessorR const &last_epoch, Allocator &allocator) { - GenericTensorAccessorW tensor_le = - compare_tensor_accessors_le(last_epoch, first_epoch, allocator); - bool result = tensor_accessor_all(tensor_le); - allocator.deallocate_tensor(tensor_le); - return result; + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); } TEST_SUITE(FF_TEST_SUITE) { @@ -259,12 +256,6 @@ TEST_SUITE(FF_TEST_SUITE) { format_accessor_r_contents(first_epoch_loss)), check_kv("last_epoch_loss", format_accessor_r_contents(last_epoch_loss))); - - for (GenericTensorAccessorR const &loss_value : loss_values) { - allocator.deallocate_tensor(loss_value); - } - allocator.deallocate_tensor(label_tensor); - allocator.deallocate_tensor(label_tensor_backing); }); } } @@ -489,12 +480,6 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { format_accessor_r_contents(first_epoch_loss)), check_kv("last_epoch_loss", format_accessor_r_contents(last_epoch_loss))); - - for (GenericTensorAccessorR const &loss_value : loss_values) { - allocator.deallocate_tensor(loss_value); - } - allocator.deallocate_tensor(label_tensor); - allocator.deallocate_tensor(label_tensor_backing); }); } } From 7a66f5a698e223fb64515baa709536a35d9c5656 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Mon, 23 Feb 2026 14:34:47 -0800 Subject: [PATCH 85/88] Support for PRealm. --- .flake/pkgs/realm.nix | 6 +++--- .../distributed_device_handle.h | 1 - .../include/realm-execution/hash/processor.h | 20 ------------------- .../realm-execution/instance_allocation.h | 3 +++ .../pcg_instance/pcg_instance.h | 1 + .../include/realm-execution/realm.h | 2 +- .../serializable_realm_instance.dtg.toml | 10 ++++++++-- .../tensor_instance_backing.dtg.toml | 2 +- .../src/realm-execution/hash/processor.cc | 15 -------------- .../realm-execution/instance_allocation.cc | 7 +++++++ .../pcg_instance/pcg_instance.cc | 5 +++++ .../tasks/realm_task_registry.cc | 7 ++++++- .../serializer/serializable_realm_instance.cc | 12 +++++++++-- 13 files changed, 45 insertions(+), 46 deletions(-) delete mode 100644 lib/realm-execution/include/realm-execution/hash/processor.h delete mode 100644 lib/realm-execution/src/realm-execution/hash/processor.cc diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix index 9f1fb8832c..b7985b497d 100644 --- a/.flake/pkgs/realm.nix +++ b/.flake/pkgs/realm.nix @@ -13,13 +13,13 @@ in stdenv.mkDerivation rec { pname = "realm"; - version = "2026-02-18"; + version = "2026-02-22-prealm"; src = fetchFromGitHub { owner = "StanfordLegion"; repo = "realm"; - rev = "47f18543592cb69c5bc7c97ee7e2bc521d377d3e"; - sha256 = "sha256-brAWh2p67hIyfrtNKN+6XZjIB0V2gYGBjdIocuwtmj4="; + rev = "6ab01f413926a2428c3c799a345f69b4807d5595"; + sha256 = "sha256-MN8nJ9O6oCZbbrE/ROvIlogtXJiSLsVZxoVXJUTeSHs="; }; nativeBuildInputs = [ diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index 268be3583d..1173d75b27 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H #include "realm-execution/device_specific_managed_per_device_ff_handle.h" -#include "realm-execution/hash/processor.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include diff --git a/lib/realm-execution/include/realm-execution/hash/processor.h b/lib/realm-execution/include/realm-execution/hash/processor.h deleted file mode 100644 index efe6e6186b..0000000000 --- a/lib/realm-execution/include/realm-execution/hash/processor.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H - -#include "realm-execution/realm.h" -#include - -#ifdef FLEXFLOW_USE_PREALM - -namespace std { - -template <> -struct hash<::FlexFlow::Realm::Processor> { - size_t operator()(::FlexFlow::Realm::Processor const &p) const; -}; - -} // namespace std - -#endif - -#endif diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index 09709201ce..95530c0eee 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -18,6 +18,9 @@ TensorInstanceBacking perform_instance_allocation( &preallocated, RealmContext &ctx); +void destroy_instances(TensorInstanceBacking const &instances, + Realm::Event precondition); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index e754fbbf5c..db338e4e4b 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -35,6 +35,7 @@ struct PCGInstance { PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, std::optional logit_grad_tensor); + ~PCGInstance(); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; TensorInstanceBacking const &get_tensor_instance_backing() const; diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h index fe83e69583..b6913e66f5 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H -// #define FLEXFLOW_USE_PREALM +#define FLEXFLOW_USE_PREALM #ifdef FLEXFLOW_USE_PREALM #include diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml index 150801367d..5b70c6888b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_instance.dtg.toml @@ -12,6 +12,12 @@ includes = [ "realm-execution/realm.h", ] +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + [[fields]] -name = "id" -type = "::FlexFlow::Realm::RegionInstance::id_t" +name = "instance" +# Realm::RegionInstance has hidden fields in PRealm so we need to encode it as bytes +type = "std::vector" diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml index 6c43990282..b8533dbcc9 100644 --- a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -16,8 +16,8 @@ includes = [ src_includes = [ "realm-execution/fmt/realm_event.h", "realm-execution/fmt/realm_instance.h", - "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", ] [[fields]] diff --git a/lib/realm-execution/src/realm-execution/hash/processor.cc b/lib/realm-execution/src/realm-execution/hash/processor.cc deleted file mode 100644 index 5a8624f676..0000000000 --- a/lib/realm-execution/src/realm-execution/hash/processor.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "realm-execution/hash/processor.h" -#include - -#ifdef FLEXFLOW_USE_PREALM - -namespace std { - -size_t hash<::FlexFlow::Realm::Processor>::operator()( - ::FlexFlow::Realm::Processor const &p) const { - return hash<::FlexFlow::Realm::Processor::id_t>{}(p.id); -} - -} // namespace std - -#endif diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 797455573c..e003e5b71a 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -72,4 +72,11 @@ TensorInstanceBacking perform_instance_allocation( return result; } +void destroy_instances(TensorInstanceBacking const &instances, + Realm::Event precondition) { + for (auto const &[instance, ready] : values(instances.backing)) { + instance.destroy(Realm::Event::merge_events(precondition, ready)); + } +} + } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 4b08e9a430..d78ed68988 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -39,6 +39,11 @@ PCGInstance::PCGInstance( device_state_backing(device_state_backing), optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} +PCGInstance::~PCGInstance() { + destroy_instances(this->tensor_instance_backing, + ctx.get_outstanding_events()); +} + RealmContext &PCGInstance::get_realm_context() { return this->ctx; } diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index fa056d6f33..a9c134af01 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -17,10 +17,15 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, void const *, size_t, Realm::Processor)) { + Realm::Processor::TaskFuncID realm_task_id = + get_realm_task_id_for_task_id(func_id); +#ifdef FLEXFLOW_USE_PREALM + Realm::prealm_task_name(realm_task_id, fmt::format("{}", func_id)); +#endif return Realm::Processor::register_task_by_kind( target_kind, /*global=*/false, - get_realm_task_id_for_task_id(func_id), + realm_task_id, Realm::CodeDescriptor(task_body), Realm::ProfilingRequestSet()); } diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc index f2d42a96ca..0e58d6e36c 100644 --- a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_instance.cc @@ -1,15 +1,23 @@ #include "realm-execution/tasks/serializer/serializable_realm_instance.h" +#include "utils/exception.h" +#include namespace FlexFlow { +// Realm::RegionInstance is trivially copyable so it's safe to treat it as bytes +static_assert(std::is_trivially_copy_constructible_v); + SerializableRealmInstance realm_instance_to_serializable(Realm::RegionInstance const &inst) { - return SerializableRealmInstance{inst.id}; + uint8_t const *data = reinterpret_cast(&inst); + return SerializableRealmInstance{ + std::vector{data, data + sizeof(inst)}}; } Realm::RegionInstance realm_instance_from_serializable(SerializableRealmInstance const &inst) { - return Realm::RegionInstance{inst.id}; + ASSERT(inst.instance.size() == sizeof(Realm::RegionInstance)); + return *reinterpret_cast(inst.instance.data()); } } // namespace FlexFlow From 1dff7afbfb722881e9910ad6bd548c7ab5e3de8b Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 24 Feb 2026 14:08:29 -0800 Subject: [PATCH 86/88] Update to Realm main commit for PRealm. --- .flake/pkgs/realm.nix | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix index b7985b497d..336b1c050c 100644 --- a/.flake/pkgs/realm.nix +++ b/.flake/pkgs/realm.nix @@ -13,13 +13,13 @@ in stdenv.mkDerivation rec { pname = "realm"; - version = "2026-02-22-prealm"; + version = "2026-02-24"; src = fetchFromGitHub { owner = "StanfordLegion"; repo = "realm"; - rev = "6ab01f413926a2428c3c799a345f69b4807d5595"; - sha256 = "sha256-MN8nJ9O6oCZbbrE/ROvIlogtXJiSLsVZxoVXJUTeSHs="; + rev = "42f7484a80e0bdacaf47d9a758822f5327348dd0"; + sha256 = "sha256-IHiokPmTjEV5df3fr1Xubuyt2N1CFI2fA7Q2TsbxS3Y="; }; nativeBuildInputs = [ From 2abfb8d44684a42fc7f45687cf04b13a9096b89e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 25 Feb 2026 13:46:29 -0800 Subject: [PATCH 87/88] Add a switch to control PRealm. --- CMakeLists.txt | 1 + cmake/flexflow-utils.cmake | 1 + lib/realm-execution/include/realm-execution/realm.h | 6 ++---- .../src/realm-execution/tasks/realm_task_registry.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c2239cdcb0..df60e24d72 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ set(FF_MAX_NUM_TASK_REGIONS "20" CACHE STRING set(FF_MAX_NUM_TASK_ARGUMENTS "5" CACHE STRING "Maximum number of arguments that can be declared in a TaskSignature") option(FF_USE_NCCL "Run FlexFlow with NCCL" OFF) +option(FF_USE_PREALM "Build with PRealm profiling interface" ON) option(FF_USE_ALL_PREBUILT_LIBRARIES "Enable use of all pre-compiled libraries, if available" OFF) option(FF_USE_PYTHON "Enable Python" ON) option(FF_BUILD_FROM_PYPI "Build from pypi" OFF) diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index ef5d6d9d11..795668e32a 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -17,6 +17,7 @@ function(define_ff_vars target) MAX_NUM_FUSED_TENSORS=${FF_MAX_NUM_FUSED_TENSORS} MAX_NUM_WORKERS=${FF_MAX_NUM_WORKERS} FF_USE_NCCL=${FF_USE_NCCL} + FF_USE_PREALM=${FF_USE_PREALM} MAX_TENSOR_DIM=${FF_MAX_DIM} MAX_NUM_TASK_REGIONS=${FF_MAX_NUM_TASK_REGIONS} MAX_NUM_TASK_ARGUMENTS=${FF_MAX_NUM_TASK_ARGUMENTS} diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h index b6913e66f5..814132d355 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -1,9 +1,7 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H -#define FLEXFLOW_USE_PREALM - -#ifdef FLEXFLOW_USE_PREALM +#ifdef FF_USE_PREALM #include #else #include @@ -11,7 +9,7 @@ namespace FlexFlow { -#ifdef FLEXFLOW_USE_PREALM +#ifdef FF_USE_PREALM namespace Realm = ::PRealm; #else namespace Realm = ::Realm; diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index a9c134af01..09d99655c0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -19,7 +19,7 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Processor)) { Realm::Processor::TaskFuncID realm_task_id = get_realm_task_id_for_task_id(func_id); -#ifdef FLEXFLOW_USE_PREALM +#ifdef FF_USE_PREALM Realm::prealm_task_name(realm_task_id, fmt::format("{}", func_id)); #endif return Realm::Processor::register_task_by_kind( From 9dd1f12466fe1b3a4bddd06cdf9122340810bfb8 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 27 Feb 2026 10:15:00 -0800 Subject: [PATCH 88/88] Update rect constructor. --- .../include/realm-execution/realm_context.h | 1 + .../src/realm-execution/realm_context.cc | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index b8baad41b9..b018a04a87 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -4,6 +4,7 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" +#include "op-attrs/tensor_shape.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 3427e8cbee..10ed07118b 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -137,12 +137,13 @@ Realm::Event RealmContext::collective_spawn_task(Realm::Processor target_proc, return result; } -template -static Realm::Rect rect_from_dims(TensorDims const &dims) { +template +static Realm::Rect rect_from_dims(TensorDims const &dims) { std::vector values{dims.ff_ordered.begin(), dims.ff_ordered.end()}; - return Realm::Rect{Realm::Point::ZEROES(), - Realm::Point{values.data()} - - Realm::Point::ONES()}; + ASSERT(values.size() == N); + return Realm::Rect{Realm::Point::ZEROES(), + Realm::Point{values.data()} - + Realm::Point::ONES()}; } std::pair