From 518ad39a13161d9bcc373c5c5609bef1f7549dd6 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 27 Nov 2025 23:18:50 +0000 Subject: [PATCH 1/8] Zero-copy observations --- .../python/ale_vector_python_interface.cpp | 214 +++++++----------- .../python/ale_vector_python_interface.hpp | 24 +- src/ale/vector/async_vectorizer.hpp | 151 ++++++++---- src/ale/vector/preprocessed_env.hpp | 59 +++++ src/ale/vector/utils.hpp | 201 +++++++++------- tests/python/test_atari_vector_env.py | 1 + 6 files changed, 390 insertions(+), 260 deletions(-) diff --git a/src/ale/python/ale_vector_python_interface.cpp b/src/ale/python/ale_vector_python_interface.cpp index 8e29d8b93..3f063000f 100644 --- a/src/ale/python/ale_vector_python_interface.cpp +++ b/src/ale/python/ale_vector_python_interface.cpp @@ -40,20 +40,30 @@ void init_vector_module(nb::module_& m) { .def("reset", [](ale::vector::ALEVectorInterface& self, const std::vector reset_indices, const std::vector reset_seeds) { // Call C++ reset method with GIL released nb::gil_scoped_release release; - auto timesteps = self.reset(reset_indices, reset_seeds); + auto result = self.reset(reset_indices, reset_seeds); nb::gil_scoped_acquire acquire; // Get shape information - int batch_size = timesteps.size(); - auto obs_shape = self.get_observation_shape(); - int stack_num = std::get<0>(obs_shape); - int height = std::get<1>(obs_shape); - int width = std::get<2>(obs_shape); - int channels = self.is_grayscale() ? 1 : 3; + const int batch_size = result.batch_size; + const auto obs_shape = self.get_observation_shape(); + const int stack_num = std::get<0>(obs_shape); + const int height = std::get<1>(obs_shape); + const int width = std::get<2>(obs_shape); + const bool grayscale = self.is_grayscale(); + + // Wrap observation buffer - capsule takes ownership + nb::capsule obs_owner(result.obs_data, [](void *p) noexcept { + delete[] static_cast(p); + }); - // Create a single NumPy array for all observations - size_t obs_total_size = batch_size * stack_num * height * width * channels; - uint8_t* obs_data = new uint8_t[obs_total_size]; + nb::ndarray observations; + if (grayscale) { + size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; + observations = nb::ndarray(result.obs_data, 4, shape, obs_owner); + } else { + size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; + observations = nb::ndarray(result.obs_data, 5, shape, obs_owner); + } // Create arrays for info fields int* env_ids_data = new int[batch_size]; @@ -61,42 +71,21 @@ void init_vector_module(nb::module_& m) { int* frame_numbers_data = new int[batch_size]; int* episode_frame_numbers_data = new int[batch_size]; - // Copy data from observations to arrays - size_t obs_size = stack_num * height * width * channels; - for (int i = 0; i < batch_size; i++) { - const auto& timestep = timesteps[i]; - - // Copy screen data - std::memcpy( - obs_data + i * obs_size, - timestep.observation.data(), - obs_size * sizeof(uint8_t) - ); - - // Copy info fields - env_ids_data[i] = timestep.env_id; - lives_data[i] = timestep.lives; - frame_numbers_data[i] = timestep.frame_number; - episode_frame_numbers_data[i] = timestep.episode_frame_number; + for (size_t i = 0; i < batch_size; i++) { + const auto& meta = result.metadata[i]; + env_ids_data[i] = meta.env_id; + lives_data[i] = meta.lives; + frame_numbers_data[i] = meta.frame_number; + episode_frame_numbers_data[i] = meta.episode_frame_number; } // Create capsules for cleanup - nb::capsule obs_owner(obs_data, [](void *p) noexcept { delete[] (uint8_t *) p; }); - nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int *) p; }); - - // Create numpy arrays with allocated data - nb::ndarray observations; - if (self.is_grayscale()) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(obs_data, 4, shape, obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(obs_data, 5, shape, obs_owner); - } + nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + // Create numpy arrays size_t info_shape[1] = {(size_t)batch_size}; auto env_ids = nb::ndarray(env_ids_data, 1, info_shape, env_ids_owner); auto lives = nb::ndarray(lives_data, 1, info_shape, lives_owner); @@ -116,21 +105,34 @@ void init_vector_module(nb::module_& m) { self.send(action_ids, paddle_strengths); }) .def("recv", [](ale::vector::ALEVectorInterface& self) { - const auto timesteps = self.recv(); + // Release GIL while waiting for workers + nb::gil_scoped_release release; + auto result = self.recv(); nb::gil_scoped_acquire acquire; - // Get shape information - int batch_size = timesteps.size(); + // Get shape info const auto shape_info = self.get_observation_shape(); - int stack_num = std::get<0>(shape_info); - int height = std::get<1>(shape_info); - int width = std::get<2>(shape_info); - int channels = self.is_grayscale() ? 1 : 3; - ale::vector::AutoresetMode autoreset_mode = self.get_autoreset_mode(); + const int stack_num = std::get<0>(shape_info); + const int height = std::get<1>(shape_info); + const int width = std::get<2>(shape_info); + const int batch_size = result.batch_size; + const bool grayscale = self.is_grayscale(); - // Allocate memory for arrays - size_t obs_total_size = batch_size * stack_num * height * width * channels; - uint8_t* obs_data = new uint8_t[obs_total_size]; + // Wrap obs buffer - capsule takes ownership and will delete[] + nb::capsule obs_owner(result.obs_data, [](void *p) noexcept { + delete[] static_cast(p); + }); + + nb::ndarray observations; + if (grayscale) { + size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; + observations = nb::ndarray(result.obs_data, 4, shape, obs_owner); + } else { + size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; + observations = nb::ndarray(result.obs_data, 5, shape, obs_owner); + } + + // Allocate metadata arrays int* rewards_data = new int[batch_size]; bool* terminations_data = new bool[batch_size]; bool* truncations_data = new bool[batch_size]; @@ -139,48 +141,27 @@ void init_vector_module(nb::module_& m) { int* frame_numbers_data = new int[batch_size]; int* episode_frame_numbers_data = new int[batch_size]; - // Copy data from timesteps to arrays - const size_t obs_size = stack_num * height * width * channels; - for (int i = 0; i < batch_size; i++) { - const auto& timestep = timesteps[i]; - - // Copy screen data - std::memcpy( - obs_data + i * obs_size, - timestep.observation.data(), - obs_size * sizeof(uint8_t) - ); - - // Copy other fields - rewards_data[i] = timestep.reward; - terminations_data[i] = timestep.terminated; - truncations_data[i] = timestep.truncated; - env_ids_data[i] = timestep.env_id; - lives_data[i] = timestep.lives; - frame_numbers_data[i] = timestep.frame_number; - episode_frame_numbers_data[i] = timestep.episode_frame_number; + for (size_t i = 0; i < batch_size; i++) { + const auto& meta = result.metadata[i]; + rewards_data[i] = meta.reward; + terminations_data[i] = meta.terminated; + truncations_data[i] = meta.truncated; + env_ids_data[i] = meta.env_id; + lives_data[i] = meta.lives; + frame_numbers_data[i] = meta.frame_number; + episode_frame_numbers_data[i] = meta.episode_frame_number; } - // Create capsules for cleanup - nb::capsule obs_owner(obs_data, [](void *p) noexcept { delete[] (uint8_t *) p; }); - nb::capsule rewards_owner(rewards_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule terminations_owner(terminations_data, [](void *p) noexcept { delete[] (bool *) p; }); - nb::capsule truncations_owner(truncations_data, [](void *p) noexcept { delete[] (bool *) p; }); - nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int *) p; }); - nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int *) p; }); - - // Create numpy arrays with allocated data - nb::ndarray observations; - if (self.is_grayscale()) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(obs_data, 4, shape, obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(obs_data, 5, shape, obs_owner); - } + // Create capsules + nb::capsule rewards_owner(rewards_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule terminations_owner(terminations_data, [](void *p) noexcept { delete[] (bool*)p; }); + nb::capsule truncations_owner(truncations_data, [](void *p) noexcept { delete[] (bool*)p; }); + nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + // Create numpy arrays size_t info_shape[1] = {(size_t)batch_size}; auto rewards = nb::ndarray(rewards_data, 1, info_shape, rewards_owner); auto terminations = nb::ndarray(terminations_data, 1, info_shape, terminations_owner); @@ -190,47 +171,28 @@ void init_vector_module(nb::module_& m) { auto frame_numbers = nb::ndarray(frame_numbers_data, 1, info_shape, frame_numbers_owner); auto episode_frame_numbers = nb::ndarray(episode_frame_numbers_data, 1, info_shape, episode_frame_numbers_owner); - // Create info dict + // Build info dict nb::dict info; info["env_id"] = env_ids; info["lives"] = lives; info["frame_number"] = frame_numbers; info["episode_frame_number"] = episode_frame_numbers; - if (autoreset_mode == ale::vector::AutoresetMode::SameStep) { - bool any_terminated = std::any_of(terminations_data, terminations_data + batch_size, [](bool b) { return b; }); - bool any_truncated = std::any_of(truncations_data, truncations_data + batch_size, [](bool b) { return b; }); - - if (any_terminated || any_truncated) { - uint8_t* final_obs_data = new uint8_t[obs_total_size]; - - for (int i = 0; i < batch_size; i++) { - const auto& timestep = timesteps[i]; - - // Use final_observation if available, otherwise use current observation - const std::vector* obs_src = (timestep.terminated || timestep.truncated) ? - timestep.final_observation : ×tep.observation; - - std::memcpy( - final_obs_data + i * obs_size, - obs_src->data(), - obs_size * sizeof(uint8_t) - ); - } - - nb::capsule final_obs_owner(final_obs_data, [](void *p) noexcept { delete[] (uint8_t *) p; }); - - nb::ndarray final_observations; - if (self.is_grayscale()) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - final_observations = nb::ndarray(final_obs_data, 4, shape, final_obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - final_observations = nb::ndarray(final_obs_data, 5, shape, final_obs_owner); - } - - info["final_obs"] = final_observations; + // Handle final_obs for SameStep mode + if (result.final_obs_data != nullptr) { + nb::capsule final_obs_owner(result.final_obs_data, [](void *p) noexcept { + delete[] static_cast(p); + }); + + nb::ndarray final_observations; + if (grayscale) { + size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; + final_observations = nb::ndarray(result.final_obs_data, 4, shape, final_obs_owner); + } else { + size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; + final_observations = nb::ndarray(result.final_obs_data, 5, shape, final_obs_owner); } + info["final_obs"] = final_observations; } return nb::make_tuple(observations, rewards, terminations, truncations, info); diff --git a/src/ale/python/ale_vector_python_interface.hpp b/src/ale/python/ale_vector_python_interface.hpp index 75970a5c1..12392b822 100644 --- a/src/ale/python/ale_vector_python_interface.hpp +++ b/src/ale/python/ale_vector_python_interface.hpp @@ -143,9 +143,9 @@ namespace ale::vector { * * @param reset_indices Vector of environment indices to be reset * @param reset_seeds Vector of environment seeds to use - * @return Timesteps from all environments after reset + * @return RecvResult with initial observations */ - std::vector reset(const std::vector &reset_indices, const std::vector &reset_seeds) { + RecvResult reset(const std::vector &reset_indices, const std::vector &reset_seeds) { vectorizer_->reset(reset_indices, reset_seeds); return recv(); } @@ -178,14 +178,15 @@ namespace ale::vector { } /** - * Returns the environment's data for the environments + * Returns the environment's data for the environments. + * Returns ownership of observation buffer to caller. */ - const std::vector recv() { - std::vector timesteps = vectorizer_->recv(); - for (size_t i = 0; i < timesteps.size(); i++) { - received_env_ids_[i] = timesteps[i].env_id; + RecvResult recv() { + RecvResult result = vectorizer_->recv(); + for (size_t i = 0; i < result.batch_size; i++) { + received_env_ids_[i] = result.metadata[i].env_id; } - return timesteps; + return result; } /** @@ -237,6 +238,13 @@ namespace ale::vector { return autoreset_mode_; } + /** + * Get the size of a single stacked observation in bytes. + */ + std::size_t get_stacked_obs_size() const { + return vectorizer_->get_stacked_obs_size(); + } + /** * Get the underlying vectorizer * diff --git a/src/ale/vector/async_vectorizer.hpp b/src/ale/vector/async_vectorizer.hpp index e8d8e214e..3b619801e 100644 --- a/src/ale/vector/async_vectorizer.hpp +++ b/src/ale/vector/async_vectorizer.hpp @@ -17,6 +17,17 @@ #endif namespace ale::vector { + /** + * Result from recv() - caller takes ownership of allocated buffers. + */ + struct RecvResult { + uint8_t* obs_data; // Newly allocated, caller owns + std::vector metadata; // Copied from internal buffer + uint8_t* final_obs_data; // nullptr or newly allocated, caller owns + std::vector has_final_obs; // Which slots have final_obs (uint8_t for compatibility) + std::size_t batch_size; // Number of results + }; + /** * AsyncVectorizer manages a collection of environments that can be stepped in parallel. * It handles the (async) distribution of actions to environments and collection of observations. @@ -45,8 +56,7 @@ namespace ale::vector { autoreset_mode_(autoreset_mode), stop_(false), action_queue_(new ActionQueue(num_envs_)), - state_buffer_(new StateBuffer(batch_size_, num_envs_)), - final_obs_storage_(num_envs_) { + pending_obs_buffer_(nullptr) { // Create environments envs_.resize(num_envs_); @@ -55,6 +65,9 @@ namespace ale::vector { } stacked_obs_size_ = envs_[0]->get_stacked_obs_size(); + // Create state buffer with observation size + state_buffer_ = std::make_unique(batch_size_, num_envs_, stacked_obs_size_); + // Setup worker threads const std::size_t processor_count = std::thread::hardware_concurrency(); if (num_threads <= 0) { @@ -98,6 +111,12 @@ namespace ale::vector { * @param seeds Vector of seeds to use on reset (use -1 to not change the environment's seed) */ void reset(const std::vector& reset_indices, const std::vector& seeds) { + // Allocate output buffer BEFORE enqueueing (prevents race condition) + const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; + pending_obs_buffer_ = new uint8_t[total_obs_size]; + state_buffer_->set_output_buffer(pending_obs_buffer_); + + // Prepare reset actions std::vector reset_actions; reset_actions.reserve(reset_indices.size()); @@ -112,6 +131,7 @@ namespace ale::vector { reset_actions.emplace_back(action); } + // Enqueue actions - workers can now safely write to buffer action_queue_->enqueue_bulk(reset_actions); } @@ -121,6 +141,12 @@ namespace ale::vector { * @param actions Vector of actions to send to the sub-environments */ void send(const std::vector& actions) { + // Allocate output buffer BEFORE enqueueing (prevents race condition) + const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; + pending_obs_buffer_ = new uint8_t[total_obs_size]; + state_buffer_->set_output_buffer(pending_obs_buffer_); + + // Prepare action slices std::vector action_slices; action_slices.reserve(actions.size()); @@ -135,30 +161,61 @@ namespace ale::vector { action_slices.emplace_back(action); } + // Enqueue actions - workers can now safely write to buffer action_queue_->enqueue_bulk(action_slices); } /** - * Receive timesteps from the environments - * This is the asynchronous version that waits for results after send() + * Receive timesteps from the environments. + * Returns ownership of allocated observation buffer to caller. * - * @return Vector of timesteps from the environments + * @return RecvResult containing observation data and metadata */ - const std::vector recv() { - std::vector timesteps = state_buffer_->collect(); - return timesteps; - } + RecvResult recv() { + // Wait for all workers to complete + state_buffer_->wait_for_batch(); + + // Build result + RecvResult result; + result.obs_data = pending_obs_buffer_; // Transfer ownership + result.batch_size = batch_size_; + pending_obs_buffer_ = nullptr; + + // Copy metadata (small - ~32 bytes per env) + result.metadata.resize(batch_size_); + std::memcpy( + result.metadata.data(), + state_buffer_->get_metadata(), + batch_size_ * sizeof(TimestepMetadata) + ); + + // Handle final_obs for SameStep mode + if (autoreset_mode_ == AutoresetMode::SameStep) { + const uint8_t* has_final = state_buffer_->get_has_final_obs(); + bool any_final = false; + for (std::size_t i = 0; i < batch_size_; i++) { + if (has_final[i]) { + any_final = true; + break; + } + } - /** - * Step the environments with actions and wait for results - * This is a convenience method that combines send() and recv() - * - * @param actions Vector of actions for the environments - * @return Vector of timesteps from the environments - */ - const std::vector step(const std::vector& actions) { - send(actions); - return recv(); + if (any_final) { + const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; + result.final_obs_data = new uint8_t[total_obs_size]; + std::memcpy(result.final_obs_data, state_buffer_->get_final_obs_buffer(), total_obs_size); + result.has_final_obs.assign(has_final, has_final + batch_size_); + } else { + result.final_obs_data = nullptr; + } + } else { + result.final_obs_data = nullptr; + } + + // Reset state buffer for next batch + state_buffer_->reset(); + + return result; } const int get_num_envs() const { @@ -187,15 +244,16 @@ namespace ale::vector { std::atomic stop_; // Signal to stop worker threads std::vector workers_; // Worker threads std::unique_ptr action_queue_; // Queue for actions - std::unique_ptr state_buffer_; // Queue for observations + std::unique_ptr state_buffer_; // Buffer for observations and metadata std::vector> envs_; // Environment instances - mutable std::vector> final_obs_storage_; // For same-step autoreset + uint8_t* pending_obs_buffer_; // Buffer allocated in send(), returned in recv() /** - * Worker thread function that processes environment steps + * Worker thread function that processes environment steps. + * Writes results directly to pre-allocated output buffer. */ - void worker_function() const { + void worker_function() { while (!stop_) { try { ActionSlice action = action_queue_->dequeue(); @@ -204,6 +262,10 @@ namespace ale::vector { } const int env_id = action.env_id; + + // Get write slot - pointers are into the pre-allocated output buffer + WriteSlot slot = state_buffer_->allocate_write_slot(env_id); + if (autoreset_mode_ == AutoresetMode::NextStep) { if (action.force_reset || envs_[env_id]->is_episode_over()) { envs_[env_id]->reset(); @@ -211,45 +273,44 @@ namespace ale::vector { envs_[env_id]->step(); } - // Get timestep and write to state buffer - Timestep timestep = envs_[env_id]->get_timestep(); - timestep.final_observation = nullptr; // Not used in NextStep mode - state_buffer_->write(timestep); + // Write directly to output buffer (single copy: linearize frame stack) + envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); + } else if (autoreset_mode_ == AutoresetMode::SameStep) { if (action.force_reset) { - // on standard `reset` envs_[env_id]->reset(); - Timestep timestep = envs_[env_id]->get_timestep(); - timestep.final_observation = nullptr; - state_buffer_->write(timestep); + envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); } else { envs_[env_id]->step(); - Timestep step_timestep = envs_[env_id]->get_timestep(); - // if episode over, autoreset if (envs_[env_id]->is_episode_over()) { - final_obs_storage_[env_id] = step_timestep.observation; + // Save final observation before reset + envs_[env_id]->write_observation_to(slot.final_obs_dest); + state_buffer_->mark_slot_has_final_obs(slot.slot_index); - envs_[env_id]->reset(); - Timestep reset_timestep = envs_[env_id]->get_timestep(); + // Capture pre-reset metadata + TimestepMetadata pre_reset_meta; + envs_[env_id]->write_metadata_to(pre_reset_meta); - reset_timestep.final_observation = &final_obs_storage_[env_id]; - reset_timestep.reward = step_timestep.reward; - reset_timestep.terminated = step_timestep.terminated; - reset_timestep.truncated = step_timestep.truncated; + // Reset and write new observation + envs_[env_id]->reset(); + envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); - // Write the reset timestep with the some of the step timestep data - state_buffer_->write(reset_timestep); + // Restore pre-reset reward/terminated/truncated + slot.meta->reward = pre_reset_meta.reward; + slot.meta->terminated = pre_reset_meta.terminated; + slot.meta->truncated = pre_reset_meta.truncated; } else { - step_timestep.final_observation = nullptr; - state_buffer_->write(step_timestep); + envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); } } } else { throw std::runtime_error("Invalid autoreset mode"); } + + state_buffer_->mark_complete(); + } catch (const std::exception& e) { - // Log error but continue processing std::cerr << "Error in worker thread: " << e.what() << std::endl; } } diff --git a/src/ale/vector/preprocessed_env.hpp b/src/ale/vector/preprocessed_env.hpp index a95996d97..c7c1ddea1 100644 --- a/src/ale/vector/preprocessed_env.hpp +++ b/src/ale/vector/preprocessed_env.hpp @@ -235,6 +235,65 @@ namespace ale::vector { reward_ = reward_clipping_ ? std::clamp(reward, -1, 1) : reward; } + /** + * Write timestep data directly to provided destinations. + * Avoids allocating intermediate vectors. + * + * @param obs_dest Pointer to write linearized observation (size: stack_num * obs_size) + * @param meta Reference to metadata struct to populate + */ + void write_timestep_to(uint8_t* obs_dest, TimestepMetadata& meta) const { + // Write metadata + meta.env_id = env_id_; + meta.reward = reward_; + meta.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); + meta.truncated = elapsed_step_ >= max_episode_steps_ && !meta.terminated; + meta.lives = lives_; + meta.frame_number = env_->getFrameNumber(); + meta.episode_frame_number = env_->getEpisodeFrameNumber(); + + // Linearize circular frame_stack directly to destination + for (int i = 0; i < stack_num_; ++i) { + const int src_idx = (frame_stack_idx_ + i) % stack_num_; + std::memcpy( + obs_dest + i * obs_size_, + frame_stack_.data() + src_idx * obs_size_, + obs_size_ + ); + } + } + + /** + * Write only observation to destination (for final_obs in SameStep mode). + * + * @param obs_dest Pointer to write linearized observation + */ + void write_observation_to(uint8_t* obs_dest) const { + for (int i = 0; i < stack_num_; ++i) { + const int src_idx = (frame_stack_idx_ + i) % stack_num_; + std::memcpy( + obs_dest + i * obs_size_, + frame_stack_.data() + src_idx * obs_size_, + obs_size_ + ); + } + } + + /** + * Write only metadata (used to capture state before reset in SameStep mode). + * + * @param meta Reference to metadata struct to populate + */ + void write_metadata_to(TimestepMetadata& meta) const { + meta.env_id = env_id_; + meta.reward = reward_; + meta.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); + meta.truncated = elapsed_step_ >= max_episode_steps_ && !meta.terminated; + meta.lives = lives_; + meta.frame_number = env_->getFrameNumber(); + meta.episode_frame_number = env_->getEpisodeFrameNumber(); + } + /** * Get the current observation */ diff --git a/src/ale/vector/utils.hpp b/src/ale/vector/utils.hpp index c703519fa..c692ce546 100644 --- a/src/ale/vector/utils.hpp +++ b/src/ale/vector/utils.hpp @@ -48,6 +48,30 @@ namespace ale::vector { std::vector* final_observation; // Screen pixel data for previous episode last observation with Autoresetmode == SameStep }; + /** + * Lightweight metadata without observation data. + * Used when observations are written directly to output buffer. + */ + struct TimestepMetadata { + int env_id; // ID of the environment + reward_t reward; // Reward received + bool terminated; // Whether the game ended + bool truncated; // Whether episode was truncated + int lives; // Remaining lives + int frame_number; // Frame number since game start + int episode_frame_number; // Frame number since episode start + }; + + /** + * WriteSlot provides destinations for workers to write data directly. + */ + struct WriteSlot { + int slot_index; // Index in the batch + uint8_t* obs_dest; // Pointer to write observation data + TimestepMetadata* meta; // Pointer to write metadata + uint8_t* final_obs_dest; // Pointer for final_obs (SameStep mode) + }; + /** * Observation format enumeration */ @@ -122,119 +146,134 @@ namespace ale::vector { }; /** - * StateBuffer handles the collection of timesteps from environments + * StateBuffer manages output buffers for vectorized environment results. + * + * The buffer is set externally before workers begin writing. + * Workers write directly to allocated slots, avoiding intermediate copies. * * Two modes of operation: - * 1. Ordered mode (batch_size == num_envs): Waits for all env_ids to be filled - * 2. Unordered mode (batch_size != num_envs): Uses circular buffer for continuous operation + * 1. Ordered mode (batch_size == num_envs): Slot index equals env_id + * 2. Unordered mode (batch_size != num_envs): Atomic slot allocation */ class StateBuffer { public: - StateBuffer(const std::size_t batch_size, const std::size_t num_envs) + StateBuffer(const std::size_t batch_size, const std::size_t num_envs, const std::size_t obs_size) : batch_size_(batch_size), num_envs_(num_envs), + obs_size_(obs_size), ordered_mode_(batch_size == num_envs), - timesteps_(num_envs_), + metadata_(batch_size), + final_obs_buffer_(batch_size * obs_size), + has_final_obs_(batch_size, false), + output_obs_buffer_(nullptr), count_(0), write_idx_(0), - read_idx_(0), - sem_ready_(0), // Initially no batches ready - sem_read_(1) { // Allow one reader at a time + sem_ready_(0), + sem_read_(1) {} + + /** + * Set the output buffer that workers will write observations into. + * MUST be called before enqueueing any actions that will use this buffer. + * + * @param obs_buffer Pointer to allocated buffer of size batch_size * obs_size + */ + void set_output_buffer(uint8_t* obs_buffer) { + output_obs_buffer_ = obs_buffer; } /** - * Write a timestep to the buffer - * Multiple threads can write simultaneously + * Allocate a write slot for a worker thread. + * Returns pointers for direct writing into the output buffer. + * + * Thread-safe: multiple workers can call simultaneously. + * + * @param env_id The environment ID requesting a slot + * @return WriteSlot with pointers into output buffers */ - void write(const Timestep& timestep) { + WriteSlot allocate_write_slot(int env_id) { + WriteSlot slot; + if (ordered_mode_) { - // In ordered mode, place timestep at env_id position - const int env_id = timestep.env_id; - timesteps_[env_id] = timestep; - - // Atomically increment count and check if batch is ready - const auto old_count = count_.fetch_add(1); - if (old_count + 1 == batch_size_) { - // Exactly one thread will see count == batch_size_ in ordered mode - sem_ready_.signal(1); - } + // In ordered mode, slot index equals env_id + slot.slot_index = env_id; } else { - // In unordered mode, use circular buffer - // Each thread gets a unique index atomically - const auto idx = write_idx_.fetch_add(1) % num_envs_; - timesteps_[idx] = timestep; - - // Atomically increment count and check if batch is ready - const auto old_count = count_.fetch_add(1); - // Signal if we just crossed a batch boundary - if ((old_count + 1) / batch_size_ > old_count / batch_size_) { - sem_ready_.signal(1); - } + // In unordered mode, atomically allocate next available slot + slot.slot_index = static_cast(write_idx_.fetch_add(1) % batch_size_); } + + slot.obs_dest = output_obs_buffer_ + slot.slot_index * obs_size_; + slot.meta = &metadata_[slot.slot_index]; + slot.final_obs_dest = final_obs_buffer_.data() + slot.slot_index * obs_size_; + + return slot; } /** - * Collect timesteps when ready and return them + * Mark that a slot has final observation data (for SameStep autoreset). * - * @return Vector of timesteps + * @param slot_index The slot index to mark */ - std::vector collect() { - // Wait until a batch is ready - while (!sem_ready_.wait()) {} - - // Acquire read semaphore - while (!sem_read_.wait()) {} - - // Collect the results - std::vector result; - result.reserve(batch_size_); - - if (ordered_mode_) { - // In ordered mode, read in env_id order - for (size_t i = 0; i < batch_size_; ++i) { - result.push_back(std::move(timesteps_[i])); - } + void mark_slot_has_final_obs(int slot_index) { + has_final_obs_[slot_index] = true; + } - // Reset count for ordered mode (all items consumed) - count_.store(0); - } else { - // In unordered mode, read from circular buffer - for (size_t i = 0; i < batch_size_; ++i) { - const auto idx = read_idx_.fetch_add(1) % num_envs_; - result.push_back(std::move(timesteps_[idx])); - } - - // Atomically decrease count by batch_size_ - count_.fetch_sub(batch_size_); + /** + * Mark a slot as complete. Called by worker after writing all data. + * When all slots are complete, signals that batch is ready. + */ + void mark_complete() { + const auto old_count = count_.fetch_add(1); + if (old_count + 1 == batch_size_) { + sem_ready_.signal(1); } + } - // Release read semaphore - sem_read_.signal(1); - - return result; + /** + * Wait for batch to complete. Blocks until all slots are filled. + */ + void wait_for_batch() { + while (!sem_ready_.wait()) {} } /** - * Get the number of timesteps currently buffered + * Reset state for next batch. Must be called after collecting results. */ - size_t filled_timesteps() const { - return count_.load(); + void reset() { + count_.store(0); + write_idx_.store(0); + std::fill(has_final_obs_.begin(), has_final_obs_.end(), false); + output_obs_buffer_ = nullptr; } + // Accessors + TimestepMetadata* get_metadata() { return metadata_.data(); } + const TimestepMetadata* get_metadata() const { return metadata_.data(); } + uint8_t* get_final_obs_buffer() { return final_obs_buffer_.data(); } + const uint8_t* get_final_obs_buffer() const { return final_obs_buffer_.data(); } + uint8_t* get_has_final_obs() { return has_final_obs_.data(); } + const uint8_t* get_has_final_obs() const { return has_final_obs_.data(); } + std::size_t get_batch_size() const { return batch_size_; } + std::size_t get_obs_size() const { return obs_size_; } + private: - const std::size_t batch_size_; // Size of each batch - const std::size_t num_envs_; // Number of environments - const bool ordered_mode_; // Whether we're in ordered mode - std::vector timesteps_; // Buffer for timesteps - - // Atomic counters for lock-free operations - std::atomic count_; // Current count of available timesteps - std::atomic write_idx_; // Write position (for unordered mode) - std::atomic read_idx_; // Read position (for unordered mode) - - // Semaphores for coordination - moodycamel::LightweightSemaphore sem_ready_; // Signals when a batch is ready for collection - moodycamel::LightweightSemaphore sem_read_; // Controls access to read operations + const std::size_t batch_size_; + const std::size_t num_envs_; + const std::size_t obs_size_; + const bool ordered_mode_; + + // Internal storage for metadata and final observations + std::vector metadata_; + std::vector final_obs_buffer_; + std::vector has_final_obs_; // uint8_t instead of bool for .data() access + + // External output buffer (set via set_output_buffer) + uint8_t* output_obs_buffer_; + + // Synchronization + std::atomic count_; + std::atomic write_idx_; + moodycamel::LightweightSemaphore sem_ready_; + moodycamel::LightweightSemaphore sem_read_; }; } diff --git a/tests/python/test_atari_vector_env.py b/tests/python/test_atari_vector_env.py index be951cd3e..e11c9531a 100644 --- a/tests/python/test_atari_vector_env.py +++ b/tests/python/test_atari_vector_env.py @@ -434,6 +434,7 @@ def test_batch_size_async( ) async_env_timestep[async_env_ids] += 1 + assert np.all(async_env_timestep > rollout_length / (num_envs * 2)), async_env_timestep sync_envs.close() async_envs.close() From 2cb0462125b081d3ce7d01791a38d7a3281e5e2e Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Fri, 28 Nov 2025 12:01:24 +0000 Subject: [PATCH 2/8] Improve implementation --- .../python/ale_vector_python_interface.cpp | 118 +++------ .../python/ale_vector_python_interface.hpp | 10 +- src/ale/vector/async_vectorizer.hpp | 239 +++++++++++++----- src/ale/vector/preprocessed_env.hpp | 99 ++++---- src/ale/vector/utils.hpp | 155 +++++++----- 5 files changed, 359 insertions(+), 262 deletions(-) diff --git a/src/ale/python/ale_vector_python_interface.cpp b/src/ale/python/ale_vector_python_interface.cpp index 3f063000f..9fa56da94 100644 --- a/src/ale/python/ale_vector_python_interface.cpp +++ b/src/ale/python/ale_vector_python_interface.cpp @@ -52,45 +52,31 @@ void init_vector_module(nb::module_& m) { const bool grayscale = self.is_grayscale(); // Wrap observation buffer - capsule takes ownership - nb::capsule obs_owner(result.obs_data, [](void *p) noexcept { + nb::capsule obs_owner(result.observations, [](void *p) noexcept { delete[] static_cast(p); }); nb::ndarray observations; if (grayscale) { size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(result.obs_data, 4, shape, obs_owner); + observations = nb::ndarray(result.observations, 4, shape, obs_owner); } else { size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(result.obs_data, 5, shape, obs_owner); + observations = nb::ndarray(result.observations, 5, shape, obs_owner); } - // Create arrays for info fields - int* env_ids_data = new int[batch_size]; - int* lives_data = new int[batch_size]; - int* frame_numbers_data = new int[batch_size]; - int* episode_frame_numbers_data = new int[batch_size]; - - for (size_t i = 0; i < batch_size; i++) { - const auto& meta = result.metadata[i]; - env_ids_data[i] = meta.env_id; - lives_data[i] = meta.lives; - frame_numbers_data[i] = meta.frame_number; - episode_frame_numbers_data[i] = meta.episode_frame_number; - } - - // Create capsules for cleanup - nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + // Create capsules - ownership transferred from BatchData + nb::capsule env_ids_owner(result.env_ids, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule lives_owner(result.lives, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule frame_numbers_owner(result.frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule episode_frame_numbers_owner(result.episode_frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - // Create numpy arrays + // Create numpy arrays (zero-copy - direct from BatchData) size_t info_shape[1] = {(size_t)batch_size}; - auto env_ids = nb::ndarray(env_ids_data, 1, info_shape, env_ids_owner); - auto lives = nb::ndarray(lives_data, 1, info_shape, lives_owner); - auto frame_numbers = nb::ndarray(frame_numbers_data, 1, info_shape, frame_numbers_owner); - auto episode_frame_numbers = nb::ndarray(episode_frame_numbers_data, 1, info_shape, episode_frame_numbers_owner); + auto env_ids = nb::ndarray(result.env_ids, 1, info_shape, env_ids_owner); + auto lives = nb::ndarray(result.lives, 1, info_shape, lives_owner); + auto frame_numbers = nb::ndarray(result.frame_numbers, 1, info_shape, frame_numbers_owner); + auto episode_frame_numbers = nb::ndarray(result.episode_frame_numbers, 1, info_shape, episode_frame_numbers_owner); // Create info dict nb::dict info; @@ -111,65 +97,42 @@ void init_vector_module(nb::module_& m) { nb::gil_scoped_acquire acquire; // Get shape info - const auto shape_info = self.get_observation_shape(); - const int stack_num = std::get<0>(shape_info); - const int height = std::get<1>(shape_info); - const int width = std::get<2>(shape_info); + const auto obs_shape_info = self.get_observation_shape(); + const int stack_num = std::get<0>(obs_shape_info); + const int height = std::get<1>(obs_shape_info); + const int width = std::get<2>(obs_shape_info); const int batch_size = result.batch_size; const bool grayscale = self.is_grayscale(); // Wrap obs buffer - capsule takes ownership and will delete[] - nb::capsule obs_owner(result.obs_data, [](void *p) noexcept { - delete[] static_cast(p); - }); - + nb::capsule obs_owner(result.observations, [](void *p) noexcept { delete[] static_cast(p); }); nb::ndarray observations; if (grayscale) { size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(result.obs_data, 4, shape, obs_owner); + observations = nb::ndarray(result.observations, 4, shape, obs_owner); } else { size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(result.obs_data, 5, shape, obs_owner); - } - - // Allocate metadata arrays - int* rewards_data = new int[batch_size]; - bool* terminations_data = new bool[batch_size]; - bool* truncations_data = new bool[batch_size]; - int* env_ids_data = new int[batch_size]; - int* lives_data = new int[batch_size]; - int* frame_numbers_data = new int[batch_size]; - int* episode_frame_numbers_data = new int[batch_size]; - - for (size_t i = 0; i < batch_size; i++) { - const auto& meta = result.metadata[i]; - rewards_data[i] = meta.reward; - terminations_data[i] = meta.terminated; - truncations_data[i] = meta.truncated; - env_ids_data[i] = meta.env_id; - lives_data[i] = meta.lives; - frame_numbers_data[i] = meta.frame_number; - episode_frame_numbers_data[i] = meta.episode_frame_number; + observations = nb::ndarray(result.observations, 5, shape, obs_owner); } - // Create capsules - nb::capsule rewards_owner(rewards_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule terminations_owner(terminations_data, [](void *p) noexcept { delete[] (bool*)p; }); - nb::capsule truncations_owner(truncations_data, [](void *p) noexcept { delete[] (bool*)p; }); - nb::capsule env_ids_owner(env_ids_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule lives_owner(lives_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule frame_numbers_owner(frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule episode_frame_numbers_owner(episode_frame_numbers_data, [](void *p) noexcept { delete[] (int*)p; }); + // Create capsules - ownership transferred from BatchData + nb::capsule rewards_owner(result.rewards, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule terminations_owner(result.terminations, [](void *p) noexcept { delete[] (bool*)p; }); + nb::capsule truncations_owner(result.truncations, [](void *p) noexcept { delete[] (bool*)p; }); + nb::capsule env_ids_owner(result.env_ids, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule lives_owner(result.lives, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule frame_numbers_owner(result.frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); + nb::capsule episode_frame_numbers_owner(result.episode_frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - // Create numpy arrays + // Create numpy arrays (zero-copy - direct from BatchData) size_t info_shape[1] = {(size_t)batch_size}; - auto rewards = nb::ndarray(rewards_data, 1, info_shape, rewards_owner); - auto terminations = nb::ndarray(terminations_data, 1, info_shape, terminations_owner); - auto truncations = nb::ndarray(truncations_data, 1, info_shape, truncations_owner); - auto env_ids = nb::ndarray(env_ids_data, 1, info_shape, env_ids_owner); - auto lives = nb::ndarray(lives_data, 1, info_shape, lives_owner); - auto frame_numbers = nb::ndarray(frame_numbers_data, 1, info_shape, frame_numbers_owner); - auto episode_frame_numbers = nb::ndarray(episode_frame_numbers_data, 1, info_shape, episode_frame_numbers_owner); + auto rewards = nb::ndarray(result.rewards, 1, info_shape, rewards_owner); + auto terminations = nb::ndarray(result.terminations, 1, info_shape, terminations_owner); + auto truncations = nb::ndarray(result.truncations, 1, info_shape, truncations_owner); + auto env_ids = nb::ndarray(result.env_ids, 1, info_shape, env_ids_owner); + auto lives = nb::ndarray(result.lives, 1, info_shape, lives_owner); + auto frame_numbers = nb::ndarray(result.frame_numbers, 1, info_shape, frame_numbers_owner); + auto episode_frame_numbers = nb::ndarray(result.episode_frame_numbers, 1, info_shape, episode_frame_numbers_owner); // Build info dict nb::dict info; @@ -179,18 +142,19 @@ void init_vector_module(nb::module_& m) { info["episode_frame_number"] = episode_frame_numbers; // Handle final_obs for SameStep mode - if (result.final_obs_data != nullptr) { - nb::capsule final_obs_owner(result.final_obs_data, [](void *p) noexcept { + if (result.final_observations != nullptr) { + // Wrap the buffer directly - workers have already filled in all slots + nb::capsule final_obs_owner(result.final_observations, [](void *p) noexcept { delete[] static_cast(p); }); nb::ndarray final_observations; if (grayscale) { size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - final_observations = nb::ndarray(result.final_obs_data, 4, shape, final_obs_owner); + final_observations = nb::ndarray(result.final_observations, 4, shape, final_obs_owner); } else { size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - final_observations = nb::ndarray(result.final_obs_data, 5, shape, final_obs_owner); + final_observations = nb::ndarray(result.final_observations, 5, shape, final_obs_owner); } info["final_obs"] = final_observations; } diff --git a/src/ale/python/ale_vector_python_interface.hpp b/src/ale/python/ale_vector_python_interface.hpp index 12392b822..acb807a6d 100644 --- a/src/ale/python/ale_vector_python_interface.hpp +++ b/src/ale/python/ale_vector_python_interface.hpp @@ -143,9 +143,9 @@ namespace ale::vector { * * @param reset_indices Vector of environment indices to be reset * @param reset_seeds Vector of environment seeds to use - * @return RecvResult with initial observations + * @return BatchData with initial observations */ - RecvResult reset(const std::vector &reset_indices, const std::vector &reset_seeds) { + BatchData reset(const std::vector &reset_indices, const std::vector &reset_seeds) { vectorizer_->reset(reset_indices, reset_seeds); return recv(); } @@ -181,10 +181,10 @@ namespace ale::vector { * Returns the environment's data for the environments. * Returns ownership of observation buffer to caller. */ - RecvResult recv() { - RecvResult result = vectorizer_->recv(); + BatchData recv() { + BatchData result = vectorizer_->recv(); for (size_t i = 0; i < result.batch_size; i++) { - received_env_ids_[i] = result.metadata[i].env_id; + received_env_ids_[i] = result.env_ids[i]; } return result; } diff --git a/src/ale/vector/async_vectorizer.hpp b/src/ale/vector/async_vectorizer.hpp index 3b619801e..f00f7dc5d 100644 --- a/src/ale/vector/async_vectorizer.hpp +++ b/src/ale/vector/async_vectorizer.hpp @@ -18,14 +18,20 @@ namespace ale::vector { /** - * Result from recv() - caller takes ownership of allocated buffers. + * Batch data from recv() - caller takes ownership of allocated buffers. */ - struct RecvResult { - uint8_t* obs_data; // Newly allocated, caller owns - std::vector metadata; // Copied from internal buffer - uint8_t* final_obs_data; // nullptr or newly allocated, caller owns - std::vector has_final_obs; // Which slots have final_obs (uint8_t for compatibility) - std::size_t batch_size; // Number of results + struct BatchData { + int* env_ids; // Newly allocated, caller owns + uint8_t* observations; // Newly allocated, caller owns + int* rewards; // Newly allocated, caller owns + bool* terminations; // Newly allocated, caller owns + bool* truncations; // Newly allocated, caller owns + int* lives; // Newly allocated, caller owns + int* frame_numbers; // Newly allocated, caller owns + int* episode_frame_numbers; // Newly allocated, caller owns + + uint8_t* final_observations; // nullptr or newly allocated, caller owns + std::size_t batch_size; // Number of results }; /** @@ -56,7 +62,15 @@ namespace ale::vector { autoreset_mode_(autoreset_mode), stop_(false), action_queue_(new ActionQueue(num_envs_)), - pending_obs_buffer_(nullptr) { + pending_obs_buffer_(nullptr), + pending_final_obs_(nullptr), + pending_env_ids_(nullptr), + pending_rewards_(nullptr), + pending_terminations_(nullptr), + pending_truncations_(nullptr), + pending_lives_(nullptr), + pending_frame_numbers_(nullptr), + pending_episode_frame_numbers_(nullptr) { // Create environments envs_.resize(num_envs_); @@ -111,11 +125,35 @@ namespace ale::vector { * @param seeds Vector of seeds to use on reset (use -1 to not change the environment's seed) */ void reset(const std::vector& reset_indices, const std::vector& seeds) { - // Allocate output buffer BEFORE enqueueing (prevents race condition) + // Allocate output buffers BEFORE enqueueing (prevents race condition) const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; pending_obs_buffer_ = new uint8_t[total_obs_size]; state_buffer_->set_output_buffer(pending_obs_buffer_); + // Allocate metadata buffers + pending_env_ids_ = new int[batch_size_]; + pending_rewards_ = new int[batch_size_]; + pending_terminations_ = new bool[batch_size_]; + pending_truncations_ = new bool[batch_size_]; + pending_lives_ = new int[batch_size_]; + pending_frame_numbers_ = new int[batch_size_]; + pending_episode_frame_numbers_ = new int[batch_size_]; + state_buffer_->set_metadata_buffers( + pending_env_ids_, + pending_rewards_, + pending_terminations_, + pending_truncations_, + pending_lives_, + pending_frame_numbers_, + pending_episode_frame_numbers_ + ); + + // In SameStep mode, also allocate final_obs buffer + if (autoreset_mode_ == AutoresetMode::SameStep) { + pending_final_obs_ = new uint8_t[total_obs_size]; + state_buffer_->set_final_obs_buffer(pending_final_obs_); + } + // Prepare reset actions std::vector reset_actions; reset_actions.reserve(reset_indices.size()); @@ -141,10 +179,34 @@ namespace ale::vector { * @param actions Vector of actions to send to the sub-environments */ void send(const std::vector& actions) { - // Allocate output buffer BEFORE enqueueing (prevents race condition) + // Allocate output buffers BEFORE enqueueing (prevents race condition) const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; + pending_obs_buffer_ = new uint8_t[total_obs_size]; + pending_env_ids_ = new int[batch_size_]; + pending_rewards_ = new int[batch_size_]; + pending_terminations_ = new bool[batch_size_]; + pending_truncations_ = new bool[batch_size_]; + pending_lives_ = new int[batch_size_]; + pending_frame_numbers_ = new int[batch_size_]; + pending_episode_frame_numbers_ = new int[batch_size_]; + state_buffer_->set_output_buffer(pending_obs_buffer_); + state_buffer_->set_metadata_buffers( + pending_env_ids_, + pending_rewards_, + pending_terminations_, + pending_truncations_, + pending_lives_, + pending_frame_numbers_, + pending_episode_frame_numbers_ + ); + + // In SameStep mode, also allocate final_obs buffer + if (autoreset_mode_ == AutoresetMode::SameStep) { + pending_final_obs_ = new uint8_t[total_obs_size]; + state_buffer_->set_final_obs_buffer(pending_final_obs_); + } // Prepare action slices std::vector action_slices; @@ -169,48 +231,35 @@ namespace ale::vector { * Receive timesteps from the environments. * Returns ownership of allocated observation buffer to caller. * - * @return RecvResult containing observation data and metadata + * @return BatchData containing observation data and metadata */ - RecvResult recv() { + BatchData recv() { // Wait for all workers to complete state_buffer_->wait_for_batch(); - // Build result - RecvResult result; - result.obs_data = pending_obs_buffer_; // Transfer ownership + // Build result - transfer ownership of all buffers (no copying!) + BatchData result; + result.observations = pending_obs_buffer_; + result.final_observations = pending_final_obs_; + result.env_ids = pending_env_ids_; + result.rewards = pending_rewards_; + result.terminations = pending_terminations_; + result.truncations = pending_truncations_; + result.lives = pending_lives_; + result.frame_numbers = pending_frame_numbers_; + result.episode_frame_numbers = pending_episode_frame_numbers_; result.batch_size = batch_size_; - pending_obs_buffer_ = nullptr; - - // Copy metadata (small - ~32 bytes per env) - result.metadata.resize(batch_size_); - std::memcpy( - result.metadata.data(), - state_buffer_->get_metadata(), - batch_size_ * sizeof(TimestepMetadata) - ); - - // Handle final_obs for SameStep mode - if (autoreset_mode_ == AutoresetMode::SameStep) { - const uint8_t* has_final = state_buffer_->get_has_final_obs(); - bool any_final = false; - for (std::size_t i = 0; i < batch_size_; i++) { - if (has_final[i]) { - any_final = true; - break; - } - } - if (any_final) { - const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; - result.final_obs_data = new uint8_t[total_obs_size]; - std::memcpy(result.final_obs_data, state_buffer_->get_final_obs_buffer(), total_obs_size); - result.has_final_obs.assign(has_final, has_final + batch_size_); - } else { - result.final_obs_data = nullptr; - } - } else { - result.final_obs_data = nullptr; - } + // Clear pending pointers (ownership transferred) + pending_obs_buffer_ = nullptr; + pending_final_obs_ = nullptr; + pending_env_ids_ = nullptr; + pending_rewards_ = nullptr; + pending_terminations_ = nullptr; + pending_truncations_ = nullptr; + pending_lives_ = nullptr; + pending_frame_numbers_ = nullptr; + pending_episode_frame_numbers_ = nullptr; // Reset state buffer for next batch state_buffer_->reset(); @@ -247,7 +296,16 @@ namespace ale::vector { std::unique_ptr state_buffer_; // Buffer for observations and metadata std::vector> envs_; // Environment instances - uint8_t* pending_obs_buffer_; // Buffer allocated in send(), returned in recv() + // Pending buffers allocated in send()/reset(), returned in recv() + uint8_t* pending_obs_buffer_; // Observations buffer + uint8_t* pending_final_obs_; // Final observations buffer (SameStep mode only) + int* pending_env_ids_; // Env IDs metadata buffer + int* pending_rewards_; // Rewards metadata buffer + bool* pending_terminations_; // Terminations metadata buffer + bool* pending_truncations_; // Truncations metadata buffer + int* pending_lives_; // Lives metadata buffer + int* pending_frame_numbers_; // Frame numbers metadata buffer + int* pending_episode_frame_numbers_; // Episode frame numbers metadata buffer /** * Worker thread function that processes environment steps. @@ -262,10 +320,6 @@ namespace ale::vector { } const int env_id = action.env_id; - - // Get write slot - pointers are into the pre-allocated output buffer - WriteSlot slot = state_buffer_->allocate_write_slot(env_id); - if (autoreset_mode_ == AutoresetMode::NextStep) { if (action.force_reset || envs_[env_id]->is_episode_over()) { envs_[env_id]->reset(); @@ -273,35 +327,86 @@ namespace ale::vector { envs_[env_id]->step(); } - // Write directly to output buffer (single copy: linearize frame stack) - envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); - + // Get write slot - pointers are into the pre-allocated output buffer (after the reset or step occurs) + WriteSlot slot = state_buffer_->allocate_write_slot(env_id); + envs_[env_id]->write_timestep_to( + slot.obs_dest, + slot.env_id_dest, + slot.reward_dest, + slot.terminated_dest, + slot.truncated_dest, + slot.lives_dest, + slot.frame_number_dest, + slot.episode_frame_number_dest + ); } else if (autoreset_mode_ == AutoresetMode::SameStep) { if (action.force_reset) { envs_[env_id]->reset(); - envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); + + // Get write slot - pointers are into the pre-allocated output buffer (after the force reset) + WriteSlot slot = state_buffer_->allocate_write_slot(env_id); + envs_[env_id]->write_timestep_to( + slot.obs_dest, + slot.env_id_dest, + slot.reward_dest, + slot.terminated_dest, + slot.truncated_dest, + slot.lives_dest, + slot.frame_number_dest, + slot.episode_frame_number_dest + ); } else { envs_[env_id]->step(); + // Get write slot - pointers are into the pre-allocated output buffer (after the step) + WriteSlot slot = state_buffer_->allocate_write_slot(env_id); + if (envs_[env_id]->is_episode_over()) { - // Save final observation before reset + // Write current (final) observation before reset envs_[env_id]->write_observation_to(slot.final_obs_dest); - state_buffer_->mark_slot_has_final_obs(slot.slot_index); - // Capture pre-reset metadata - TimestepMetadata pre_reset_meta; - envs_[env_id]->write_metadata_to(pre_reset_meta); + // Capture pre-reset metadata temporarily (for reward/terminated/truncated) + int pre_reward; + bool pre_terminated, pre_truncated; + envs_[env_id]->write_metadata_to( + slot.env_id_dest, + &pre_reward, + &pre_terminated, + &pre_truncated, + slot.lives_dest, + slot.frame_number_dest, + slot.episode_frame_number_dest + ); // Reset and write new observation envs_[env_id]->reset(); - envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); + envs_[env_id]->write_timestep_to( + slot.obs_dest, + slot.env_id_dest, // overwrites with same value + slot.reward_dest, + slot.terminated_dest, + slot.truncated_dest, + slot.lives_dest, // overwrites with reset lives + slot.frame_number_dest, + slot.episode_frame_number_dest + ); // Restore pre-reset reward/terminated/truncated - slot.meta->reward = pre_reset_meta.reward; - slot.meta->terminated = pre_reset_meta.terminated; - slot.meta->truncated = pre_reset_meta.truncated; + *slot.reward_dest = pre_reward; + *slot.terminated_dest = pre_terminated; + *slot.truncated_dest = pre_truncated; } else { - envs_[env_id]->write_timestep_to(slot.obs_dest, *slot.meta); + // No episode over + envs_[env_id]->write_timestep_to( + slot.obs_dest, + slot.env_id_dest, + slot.reward_dest, + slot.terminated_dest, + slot.truncated_dest, + slot.lives_dest, + slot.frame_number_dest, + slot.episode_frame_number_dest + ); } } } else { @@ -317,8 +422,8 @@ namespace ale::vector { } /** - * Set thread affinity for worker threads - */ + * Set thread affinity for worker threads + */ void set_thread_affinity(const int thread_affinity_offset, const int processor_count) { for (size_t tid = 0; tid < workers_.size(); ++tid) { size_t core_id = (thread_affinity_offset + tid) % processor_count; diff --git a/src/ale/vector/preprocessed_env.hpp b/src/ale/vector/preprocessed_env.hpp index c7c1ddea1..4e1fe842b 100644 --- a/src/ale/vector/preprocessed_env.hpp +++ b/src/ale/vector/preprocessed_env.hpp @@ -240,17 +240,32 @@ namespace ale::vector { * Avoids allocating intermediate vectors. * * @param obs_dest Pointer to write linearized observation (size: stack_num * obs_size) - * @param meta Reference to metadata struct to populate + * @param env_id_dest Pointer to write env_id + * @param reward_dest Pointer to write reward + * @param terminated_dest Pointer to write terminated flag + * @param truncated_dest Pointer to write truncated flag + * @param lives_dest Pointer to write lives + * @param frame_number_dest Pointer to write frame_number + * @param episode_frame_number_dest Pointer to write episode_frame_number */ - void write_timestep_to(uint8_t* obs_dest, TimestepMetadata& meta) const { - // Write metadata - meta.env_id = env_id_; - meta.reward = reward_; - meta.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); - meta.truncated = elapsed_step_ >= max_episode_steps_ && !meta.terminated; - meta.lives = lives_; - meta.frame_number = env_->getFrameNumber(); - meta.episode_frame_number = env_->getEpisodeFrameNumber(); + void write_timestep_to( + uint8_t* obs_dest, + int* env_id_dest, + int* reward_dest, + bool* terminated_dest, + bool* truncated_dest, + int* lives_dest, + int* frame_number_dest, + int* episode_frame_number_dest + ) const { + // Write metadata directly to BatchData arrays + *env_id_dest = env_id_; + *reward_dest = reward_; + *terminated_dest = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); + *truncated_dest = elapsed_step_ >= max_episode_steps_ && !(*terminated_dest); + *lives_dest = lives_; + *frame_number_dest = env_->getFrameNumber(); + *episode_frame_number_dest = env_->getEpisodeFrameNumber(); // Linearize circular frame_stack directly to destination for (int i = 0; i < stack_num_; ++i) { @@ -282,48 +297,30 @@ namespace ale::vector { /** * Write only metadata (used to capture state before reset in SameStep mode). * - * @param meta Reference to metadata struct to populate + * @param env_id_dest Pointer to write env_id + * @param reward_dest Pointer to write reward + * @param terminated_dest Pointer to write terminated flag + * @param truncated_dest Pointer to write truncated flag + * @param lives_dest Pointer to write lives + * @param frame_number_dest Pointer to write frame_number + * @param episode_frame_number_dest Pointer to write episode_frame_number */ - void write_metadata_to(TimestepMetadata& meta) const { - meta.env_id = env_id_; - meta.reward = reward_; - meta.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); - meta.truncated = elapsed_step_ >= max_episode_steps_ && !meta.terminated; - meta.lives = lives_; - meta.frame_number = env_->getFrameNumber(); - meta.episode_frame_number = env_->getEpisodeFrameNumber(); - } - - /** - * Get the current observation - */ - Timestep get_timestep() const { - Timestep timestep; - timestep.env_id = env_id_; - - timestep.reward = reward_; - timestep.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); - timestep.truncated = elapsed_step_ >= max_episode_steps_ && !timestep.terminated; - - timestep.lives = lives_; - timestep.frame_number = env_->getFrameNumber(); - timestep.episode_frame_number = env_->getEpisodeFrameNumber(); - - // Copy frames from oldest to newest into a single observation - timestep.observation.resize(obs_size_ * stack_num_); - for (int i = 0; i < stack_num_; ++i) { - int src_idx = (frame_stack_idx_ + i) % stack_num_; - std::memcpy( - timestep.observation.data() + i * obs_size_, - frame_stack_.data() + src_idx * obs_size_, - obs_size_ - ); - } - - // Initialize as nullptr and set in AsyncVectorizer if needed - timestep.final_observation = nullptr; - - return timestep; + void write_metadata_to( + int* env_id_dest, + int* reward_dest, + bool* terminated_dest, + bool* truncated_dest, + int* lives_dest, + int* frame_number_dest, + int* episode_frame_number_dest + ) const { + *env_id_dest = env_id_; + *reward_dest = reward_; + *terminated_dest = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); + *truncated_dest = elapsed_step_ >= max_episode_steps_ && !(*terminated_dest); + *lives_dest = lives_; + *frame_number_dest = env_->getFrameNumber(); + *episode_frame_number_dest = env_->getEpisodeFrameNumber(); } /** diff --git a/src/ale/vector/utils.hpp b/src/ale/vector/utils.hpp index c692ce546..a112d8999 100644 --- a/src/ale/vector/utils.hpp +++ b/src/ale/vector/utils.hpp @@ -32,44 +32,21 @@ namespace ale::vector { float paddle_strength; // Strength for paddle-based games (default: 1.0) }; - /** - * Timestep represents the output from an environment step - */ - struct Timestep { - int env_id; // ID of the environment this observation is from - std::vector observation; // Screen pixel data - reward_t reward; // Reward received in this step - bool terminated; // Whether the game ended - bool truncated; // Whether the episode was truncated due to a time limit - int lives; // Remaining lives in the game - int frame_number; // Frame number since the beginning of the game - int episode_frame_number; // Frame number since the beginning of the episode - - std::vector* final_observation; // Screen pixel data for previous episode last observation with Autoresetmode == SameStep - }; - - /** - * Lightweight metadata without observation data. - * Used when observations are written directly to output buffer. - */ - struct TimestepMetadata { - int env_id; // ID of the environment - reward_t reward; // Reward received - bool terminated; // Whether the game ended - bool truncated; // Whether episode was truncated - int lives; // Remaining lives - int frame_number; // Frame number since game start - int episode_frame_number; // Frame number since episode start - }; - /** * WriteSlot provides destinations for workers to write data directly. + * All pointers point into externally allocated BatchData arrays. */ struct WriteSlot { - int slot_index; // Index in the batch - uint8_t* obs_dest; // Pointer to write observation data - TimestepMetadata* meta; // Pointer to write metadata - uint8_t* final_obs_dest; // Pointer for final_obs (SameStep mode) + int slot_index; // Index in the batch + uint8_t* obs_dest; // Pointer to write observation data + int* env_id_dest; // Pointer to write env_id + int* reward_dest; // Pointer to write reward + bool* terminated_dest; // Pointer to write terminated flag + bool* truncated_dest; // Pointer to write truncated flag + int* lives_dest; // Pointer to write lives + int* frame_number_dest; // Pointer to write frame_number + int* episode_frame_number_dest; // Pointer to write episode_frame_number + uint8_t* final_obs_dest; // Pointer for final_obs (SameStep mode) }; /** @@ -162,10 +139,15 @@ namespace ale::vector { num_envs_(num_envs), obs_size_(obs_size), ordered_mode_(batch_size == num_envs), - metadata_(batch_size), - final_obs_buffer_(batch_size * obs_size), - has_final_obs_(batch_size, false), output_obs_buffer_(nullptr), + final_obs_buffer_(nullptr), + env_ids_buffer_(nullptr), + rewards_buffer_(nullptr), + terminations_buffer_(nullptr), + truncations_buffer_(nullptr), + lives_buffer_(nullptr), + frame_numbers_buffer_(nullptr), + episode_frame_numbers_buffer_(nullptr), count_(0), write_idx_(0), sem_ready_(0), @@ -181,6 +163,45 @@ namespace ale::vector { output_obs_buffer_ = obs_buffer; } + /** + * Set the final_obs output buffer for SameStep autoreset mode. + * + * @param final_obs_buffer Pointer to allocated buffer of size batch_size * obs_size + */ + void set_final_obs_buffer(uint8_t* final_obs_buffer) { + final_obs_buffer_ = final_obs_buffer; + } + + /** + * Set the metadata output buffers that workers will write into. + * MUST be called before enqueueing any actions that will use these buffers. + * + * @param env_ids Pointer to allocated array of size batch_size + * @param rewards Pointer to allocated array of size batch_size + * @param terminations Pointer to allocated array of size batch_size + * @param truncations Pointer to allocated array of size batch_size + * @param lives Pointer to allocated array of size batch_size + * @param frame_numbers Pointer to allocated array of size batch_size + * @param episode_frame_numbers Pointer to allocated array of size batch_size + */ + void set_metadata_buffers( + int* env_ids, + int* rewards, + bool* terminations, + bool* truncations, + int* lives, + int* frame_numbers, + int* episode_frame_numbers + ) { + env_ids_buffer_ = env_ids; + rewards_buffer_ = rewards; + terminations_buffer_ = terminations; + truncations_buffer_ = truncations; + lives_buffer_ = lives; + frame_numbers_buffer_ = frame_numbers; + episode_frame_numbers_buffer_ = episode_frame_numbers; + } + /** * Allocate a write slot for a worker thread. * Returns pointers for direct writing into the output buffer. @@ -201,20 +222,26 @@ namespace ale::vector { slot.slot_index = static_cast(write_idx_.fetch_add(1) % batch_size_); } - slot.obs_dest = output_obs_buffer_ + slot.slot_index * obs_size_; - slot.meta = &metadata_[slot.slot_index]; - slot.final_obs_dest = final_obs_buffer_.data() + slot.slot_index * obs_size_; + const int idx = slot.slot_index; - return slot; - } + // Set observation pointers + slot.obs_dest = output_obs_buffer_ + idx * obs_size_; - /** - * Mark that a slot has final observation data (for SameStep autoreset). - * - * @param slot_index The slot index to mark - */ - void mark_slot_has_final_obs(int slot_index) { - has_final_obs_[slot_index] = true; + // Set final_obs pointer (only used in SameStep mode, nullptr in NextStep mode) + slot.final_obs_dest = final_obs_buffer_ != nullptr + ? final_obs_buffer_ + idx * obs_size_ + : nullptr; + + // Set metadata pointers (directly into BatchData arrays) + slot.env_id_dest = &env_ids_buffer_[idx]; + slot.reward_dest = &rewards_buffer_[idx]; + slot.terminated_dest = &terminations_buffer_[idx]; + slot.truncated_dest = &truncations_buffer_[idx]; + slot.lives_dest = &lives_buffer_[idx]; + slot.frame_number_dest = &frame_numbers_buffer_[idx]; + slot.episode_frame_number_dest = &episode_frame_numbers_buffer_[idx]; + + return slot; } /** @@ -241,17 +268,18 @@ namespace ale::vector { void reset() { count_.store(0); write_idx_.store(0); - std::fill(has_final_obs_.begin(), has_final_obs_.end(), false); output_obs_buffer_ = nullptr; + final_obs_buffer_ = nullptr; + env_ids_buffer_ = nullptr; + rewards_buffer_ = nullptr; + terminations_buffer_ = nullptr; + truncations_buffer_ = nullptr; + lives_buffer_ = nullptr; + frame_numbers_buffer_ = nullptr; + episode_frame_numbers_buffer_ = nullptr; } // Accessors - TimestepMetadata* get_metadata() { return metadata_.data(); } - const TimestepMetadata* get_metadata() const { return metadata_.data(); } - uint8_t* get_final_obs_buffer() { return final_obs_buffer_.data(); } - const uint8_t* get_final_obs_buffer() const { return final_obs_buffer_.data(); } - uint8_t* get_has_final_obs() { return has_final_obs_.data(); } - const uint8_t* get_has_final_obs() const { return has_final_obs_.data(); } std::size_t get_batch_size() const { return batch_size_; } std::size_t get_obs_size() const { return obs_size_; } @@ -261,13 +289,16 @@ namespace ale::vector { const std::size_t obs_size_; const bool ordered_mode_; - // Internal storage for metadata and final observations - std::vector metadata_; - std::vector final_obs_buffer_; - std::vector has_final_obs_; // uint8_t instead of bool for .data() access - - // External output buffer (set via set_output_buffer) + // External output buffers (set via set_output_buffer / set_final_obs_buffer / set_metadata_buffers) uint8_t* output_obs_buffer_; + uint8_t* final_obs_buffer_; + int* env_ids_buffer_; + int* rewards_buffer_; + bool* terminations_buffer_; + bool* truncations_buffer_; + int* lives_buffer_; + int* frame_numbers_buffer_; + int* episode_frame_numbers_buffer_; // Synchronization std::atomic count_; From 0928fab36e30808acf62d10f9d3861b9bed59cdd Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sat, 29 Nov 2025 10:45:55 +0000 Subject: [PATCH 3/8] update --- .../python/ale_vector_python_interface.cpp | 40 +++++++++++++------ src/ale/python/vector_env.py | 12 +++--- src/ale/vector/async_vectorizer.hpp | 14 +++++++ src/ale/vector/utils.hpp | 21 +++++++++- tests/python/test_atari_vector_env.py | 15 ++----- 5 files changed, 72 insertions(+), 30 deletions(-) diff --git a/src/ale/python/ale_vector_python_interface.cpp b/src/ale/python/ale_vector_python_interface.cpp index 9fa56da94..06805cea4 100644 --- a/src/ale/python/ale_vector_python_interface.cpp +++ b/src/ale/python/ale_vector_python_interface.cpp @@ -141,22 +141,36 @@ void init_vector_module(nb::module_& m) { info["frame_number"] = frame_numbers; info["episode_frame_number"] = episode_frame_numbers; - // Handle final_obs for SameStep mode + // Handle final_obs for SameStep mode - only include if any env terminated/truncated if (result.final_observations != nullptr) { - // Wrap the buffer directly - workers have already filled in all slots - nb::capsule final_obs_owner(result.final_observations, [](void *p) noexcept { - delete[] static_cast(p); - }); - - nb::ndarray final_observations; - if (grayscale) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - final_observations = nb::ndarray(result.final_observations, 4, shape, final_obs_owner); + // Check if any environment actually terminated or truncated + bool any_done = false; + for (size_t i = 0; i < batch_size; i++) { + if (result.terminations[i] || result.truncations[i]) { + any_done = true; + break; + } + } + + if (any_done) { + // Wrap the buffer directly - workers have already filled in all slots + nb::capsule final_obs_owner(result.final_observations, [](void *p) noexcept { + delete[] static_cast(p); + }); + + nb::ndarray final_observations; + if (grayscale) { + size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; + final_observations = nb::ndarray(result.final_observations, 4, shape, final_obs_owner); + } else { + size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; + final_observations = nb::ndarray(result.final_observations, 5, shape, final_obs_owner); + } + info["final_obs"] = final_observations; } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - final_observations = nb::ndarray(result.final_observations, 5, shape, final_obs_owner); + // No environments terminated - delete the unused buffer + delete[] result.final_observations; } - info["final_obs"] = final_observations; } return nb::make_tuple(observations, rewards, terminations, truncations, info); diff --git a/src/ale/python/vector_env.py b/src/ale/python/vector_env.py index 182e39ab2..f69df4c7e 100644 --- a/src/ale/python/vector_env.py +++ b/src/ale/python/vector_env.py @@ -133,11 +133,11 @@ def __init__( self.batch_size = num_envs if batch_size == 0 else batch_size self.num_envs = num_envs - self.metadata["autoreset_mode"] = ( - autoreset_mode - if isinstance(autoreset_mode, AutoresetMode) - else AutoresetMode(autoreset_mode) - ) + self.autoreset_mode = AutoresetMode(autoreset_mode) + self.metadata["autoreset_mode"] = self.autoreset_mode.value + + assert not (self.autoreset_mode == AutoresetMode.DISABLED and self.batch_size != self.num_envs) + self.observation_space = gymnasium.vector.utils.batch_space( self.single_observation_space, self.batch_size ) @@ -231,6 +231,8 @@ def recv( def xla(self): """Return XLA-compatible functions for JAX integration.""" + assert self.autoreset_mode == AutoresetMode.NEXT_STEP or self.autoreset_mode == AutoresetMode.DISABLED + try: import chex import jax diff --git a/src/ale/vector/async_vectorizer.hpp b/src/ale/vector/async_vectorizer.hpp index f00f7dc5d..0f222d79f 100644 --- a/src/ale/vector/async_vectorizer.hpp +++ b/src/ale/vector/async_vectorizer.hpp @@ -61,6 +61,7 @@ namespace ale::vector { batch_size_(batch_size > 0 ? batch_size : num_envs), autoreset_mode_(autoreset_mode), stop_(false), + first_batch_(true), action_queue_(new ActionQueue(num_envs_)), pending_obs_buffer_(nullptr), pending_final_obs_(nullptr), @@ -130,6 +131,12 @@ namespace ale::vector { pending_obs_buffer_ = new uint8_t[total_obs_size]; state_buffer_->set_output_buffer(pending_obs_buffer_); + // Release slots from previous batch (but not on first batch) + if (!first_batch_) { + state_buffer_->release_slots(); + } + first_batch_ = false; + // Allocate metadata buffers pending_env_ids_ = new int[batch_size_]; pending_rewards_ = new int[batch_size_]; @@ -202,6 +209,12 @@ namespace ale::vector { pending_episode_frame_numbers_ ); + // Release slots from previous batch (but not on first batch) + if (!first_batch_) { + state_buffer_->release_slots(); + } + first_batch_ = false; + // In SameStep mode, also allocate final_obs buffer if (autoreset_mode_ == AutoresetMode::SameStep) { pending_final_obs_ = new uint8_t[total_obs_size]; @@ -291,6 +304,7 @@ namespace ale::vector { AutoresetMode autoreset_mode_; // How to reset sub-environments after an episode ends std::atomic stop_; // Signal to stop worker threads + bool first_batch_; // Track if this is the first batch (don't release permits) std::vector workers_; // Worker threads std::unique_ptr action_queue_; // Queue for actions std::unique_ptr state_buffer_; // Buffer for observations and metadata diff --git a/src/ale/vector/utils.hpp b/src/ale/vector/utils.hpp index a112d8999..b7499fc4c 100644 --- a/src/ale/vector/utils.hpp +++ b/src/ale/vector/utils.hpp @@ -151,7 +151,8 @@ namespace ale::vector { count_(0), write_idx_(0), sem_ready_(0), - sem_read_(1) {} + sem_read_(1), + sem_slots_(batch_size) {} // Initialize with batch_size permits /** * Set the output buffer that workers will write observations into. @@ -207,11 +208,17 @@ namespace ale::vector { * Returns pointers for direct writing into the output buffer. * * Thread-safe: multiple workers can call simultaneously. + * In unordered mode, blocks if all slots are occupied. * * @param env_id The environment ID requesting a slot * @return WriteSlot with pointers into output buffers */ WriteSlot allocate_write_slot(int env_id) { + // In unordered mode, block if all slots are occupied + if (!ordered_mode_) { + while (!sem_slots_.wait()) {} // Acquire permit, blocks if none available + } + WriteSlot slot; if (ordered_mode_) { @@ -279,6 +286,17 @@ namespace ale::vector { episode_frame_numbers_buffer_ = nullptr; } + /** + * Release all slots for the next batch. + * Called by recv() after transferring buffer ownership to Python. + * This allows waiting workers to proceed and allocate slots. + */ + void release_slots() { + if (!ordered_mode_) { + sem_slots_.signal(batch_size_); // Release batch_size permits + } + } + // Accessors std::size_t get_batch_size() const { return batch_size_; } std::size_t get_obs_size() const { return obs_size_; } @@ -305,6 +323,7 @@ namespace ale::vector { std::atomic write_idx_; moodycamel::LightweightSemaphore sem_ready_; moodycamel::LightweightSemaphore sem_read_; + moodycamel::LightweightSemaphore sem_slots_; // Controls slot availability }; } diff --git a/tests/python/test_atari_vector_env.py b/tests/python/test_atari_vector_env.py index e11c9531a..1e1e588af 100644 --- a/tests/python/test_atari_vector_env.py +++ b/tests/python/test_atari_vector_env.py @@ -629,14 +629,7 @@ def test_same_step_autoreset_mode( if np.any(episode_over): has_autoreset = True - gym_final_obs = np.array( - [ - final_obs if ep_over else obs - for final_obs, obs, ep_over in zip( - gym_info.pop("final_obs"), gym_obs, episode_over - ) - ] - ) + gym_final_obs = gym_info.pop("final_obs") gym_info.pop("final_info") # ALEV doesn't return final info gym_info = { key: value.astype(np.int32) @@ -649,9 +642,9 @@ def test_same_step_autoreset_mode( gym_info, ale_info ), f"{gym_info=}, {ale_info=}, {t=}" - assert obs_equivalence( - gym_final_obs, ale_final_obs, t, autoreset_mode="SAME-STEP" - ), t + for i, ep_over in enumerate(episode_over): + if ep_over: + assert obs_equivalence(gym_final_obs[i], ale_final_obs[i], t, autoreset_mode="SAME-STEP"), t else: gym_info = { key: value.astype(np.int32) From 36c85a0e292aeb9f8d4d03b508b4e9bf25333968 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sat, 29 Nov 2025 13:23:35 +0000 Subject: [PATCH 4/8] refactor --- src/ale/external/ThreadPool.h | 98 --- .../python/ale_vector_python_interface.cpp | 392 ++++++------ .../python/ale_vector_python_interface.hpp | 283 +-------- src/ale/python/ale_vector_xla_interface.cpp | 264 ++++---- src/ale/python/vector_env.py | 12 +- src/ale/vector/CMakeLists.txt | 8 +- src/ale/vector/action_queue.hpp | 58 ++ src/ale/vector/async_vectorizer.hpp | 464 -------------- src/ale/vector/env_vectorizer.cpp | 296 +++++++++ src/ale/vector/env_vectorizer.hpp | 140 +++++ src/ale/vector/preprocessed_env.cpp | 267 +++++++++ src/ale/vector/preprocessed_env.hpp | 565 ++++-------------- src/ale/vector/result_staging.hpp | 126 ++++ src/ale/vector/types.hpp | 181 ++++++ src/ale/vector/utils.hpp | 330 ---------- tests/python/test_atari_vector_env.py | 11 +- 16 files changed, 1550 insertions(+), 1945 deletions(-) delete mode 100644 src/ale/external/ThreadPool.h create mode 100644 src/ale/vector/action_queue.hpp delete mode 100644 src/ale/vector/async_vectorizer.hpp create mode 100644 src/ale/vector/env_vectorizer.cpp create mode 100644 src/ale/vector/env_vectorizer.hpp create mode 100644 src/ale/vector/preprocessed_env.cpp create mode 100644 src/ale/vector/result_staging.hpp create mode 100644 src/ale/vector/types.hpp delete mode 100644 src/ale/vector/utils.hpp diff --git a/src/ale/external/ThreadPool.h b/src/ale/external/ThreadPool.h deleted file mode 100644 index 0475bdb24..000000000 --- a/src/ale/external/ThreadPool.h +++ /dev/null @@ -1,98 +0,0 @@ -#ifndef THREAD_POOL_H -#define THREAD_POOL_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { -public: - ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); -private: - // need to keep track of threads so we can join them - std::vector< std::thread > workers; - // the task queue - std::queue< std::function > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) - : stop(false) -{ - for(size_t i = 0;i task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - } - ); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> -{ - using return_type = typename std::result_of::type; - - auto task = std::make_shared< std::packaged_task >( - std::bind(std::forward(f), std::forward(args)...) - ); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if(stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); - - tasks.emplace([task](){ (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() -{ - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for(std::thread &worker: workers) - worker.join(); -} - -#endif diff --git a/src/ale/python/ale_vector_python_interface.cpp b/src/ale/python/ale_vector_python_interface.cpp index 06805cea4..74c3ad347 100644 --- a/src/ale/python/ale_vector_python_interface.cpp +++ b/src/ale/python/ale_vector_python_interface.cpp @@ -1,203 +1,257 @@ #include "ale_vector_python_interface.hpp" +#include "ale/vector/env_vectorizer.hpp" + #include #include #include #include #include #include -#include -#include -#include -#include namespace nb = nanobind; +namespace fs = std::filesystem; + +using ale::vector::EnvVectorizer; +using ale::vector::BatchResult; +using ale::vector::AutoresetMode; +using ale::vector::Action; + +namespace { + +/// Helper to create numpy array from raw pointer with capsule ownership +template +nb::ndarray make_numpy_array(T* data, std::vector shape) { + nb::capsule owner(data, [](void* p) noexcept { + delete[] static_cast(p); + }); + return nb::ndarray(data, shape.size(), shape.data(), owner); +} + +/// Convert BatchResult to Python tuple for reset: (observations, info) +nb::tuple wrap_reset_result(EnvVectorizer& vec, BatchResult&& result) { + const std::size_t batch_size = result.batch_size(); + auto [stack_num, height, width, channels] = vec.observation_shape(); + + // Build observation shape + std::vector obs_shape; + if (vec.is_grayscale()) { + obs_shape = {batch_size, static_cast(stack_num), + static_cast(height), static_cast(width)}; + } else { + obs_shape = {batch_size, static_cast(stack_num), + static_cast(height), static_cast(width), 3}; + } + + std::vector info_shape = {batch_size}; + + // Create numpy arrays (transfers ownership via release) + auto observations = make_numpy_array(result.release_observations(), obs_shape); + auto env_ids = make_numpy_array(result.release_env_ids(), info_shape); + auto lives = make_numpy_array(result.release_lives(), info_shape); + auto frame_numbers = make_numpy_array(result.release_frame_numbers(), info_shape); + auto episode_frame_numbers = make_numpy_array(result.release_episode_frame_numbers(), info_shape); + + // Clean up unreleased arrays (rewards, terminations, truncations not used in reset) + // BatchResult destructor handles this + + // Build info dict + nb::dict info; + info["env_id"] = env_ids; + info["lives"] = lives; + info["frame_number"] = frame_numbers; + info["episode_frame_number"] = episode_frame_numbers; + + return nb::make_tuple(observations, info); +} + +/// Convert BatchResult to Python tuple for step: (observations, rewards, terminations, truncations, info) +nb::tuple wrap_step_result(EnvVectorizer& vec, BatchResult&& result) { + const std::size_t batch_size = result.batch_size(); + auto [stack_num, height, width, channels] = vec.observation_shape(); + + // Build observation shape + std::vector obs_shape; + if (vec.is_grayscale()) { + obs_shape = {batch_size, static_cast(stack_num), + static_cast(height), static_cast(width)}; + } else { + obs_shape = {batch_size, static_cast(stack_num), + static_cast(height), static_cast(width), 3}; + } + + std::vector info_shape = {batch_size}; + + // Create numpy arrays + auto observations = make_numpy_array(result.release_observations(), obs_shape); + auto rewards = make_numpy_array(result.release_rewards(), info_shape); + auto terminations = make_numpy_array(result.release_terminations(), info_shape); + auto truncations = make_numpy_array(result.release_truncations(), info_shape); + auto env_ids = make_numpy_array(result.release_env_ids(), info_shape); + auto lives = make_numpy_array(result.release_lives(), info_shape); + auto frame_numbers = make_numpy_array(result.release_frame_numbers(), info_shape); + auto episode_frame_numbers = make_numpy_array(result.release_episode_frame_numbers(), info_shape); + + // Build info dict + nb::dict info; + info["env_id"] = env_ids; + info["lives"] = lives; + info["frame_number"] = frame_numbers; + info["episode_frame_number"] = episode_frame_numbers; + + // Handle final_obs for SameStep mode + if (result.has_final_obs()) { + // Check if any environment terminated or truncated + bool any_done = false; + bool* term_data = terminations.data(); + bool* trunc_data = truncations.data(); + for (std::size_t i = 0; i < batch_size; ++i) { + if (term_data[i] || trunc_data[i]) { + any_done = true; + break; + } + } + + if (any_done) { + auto final_obs = make_numpy_array(result.release_final_observations(), obs_shape); + info["final_obs"] = final_obs; + } + // If no envs done, final_obs buffer will be cleaned up by BatchResult destructor + } + + return nb::make_tuple(observations, rewards, terminations, truncations, info); +} + +} // anonymous namespace -// Function to add vector environment bindings to an existing module void init_vector_module(nb::module_& m) { - // Define ALEVectorInterface class - nb::class_(m, "ALEVectorInterface") - .def(nb::init(), - nb::arg("rom_path"), - nb::arg("num_envs"), - nb::arg("frame_skip") = 4, - nb::arg("stack_num") = 4, - nb::arg("img_height") = 84, - nb::arg("img_width") = 84, - nb::arg("grayscale") = true, - nb::arg("maxpool") = true, - nb::arg("noop_max") = 30, - nb::arg("use_fire_reset") = true, - nb::arg("episodic_life") = false, - nb::arg("life_loss_info") = false, - nb::arg("reward_clipping") = true, - nb::arg("max_episode_steps") = 108000, - nb::arg("repeat_action_probability") = 0.0f, - nb::arg("full_action_space") = false, - nb::arg("batch_size") = 0, - nb::arg("num_threads") = 0, - nb::arg("thread_affinity_offset") = -1, - nb::arg("autoreset_mode") = "NextStep") - .def("reset", [](ale::vector::ALEVectorInterface& self, const std::vector reset_indices, const std::vector reset_seeds) { - // Call C++ reset method with GIL released + nb::class_(m, "ALEVectorInterface") + .def("__init__", [](EnvVectorizer* t, + const fs::path& rom_path, + int num_envs, + int frame_skip, + int stack_num, + int img_height, + int img_width, + bool grayscale, + bool maxpool, + int noop_max, + bool use_fire_reset, + bool episodic_life, + bool life_loss_info, + bool reward_clipping, + int max_episode_steps, + float repeat_action_probability, + bool full_action_space, + int batch_size, + int num_threads, + int thread_affinity_offset, + const std::string& autoreset_mode_str + ) { + AutoresetMode autoreset_mode; + if (autoreset_mode_str == "NextStep") { + autoreset_mode = AutoresetMode::NextStep; + } else if (autoreset_mode_str == "SameStep") { + autoreset_mode = AutoresetMode::SameStep; + } else { + throw std::invalid_argument("Invalid autoreset_mode: " + autoreset_mode_str); + } + + new (t) EnvVectorizer( + rom_path, num_envs, batch_size, num_threads, thread_affinity_offset, + autoreset_mode, img_height, img_width, stack_num, grayscale, + frame_skip, maxpool, noop_max, use_fire_reset, episodic_life, + life_loss_info, reward_clipping, max_episode_steps, + repeat_action_probability, full_action_space + ); + }, + nb::arg("rom_path"), + nb::arg("num_envs"), + nb::arg("frame_skip") = 4, + nb::arg("stack_num") = 4, + nb::arg("img_height") = 84, + nb::arg("img_width") = 84, + nb::arg("grayscale") = true, + nb::arg("maxpool") = true, + nb::arg("noop_max") = 30, + nb::arg("use_fire_reset") = true, + nb::arg("episodic_life") = false, + nb::arg("life_loss_info") = false, + nb::arg("reward_clipping") = true, + nb::arg("max_episode_steps") = 108000, + nb::arg("repeat_action_probability") = 0.0f, + nb::arg("full_action_space") = false, + nb::arg("batch_size") = 0, + nb::arg("num_threads") = 0, + nb::arg("thread_affinity_offset") = -1, + nb::arg("autoreset_mode") = "NextStep") + + .def("reset", [](EnvVectorizer& self, + const std::vector& reset_indices, + const std::vector& reset_seeds) { nb::gil_scoped_release release; auto result = self.reset(reset_indices, reset_seeds); nb::gil_scoped_acquire acquire; + return wrap_reset_result(self, std::move(result)); + }) - // Get shape information - const int batch_size = result.batch_size; - const auto obs_shape = self.get_observation_shape(); - const int stack_num = std::get<0>(obs_shape); - const int height = std::get<1>(obs_shape); - const int width = std::get<2>(obs_shape); - const bool grayscale = self.is_grayscale(); - - // Wrap observation buffer - capsule takes ownership - nb::capsule obs_owner(result.observations, [](void *p) noexcept { - delete[] static_cast(p); - }); + .def("send", [](EnvVectorizer& self, + const std::vector& action_ids, + const std::vector& paddle_strengths) { + if (action_ids.size() != paddle_strengths.size()) { + throw std::invalid_argument("action_ids and paddle_strengths must have same size"); + } - nb::ndarray observations; - if (grayscale) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(result.observations, 4, shape, obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(result.observations, 5, shape, obs_owner); + std::vector actions; + actions.reserve(action_ids.size()); + for (std::size_t i = 0; i < action_ids.size(); ++i) { + Action a; + a.env_id = static_cast(i); // Will be remapped in send() + a.action_id = action_ids[i]; + a.paddle_strength = paddle_strengths[i]; + a.force_reset = false; + actions.push_back(a); } - // Create capsules - ownership transferred from BatchData - nb::capsule env_ids_owner(result.env_ids, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule lives_owner(result.lives, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule frame_numbers_owner(result.frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule episode_frame_numbers_owner(result.episode_frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - - // Create numpy arrays (zero-copy - direct from BatchData) - size_t info_shape[1] = {(size_t)batch_size}; - auto env_ids = nb::ndarray(result.env_ids, 1, info_shape, env_ids_owner); - auto lives = nb::ndarray(result.lives, 1, info_shape, lives_owner); - auto frame_numbers = nb::ndarray(result.frame_numbers, 1, info_shape, frame_numbers_owner); - auto episode_frame_numbers = nb::ndarray(result.episode_frame_numbers, 1, info_shape, episode_frame_numbers_owner); - - // Create info dict - nb::dict info; - info["env_id"] = env_ids; - info["lives"] = lives; - info["frame_number"] = frame_numbers; - info["episode_frame_number"] = episode_frame_numbers; - - return nb::make_tuple(observations, info); + self.send(actions); }) - .def("send", [](ale::vector::ALEVectorInterface& self, const std::vector action_ids, const std::vector paddle_strengths) { - self.send(action_ids, paddle_strengths); - }) - .def("recv", [](ale::vector::ALEVectorInterface& self) { - // Release GIL while waiting for workers + + .def("recv", [](EnvVectorizer& self) { nb::gil_scoped_release release; auto result = self.recv(); nb::gil_scoped_acquire acquire; + return wrap_step_result(self, std::move(result)); + }) - // Get shape info - const auto obs_shape_info = self.get_observation_shape(); - const int stack_num = std::get<0>(obs_shape_info); - const int height = std::get<1>(obs_shape_info); - const int width = std::get<2>(obs_shape_info); - const int batch_size = result.batch_size; - const bool grayscale = self.is_grayscale(); - - // Wrap obs buffer - capsule takes ownership and will delete[] - nb::capsule obs_owner(result.observations, [](void *p) noexcept { delete[] static_cast(p); }); - nb::ndarray observations; - if (grayscale) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - observations = nb::ndarray(result.observations, 4, shape, obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - observations = nb::ndarray(result.observations, 5, shape, obs_owner); - } - - // Create capsules - ownership transferred from BatchData - nb::capsule rewards_owner(result.rewards, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule terminations_owner(result.terminations, [](void *p) noexcept { delete[] (bool*)p; }); - nb::capsule truncations_owner(result.truncations, [](void *p) noexcept { delete[] (bool*)p; }); - nb::capsule env_ids_owner(result.env_ids, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule lives_owner(result.lives, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule frame_numbers_owner(result.frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - nb::capsule episode_frame_numbers_owner(result.episode_frame_numbers, [](void *p) noexcept { delete[] (int*)p; }); - - // Create numpy arrays (zero-copy - direct from BatchData) - size_t info_shape[1] = {(size_t)batch_size}; - auto rewards = nb::ndarray(result.rewards, 1, info_shape, rewards_owner); - auto terminations = nb::ndarray(result.terminations, 1, info_shape, terminations_owner); - auto truncations = nb::ndarray(result.truncations, 1, info_shape, truncations_owner); - auto env_ids = nb::ndarray(result.env_ids, 1, info_shape, env_ids_owner); - auto lives = nb::ndarray(result.lives, 1, info_shape, lives_owner); - auto frame_numbers = nb::ndarray(result.frame_numbers, 1, info_shape, frame_numbers_owner); - auto episode_frame_numbers = nb::ndarray(result.episode_frame_numbers, 1, info_shape, episode_frame_numbers_owner); - - // Build info dict - nb::dict info; - info["env_id"] = env_ids; - info["lives"] = lives; - info["frame_number"] = frame_numbers; - info["episode_frame_number"] = episode_frame_numbers; - - // Handle final_obs for SameStep mode - only include if any env terminated/truncated - if (result.final_observations != nullptr) { - // Check if any environment actually terminated or truncated - bool any_done = false; - for (size_t i = 0; i < batch_size; i++) { - if (result.terminations[i] || result.truncations[i]) { - any_done = true; - break; - } - } + .def("get_action_set", &EnvVectorizer::action_set) - if (any_done) { - // Wrap the buffer directly - workers have already filled in all slots - nb::capsule final_obs_owner(result.final_observations, [](void *p) noexcept { - delete[] static_cast(p); - }); - - nb::ndarray final_observations; - if (grayscale) { - size_t shape[4] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width}; - final_observations = nb::ndarray(result.final_observations, 4, shape, final_obs_owner); - } else { - size_t shape[5] = {(size_t)batch_size, (size_t)stack_num, (size_t)height, (size_t)width, 3}; - final_observations = nb::ndarray(result.final_observations, 5, shape, final_obs_owner); - } - info["final_obs"] = final_observations; - } else { - // No environments terminated - delete the unused buffer - delete[] result.final_observations; - } - } + .def("get_num_envs", &EnvVectorizer::num_envs) - return nb::make_tuple(observations, rewards, terminations, truncations, info); - }) - .def("get_action_set", &ale::vector::ALEVectorInterface::get_action_set) - .def("get_num_envs", &ale::vector::ALEVectorInterface::get_num_envs) - .def("get_observation_shape", [](ale::vector::ALEVectorInterface& self) { - auto shape = self.get_observation_shape(); + .def("get_observation_shape", [](EnvVectorizer& self) { + auto [stack, h, w, c] = self.observation_shape(); if (self.is_grayscale()) { - return nb::make_tuple(std::get<0>(shape), std::get<1>(shape), std::get<2>(shape)); + return nb::make_tuple(stack, h, w); } else { - return nb::make_tuple(std::get<0>(shape), std::get<1>(shape), std::get<2>(shape), std::get<3>(shape)); + return nb::make_tuple(stack, h, w, c); } }) - .def("handle", [](ale::vector::ALEVectorInterface& self) { - // Get the raw pointer to the AsyncVectorizer - auto ptr = self.get_vectorizer(); - // Allocate memory for handle array - uint8_t* handle_data = new uint8_t[sizeof(ptr)]; - std::memcpy(handle_data, &ptr, sizeof(ptr)); + .def("handle", [](EnvVectorizer& self) { + const void* ptr = self.handle(); + std::size_t ptr_size = sizeof(ptr); + + uint8_t* handle_data = new uint8_t[ptr_size]; + std::memcpy(handle_data, &ptr, ptr_size); - // Create capsule for cleanup - nb::capsule handle_owner(handle_data, [](void *p) noexcept { delete[] (uint8_t *) p; }); + nb::capsule owner(handle_data, [](void* p) noexcept { + delete[] static_cast(p); + }); - // Create numpy array - size_t shape[1] = {sizeof(ptr)}; - return nb::ndarray(handle_data, 1, shape, handle_owner); + std::vector shape = {ptr_size}; + return nb::ndarray(handle_data, shape.size(), shape.data(), owner); }); + + // Expose AutoresetMode enum + nb::enum_(m, "AutoresetMode") + .value("NextStep", AutoresetMode::NextStep) + .value("SameStep", AutoresetMode::SameStep); } diff --git a/src/ale/python/ale_vector_python_interface.hpp b/src/ale/python/ale_vector_python_interface.hpp index acb807a6d..7d445f44b 100644 --- a/src/ale/python/ale_vector_python_interface.hpp +++ b/src/ale/python/ale_vector_python_interface.hpp @@ -1,284 +1,11 @@ -#ifndef ALE_VECTOR_INTERFACE_HPP_ -#define ALE_VECTOR_INTERFACE_HPP_ - -#include -#include -#include -#include -#include - -#include "ale/vector/async_vectorizer.hpp" -#include "ale/vector/preprocessed_env.hpp" -#include "ale/vector/utils.hpp" +#ifndef ALE_VECTOR_PYTHON_INTERFACE_HPP_ +#define ALE_VECTOR_PYTHON_INTERFACE_HPP_ #include -#include -#include -#include -#include namespace nb = nanobind; -namespace fs = std::filesystem; - -namespace ale::vector { - - /** - * ALEVectorInterface provides a vectorized interface to the Arcade Learning Environment. - * It manages multiple Atari environments running in parallel and allows sending actions - * and receiving observations in batches. - */ - class ALEVectorInterface { - public: - /** - * Constructor - * - * @param rom_path Path to the ROM file - * @param num_envs Number of parallel environments - * @param frame_skip Number of frames to skip between agent decisions (default: 4) - * @param stack_num Number of frames to stack for observations (default: 4) - * @param img_height Height to resize frames to (default: 84) - * @param img_width Width to resize frames to (default: 84) - * @param grayscale Whether to use grayscale observations (default: true) - * @param maxpool If to maxpool over frames (default: true) - * @param noop_max Maximum number of no-ops to perform at reset (default: 30) - * @param use_fire_reset Whether to press FIRE during reset (default: true) - * @param episodic_life Whether to end episodes when a life is lost (default: false) - * @param life_loss_info Whether to return `terminated=True` on a life loss but not reset until `lives==0` - * @param reward_clipping Whether to clip the environment rewards between -1 and 1 - * @param max_episode_steps Maximum number of steps per episode (default: 108000) - * @param repeat_action_probability Probability of repeating the last action (default: 0.0f) - * @param full_action_space Whether to use the full action space (default: false) - * @param batch_size The number of environments to process in a batch (0 means use num_envs, default: 0) - * @param num_threads The number of worker threads to use (0 means use hardware concurrency, default: 0) - * @param thread_affinity_offset The CPU core offset for thread affinity (-1 means no affinity, default: -1) - */ - ALEVectorInterface( - const fs::path &rom_path, - const int num_envs, - const int frame_skip = 4, - const int stack_num = 4, - const int img_height = 84, - const int img_width = 84, - const bool grayscale = true, - const bool maxpool = true, - const int noop_max = 30, - const bool use_fire_reset = true, - const bool episodic_life = false, - const bool life_loss_info = false, - const bool reward_clipping = true, - const int max_episode_steps = 108000, - const float repeat_action_probability = 0.0f, - const bool full_action_space = false, - const int batch_size = 0, - const int num_threads = 0, - const int thread_affinity_offset = -1, - const std::string &autoreset_mode = "NextStep" - ) : rom_path_(rom_path), - num_envs_(num_envs), - frame_skip_(frame_skip), - stack_num_(stack_num), - img_height_(img_height), - img_width_(img_width), - grayscale_(grayscale), - obs_format_(grayscale_ ? ObsFormat::Grayscale : ObsFormat::RGB), - maxpool_(maxpool), - noop_max_(noop_max), - use_fire_reset_(use_fire_reset), - episodic_life_(episodic_life), - life_loss_info_(life_loss_info), - reward_clipping_(reward_clipping), - max_episode_steps_(max_episode_steps), - repeat_action_probability_(repeat_action_probability), - full_action_space_(full_action_space), - received_env_ids_(batch_size > 0 ? batch_size : num_envs) { - - // Create environment factory - auto env_factory = [this](int env_id) { - return std::make_unique( - env_id, - rom_path_, - img_height_, - img_width_, - frame_skip_, - maxpool_, - obs_format_, - stack_num_, - noop_max_, - use_fire_reset_, - episodic_life_, - life_loss_info_, - reward_clipping_, - max_episode_steps_, - repeat_action_probability_, - full_action_space_, - -1 - ); - }; - - if (autoreset_mode == "NextStep") { - autoreset_mode_ = AutoresetMode::NextStep; - } else if (autoreset_mode == "SameStep") { - autoreset_mode_ = AutoresetMode::SameStep; - } else { - throw std::invalid_argument("Invalid autoreset_mode: " + autoreset_mode + ", expected values: 'NextStep' or 'SameStep'"); - } - - // Create vectorizer - vectorizer_ = std::make_unique( - num_envs, - batch_size, - num_threads, - thread_affinity_offset, - env_factory, - autoreset_mode_ - ); - - // Initialize the action set (assuming all environments have the same action set) - const auto temp_env = env_factory(0); - action_set_ = temp_env->get_action_set(); - } - - /** - * Reset all environments - * - * @param reset_indices Vector of environment indices to be reset - * @param reset_seeds Vector of environment seeds to use - * @return BatchData with initial observations - */ - BatchData reset(const std::vector &reset_indices, const std::vector &reset_seeds) { - vectorizer_->reset(reset_indices, reset_seeds); - return recv(); - } - - /** - * Step environments with actions - * - * @param action_ids Vector of actions ids to take - * @param paddle_strengths Vector of paddle strengths to take - */ - void send(const std::vector& action_ids, const std::vector& paddle_strengths) const { - if (action_ids.size() != paddle_strengths.size()) { - throw std::invalid_argument( - "The size of the action_ids is different from the paddle_strengths, action_ids length=" + std::to_string(action_ids.size()) - + ", paddle_strengths length=" + std::to_string(paddle_strengths.size())); - } - std::vector environment_actions; - environment_actions.resize(action_ids.size()); - - for (size_t i = 0; i < action_ids.size(); i++) { - EnvironmentAction env_action; - env_action.env_id = received_env_ids_[i]; - env_action.action_id = action_ids[i]; - env_action.paddle_strength = paddle_strengths[i]; - - environment_actions[i] = env_action; - } - - vectorizer_->send(environment_actions); - } - - /** - * Returns the environment's data for the environments. - * Returns ownership of observation buffer to caller. - */ - BatchData recv() { - BatchData result = vectorizer_->recv(); - for (size_t i = 0; i < result.batch_size; i++) { - received_env_ids_[i] = result.env_ids[i]; - } - return result; - } - - /** - * Get the available actions for the environments - * - * @return Vector of available actions - */ - const ActionVect& get_action_set() const { - return action_set_; - } - - /** - * Get the number of environments - * - * @return Number of environments - */ - const int get_num_envs() const { - return num_envs_; - } - - /** - * Get the dimensions of the observation space - * - * @return Tuple of (stack_num, height, width, 0) if grayscale or (stack_num, height, width, 3) if RGB - */ - const std::tuple get_observation_shape() const { - if (grayscale_) { - return std::make_tuple(stack_num_, img_height_, img_width_, 0); - } else { - return std::make_tuple(stack_num_, img_height_, img_width_, 3); - } - } - - /** - * Check if observations are grayscale - * - * @return true if observations are grayscale, false if RGB - */ - const bool is_grayscale() const { - return grayscale_; - } - - /** - * Get the async_vectorizer's autoreset mode - * - * @return the autoreset mode of the async_vectorizer - */ - const AutoresetMode get_autoreset_mode() const { - return autoreset_mode_; - } - - /** - * Get the size of a single stacked observation in bytes. - */ - std::size_t get_stacked_obs_size() const { - return vectorizer_->get_stacked_obs_size(); - } - - /** - * Get the underlying vectorizer - * - * @return pointer for the underlying vectorizer - */ - const AsyncVectorizer* get_vectorizer() const { - return vectorizer_.get(); - } - - private: - fs::path rom_path_; // Path to the ROM file - int num_envs_; // Number of parallel environments - int frame_skip_; // Number of frames to skip - int stack_num_; // Number of frames to stack - int img_height_; // Height of resized frames - int img_width_; // Width of resized frames - bool grayscale_; // Whether to use grayscale observations - ObsFormat obs_format_; // Observation format based on grayscale - bool maxpool_; // If to maxpool over frames - int noop_max_; // Max no-ops on reset - bool use_fire_reset_; // Whether to fire on reset - bool episodic_life_; // End episode on life loss - bool life_loss_info_; // If to provide a termination signal (but not reset) on life loss - bool reward_clipping_; // If to clip rewards between -1 and 1 - int max_episode_steps_; // Max steps per episode - float repeat_action_probability_; // Repeat actions probability for sticky actions - bool full_action_space_; // Use full action space - AutoresetMode autoreset_mode_; - - std::vector received_env_ids_; // Vector of environment ids for the most recently received data - std::unique_ptr vectorizer_; // Vectorizer - ActionVect action_set_; // Set of available actions - }; -} +/// Add vector environment bindings to the module +void init_vector_module(nb::module_& m); -#endif // ALE_VECTOR_INTERFACE_HPP_ +#endif // ALE_VECTOR_PYTHON_INTERFACE_HPP_ diff --git a/src/ale/python/ale_vector_xla_interface.cpp b/src/ale/python/ale_vector_xla_interface.cpp index 061c7cf60..cc2450bcf 100644 --- a/src/ale/python/ale_vector_xla_interface.cpp +++ b/src/ale/python/ale_vector_xla_interface.cpp @@ -1,6 +1,6 @@ #include "xla/ffi/api/ffi.h" -#include "ale/vector/async_vectorizer.hpp" +#include "ale/vector/env_vectorizer.hpp" #include #include @@ -29,12 +29,12 @@ ffi::Error XLAResetImpl( ffi::ResultBuffer episode_frame_numbers_buffer ) { // Validate handle buffer size - if (handle_buffer.element_count() != sizeof(ale::vector::AsyncVectorizer*)) { + if (handle_buffer.element_count() != sizeof(ale::vector::EnvVectorizer*)) { return ffi::Error::Internal("Incorrect handle buffer size in reset"); } // Safely extract the vectorizer pointer from the handle buffer - ale::vector::AsyncVectorizer* vectorizer = nullptr; + ale::vector::EnvVectorizer* vectorizer = nullptr; std::memcpy(&vectorizer, handle_buffer.typed_data(), sizeof(vectorizer)); if (!vectorizer) { return ffi::Error::Internal("Invalid vectorizer pointer in reset"); @@ -57,38 +57,28 @@ ffi::Error XLAResetImpl( reset_seeds_buffer.typed_data(), reset_seeds_buffer.typed_data() + reset_seeds_buffer.element_count()); - // Reset the environments - vectorizer->reset(reset_indices, reset_seeds); + // Reset the environments (returns BatchResult directly) + auto result = vectorizer->reset(reset_indices, reset_seeds); - // Receive the observations after reset - auto timesteps = vectorizer->recv(); - - if (timesteps.empty()) { - return ffi::Error::Internal("No timesteps received after step"); - } else if (timesteps.size() != vectorizer->get_batch_size()) { - return ffi::Error::Internal("Number of timesteps is wrong"); - } - - size_t stacked_obs_size = vectorizer->get_stacked_obs_size(); + size_t batch_size = result.batch_size(); + size_t stacked_obs_size = vectorizer->stacked_obs_size(); // Check if the observations buffer is large enough - if (observations_buffer->element_count() != vectorizer->get_batch_size() * stacked_obs_size) { + if (observations_buffer->element_count() != batch_size * stacked_obs_size) { return ffi::Error::Internal("Observations buffer is the wrong size"); } - for (int i = 0; i < vectorizer->get_batch_size(); ++i) { - const auto& timestep = timesteps[i]; - - std::memcpy( - observations_buffer->typed_data() + i * stacked_obs_size, - timestep.observation.data(), - stacked_obs_size - ); - env_ids_buffer->typed_data()[i] = timestep.env_id; - lives_buffer->typed_data()[i] = timestep.lives; - frame_numbers_buffer->typed_data()[i] = timestep.frame_number; - episode_frame_numbers_buffer->typed_data()[i] = timestep.episode_frame_number; - } + // Copy data from BatchResult to output buffers + std::memcpy(observations_buffer->typed_data(), result.obs_data(), + batch_size * stacked_obs_size); + std::memcpy(env_ids_buffer->typed_data(), result.env_ids_data(), + batch_size * sizeof(int32_t)); + std::memcpy(lives_buffer->typed_data(), result.lives_data(), + batch_size * sizeof(int32_t)); + std::memcpy(frame_numbers_buffer->typed_data(), result.frame_numbers_data(), + batch_size * sizeof(int32_t)); + std::memcpy(episode_frame_numbers_buffer->typed_data(), result.episode_frame_numbers_data(), + batch_size * sizeof(int32_t)); return ffi::Error::Success(); } @@ -133,7 +123,7 @@ ffi::Error XLAResetGPUImpl( ffi::ResultBuffer episode_frame_numbers_buffer ) { // Validate handle buffer size - if (handle_buffer.element_count() != sizeof(ale::vector::AsyncVectorizer*)) { + if (handle_buffer.element_count() != sizeof(ale::vector::EnvVectorizer*)) { return ffi::Error::Internal("Incorrect handle buffer size in reset (GPU)"); } @@ -153,7 +143,7 @@ ffi::Error XLAResetGPUImpl( } // Extract the vectorizer pointer - ale::vector::AsyncVectorizer* vectorizer = nullptr; + ale::vector::EnvVectorizer* vectorizer = nullptr; std::memcpy(&vectorizer, host_handle.data(), sizeof(vectorizer)); if (!vectorizer) { return ffi::Error::Internal("Invalid vectorizer pointer in reset (GPU)"); @@ -199,73 +189,47 @@ ffi::Error XLAResetGPUImpl( std::vector reset_indices(host_reset_indices.begin(), host_reset_indices.end()); std::vector reset_seeds(host_reset_seeds.begin(), host_reset_seeds.end()); - // Reset the environments (CPU operation) - vectorizer->reset(reset_indices, reset_seeds); - - // Receive the observations after reset - auto timesteps = vectorizer->recv(); - - if (timesteps.empty()) { - return ffi::Error::Internal("No timesteps received after reset (GPU)"); - } else if (timesteps.size() != vectorizer->get_batch_size()) { - return ffi::Error::Internal("Number of timesteps is wrong (GPU)"); - } + // Reset the environments (returns BatchResult directly) + auto result = vectorizer->reset(reset_indices, reset_seeds); - size_t stacked_obs_size = vectorizer->get_stacked_obs_size(); - size_t batch_size = vectorizer->get_batch_size(); + size_t batch_size = result.batch_size(); + size_t stacked_obs_size = vectorizer->stacked_obs_size(); // Check if the observations buffer is large enough if (observations_buffer->element_count() != batch_size * stacked_obs_size) { return ffi::Error::Internal("Observations buffer is the wrong size (GPU)"); } - // Prepare host buffers - std::vector host_observations(batch_size * stacked_obs_size); - std::vector host_env_ids(batch_size); - std::vector host_lives(batch_size); - std::vector host_frame_numbers(batch_size); - std::vector host_episode_frame_numbers(batch_size); - - for (size_t i = 0; i < batch_size; ++i) { - const auto& timestep = timesteps[i]; - std::memcpy(host_observations.data() + i * stacked_obs_size, - timestep.observation.data(), stacked_obs_size); - host_env_ids[i] = timestep.env_id; - host_lives[i] = timestep.lives; - host_frame_numbers[i] = timestep.frame_number; - host_episode_frame_numbers[i] = timestep.episode_frame_number; - } - - // Copy results to GPU - err = cudaMemcpyAsync(observations_buffer->typed_data(), host_observations.data(), + // Copy data from BatchResult to GPU buffers via host memory + err = cudaMemcpyAsync(observations_buffer->typed_data(), result.obs_data(), batch_size * stacked_obs_size, cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (observations H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(env_ids_buffer->typed_data(), host_env_ids.data(), + err = cudaMemcpyAsync(env_ids_buffer->typed_data(), result.env_ids_data(), batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (env_ids H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(lives_buffer->typed_data(), host_lives.data(), + err = cudaMemcpyAsync(lives_buffer->typed_data(), result.lives_data(), batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (lives H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(frame_numbers_buffer->typed_data(), host_frame_numbers.data(), + err = cudaMemcpyAsync(frame_numbers_buffer->typed_data(), result.frame_numbers_data(), batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (frame_numbers H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(episode_frame_numbers_buffer->typed_data(), host_episode_frame_numbers.data(), + err = cudaMemcpyAsync(episode_frame_numbers_buffer->typed_data(), result.episode_frame_numbers_data(), batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { @@ -324,12 +288,12 @@ ffi::Error XLAStepImpl( ffi::ResultBuffer episode_frame_numbers_buffer ) { // Validate handle buffer size - if (handle_buffer.element_count() != sizeof(ale::vector::AsyncVectorizer*)) { + if (handle_buffer.element_count() != sizeof(ale::vector::EnvVectorizer*)) { return ffi::Error::Internal("Incorrect handle buffer size in step"); } // Safely extract the vectorizer pointer from the handle buffer - ale::vector::AsyncVectorizer* vectorizer = nullptr; + ale::vector::EnvVectorizer* vectorizer = nullptr; std::memcpy(&vectorizer, handle_buffer.typed_data(), sizeof(vectorizer)); if (!vectorizer) { return ffi::Error::Internal("Invalid vectorizer pointer in step"); @@ -344,7 +308,7 @@ ffi::Error XLAStepImpl( handle_buffer.element_count()); try { - size_t num_envs = vectorizer->get_batch_size(); + size_t num_envs = vectorizer->num_envs(); if (action_id_buffer.element_count() != num_envs) { return ffi::Error::Internal("Action id buffer is the wrong size"); @@ -352,47 +316,48 @@ ffi::Error XLAStepImpl( return ffi::Error::Internal("Paddle strength buffer is the wrong size"); } - std::vector actions(num_envs); + std::vector actions(num_envs); for (size_t i = 0; i < num_envs; ++i) { - actions[i].env_id = i; + actions[i].env_id = static_cast(i); actions[i].action_id = action_id_buffer.typed_data()[i]; actions[i].paddle_strength = paddle_strength_buffer.typed_data()[i]; + actions[i].force_reset = false; } // Step the environments vectorizer->send(actions); - // Receive the timesteps - auto timesteps = vectorizer->recv(); - - if (timesteps.empty()) { - return ffi::Error::Internal("No timesteps received after step"); - } else if (timesteps.size() != vectorizer->get_batch_size()) { - return ffi::Error::Internal("Number of timesteps is wrong"); - } + // Receive the results + auto result = vectorizer->recv(); - size_t stacked_obs_size = vectorizer->get_stacked_obs_size(); + size_t batch_size = result.batch_size(); + size_t stacked_obs_size = vectorizer->stacked_obs_size(); // Check if the observations buffer is large enough - if (observations_buffer->element_count() != vectorizer->get_batch_size() * stacked_obs_size) { + if (observations_buffer->element_count() != batch_size * stacked_obs_size) { return ffi::Error::Internal("Observations buffer is the wrong size"); } - for (int i = 0; i < vectorizer->get_batch_size(); ++i) { - const auto& timestep = timesteps[i]; - - std::memcpy( - observations_buffer->typed_data() + i * stacked_obs_size, - timestep.observation.data(), - stacked_obs_size - ); - rewards_buffer->typed_data()[i] = timestep.reward; - terminations_buffer->typed_data()[i] = timestep.terminated; - truncations_buffer->typed_data()[i] = timestep.truncated; - env_id_buffer->typed_data()[i] = timestep.env_id; - lives_buffer->typed_data()[i] = timestep.lives; - frame_numbers_buffer->typed_data()[i] = timestep.frame_number; - episode_frame_numbers_buffer->typed_data()[i] = timestep.episode_frame_number; + // Copy data from BatchResult to output buffers + std::memcpy(observations_buffer->typed_data(), result.obs_data(), + batch_size * stacked_obs_size); + std::memcpy(rewards_buffer->typed_data(), result.rewards_data(), + batch_size * sizeof(int32_t)); + std::memcpy(env_id_buffer->typed_data(), result.env_ids_data(), + batch_size * sizeof(int32_t)); + std::memcpy(lives_buffer->typed_data(), result.lives_data(), + batch_size * sizeof(int32_t)); + std::memcpy(frame_numbers_buffer->typed_data(), result.frame_numbers_data(), + batch_size * sizeof(int32_t)); + std::memcpy(episode_frame_numbers_buffer->typed_data(), result.episode_frame_numbers_data(), + batch_size * sizeof(int32_t)); + + // Copy bools element-wise (bool* to PRED buffer) + const bool* term_data = result.terminations_data(); + const bool* trunc_data = result.truncations_data(); + for (size_t i = 0; i < batch_size; ++i) { + terminations_buffer->typed_data()[i] = term_data[i]; + truncations_buffer->typed_data()[i] = trunc_data[i]; } return ffi::Error::Success(); @@ -444,7 +409,7 @@ ffi::Error XLAStepGPUImpl( ffi::ResultBuffer episode_frame_numbers_buffer ) { // Validate handle buffer size - if (handle_buffer.element_count() != sizeof(ale::vector::AsyncVectorizer*)) { + if (handle_buffer.element_count() != sizeof(ale::vector::EnvVectorizer*)) { return ffi::Error::Internal("Incorrect handle buffer size in step (GPU)"); } @@ -464,7 +429,7 @@ ffi::Error XLAStepGPUImpl( } // Extract the vectorizer pointer - ale::vector::AsyncVectorizer* vectorizer = nullptr; + ale::vector::EnvVectorizer* vectorizer = nullptr; std::memcpy(&vectorizer, host_handle.data(), sizeof(vectorizer)); if (!vectorizer) { return ffi::Error::Internal("Invalid vectorizer pointer in step (GPU)"); @@ -482,7 +447,7 @@ ffi::Error XLAStepGPUImpl( } try { - size_t num_envs = vectorizer->get_batch_size(); + size_t num_envs = vectorizer->num_envs(); if (action_id_buffer.element_count() != num_envs) { return ffi::Error::Internal("Action id buffer is the wrong size (GPU)"); @@ -515,110 +480,93 @@ ffi::Error XLAStepGPUImpl( } // Prepare actions for vectorizer - std::vector actions(num_envs); + std::vector actions(num_envs); for (size_t i = 0; i < num_envs; ++i) { - actions[i].env_id = i; + actions[i].env_id = static_cast(i); actions[i].action_id = host_action_ids[i]; actions[i].paddle_strength = host_paddle_strength[i]; + actions[i].force_reset = false; } // Step the environments (CPU operation) vectorizer->send(actions); - // Receive the timesteps - auto timesteps = vectorizer->recv(); - - if (timesteps.empty()) { - return ffi::Error::Internal("No timesteps received after step (GPU)"); - } else if (timesteps.size() != num_envs) { - return ffi::Error::Internal("Number of timesteps is wrong (GPU)"); - } + // Receive the results + auto result = vectorizer->recv(); - size_t stacked_obs_size = vectorizer->get_stacked_obs_size(); + size_t batch_size = result.batch_size(); + size_t stacked_obs_size = vectorizer->stacked_obs_size(); // Check if the observations buffer is large enough - if (observations_buffer->element_count() != num_envs * stacked_obs_size) { + if (observations_buffer->element_count() != batch_size * stacked_obs_size) { return ffi::Error::Internal("Observations buffer is the wrong size (GPU)"); } - // Prepare host buffers (use uint8_t for bool to avoid std::vector specialization issues) - std::vector host_observations(num_envs * stacked_obs_size); - std::vector host_rewards(num_envs); - std::vector host_terminations(num_envs); - std::vector host_truncations(num_envs); - std::vector host_env_ids(num_envs); - std::vector host_lives(num_envs); - std::vector host_frame_numbers(num_envs); - std::vector host_episode_frame_numbers(num_envs); - - for (size_t i = 0; i < num_envs; ++i) { - const auto& timestep = timesteps[i]; - std::memcpy(host_observations.data() + i * stacked_obs_size, - timestep.observation.data(), stacked_obs_size); - host_rewards[i] = timestep.reward; - host_terminations[i] = timestep.terminated ? 1 : 0; - host_truncations[i] = timestep.truncated ? 1 : 0; - host_env_ids[i] = timestep.env_id; - host_lives[i] = timestep.lives; - host_frame_numbers[i] = timestep.frame_number; - host_episode_frame_numbers[i] = timestep.episode_frame_number; - } - - // Copy results to GPU - err = cudaMemcpyAsync(observations_buffer->typed_data(), host_observations.data(), - num_envs * stacked_obs_size, + // Copy data from BatchResult to GPU buffers via host memory + err = cudaMemcpyAsync(observations_buffer->typed_data(), result.obs_data(), + batch_size * stacked_obs_size, cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (observations H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(rewards_buffer->typed_data(), host_rewards.data(), - num_envs * sizeof(int32_t), + err = cudaMemcpyAsync(rewards_buffer->typed_data(), result.rewards_data(), + batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { return ffi::Error::Internal(std::string("CUDA memcpy failed (rewards H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(terminations_buffer->typed_data(), host_terminations.data(), - num_envs * sizeof(uint8_t), + err = cudaMemcpyAsync(env_id_buffer->typed_data(), result.env_ids_data(), + batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (terminations H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (env_ids H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(truncations_buffer->typed_data(), host_truncations.data(), - num_envs * sizeof(uint8_t), + err = cudaMemcpyAsync(lives_buffer->typed_data(), result.lives_data(), + batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (truncations H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (lives H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(env_id_buffer->typed_data(), host_env_ids.data(), - num_envs * sizeof(int32_t), + err = cudaMemcpyAsync(frame_numbers_buffer->typed_data(), result.frame_numbers_data(), + batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (env_ids H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (frame_numbers H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(lives_buffer->typed_data(), host_lives.data(), - num_envs * sizeof(int32_t), + err = cudaMemcpyAsync(episode_frame_numbers_buffer->typed_data(), result.episode_frame_numbers_data(), + batch_size * sizeof(int32_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (lives H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (episode_frame_numbers H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(frame_numbers_buffer->typed_data(), host_frame_numbers.data(), - num_envs * sizeof(int32_t), + // Copy bools element-wise to temporary buffer, then to GPU + std::vector host_terminations(batch_size); + std::vector host_truncations(batch_size); + const bool* term_data = result.terminations_data(); + const bool* trunc_data = result.truncations_data(); + for (size_t i = 0; i < batch_size; ++i) { + host_terminations[i] = term_data[i] ? 1 : 0; + host_truncations[i] = trunc_data[i] ? 1 : 0; + } + + err = cudaMemcpyAsync(terminations_buffer->typed_data(), host_terminations.data(), + batch_size * sizeof(uint8_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (frame_numbers H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (terminations H2D): ") + cudaGetErrorString(err)); } - err = cudaMemcpyAsync(episode_frame_numbers_buffer->typed_data(), host_episode_frame_numbers.data(), - num_envs * sizeof(int32_t), + err = cudaMemcpyAsync(truncations_buffer->typed_data(), host_truncations.data(), + batch_size * sizeof(uint8_t), cudaMemcpyHostToDevice, stream); if (err != cudaSuccess) { - return ffi::Error::Internal(std::string("CUDA memcpy failed (episode_frame_numbers H2D): ") + cudaGetErrorString(err)); + return ffi::Error::Internal(std::string("CUDA memcpy failed (truncations H2D): ") + cudaGetErrorString(err)); } // Check for any CUDA errors diff --git a/src/ale/python/vector_env.py b/src/ale/python/vector_env.py index f69df4c7e..b12a40c79 100644 --- a/src/ale/python/vector_env.py +++ b/src/ale/python/vector_env.py @@ -134,9 +134,12 @@ def __init__( self.batch_size = num_envs if batch_size == 0 else batch_size self.num_envs = num_envs self.autoreset_mode = AutoresetMode(autoreset_mode) - self.metadata["autoreset_mode"] = self.autoreset_mode.value + self.metadata["autoreset_mode"] = self.autoreset_mode - assert not (self.autoreset_mode == AutoresetMode.DISABLED and self.batch_size != self.num_envs) + assert not ( + self.autoreset_mode == AutoresetMode.DISABLED + and self.batch_size != self.num_envs + ) self.observation_space = gymnasium.vector.utils.batch_space( self.single_observation_space, self.batch_size @@ -231,7 +234,10 @@ def recv( def xla(self): """Return XLA-compatible functions for JAX integration.""" - assert self.autoreset_mode == AutoresetMode.NEXT_STEP or self.autoreset_mode == AutoresetMode.DISABLED + assert ( + self.autoreset_mode == AutoresetMode.NEXT_STEP + or self.autoreset_mode == AutoresetMode.DISABLED + ) try: import chex diff --git a/src/ale/vector/CMakeLists.txt b/src/ale/vector/CMakeLists.txt index 1ae6f428a..39264c2bc 100644 --- a/src/ale/vector/CMakeLists.txt +++ b/src/ale/vector/CMakeLists.txt @@ -1,6 +1,10 @@ target_sources(ale PRIVATE + types.hpp + action_queue.hpp + result_staging.hpp preprocessed_env.hpp - async_vectorizer.hpp - utils.hpp + preprocessed_env.cpp + env_vectorizer.hpp + env_vectorizer.cpp ) diff --git a/src/ale/vector/action_queue.hpp b/src/ale/vector/action_queue.hpp new file mode 100644 index 000000000..6061f996b --- /dev/null +++ b/src/ale/vector/action_queue.hpp @@ -0,0 +1,58 @@ +#ifndef ALE_VECTOR_ACTION_QUEUE_HPP_ +#define ALE_VECTOR_ACTION_QUEUE_HPP_ + +#include +#include + +#ifndef MOODYCAMEL_DELETE_FUNCTION + #define MOODYCAMEL_DELETE_FUNCTION = delete +#endif + +#include "ale/external/lightweightsemaphore.h" +#include "types.hpp" + +namespace ale::vector { + +/// Lock-free queue for actions to be processed by worker threads. +/// Supports bulk enqueue and single dequeue. +class ActionQueue { +public: + explicit ActionQueue(std::size_t capacity) + : capacity_(capacity), + queue_(capacity), + alloc_idx_(0), + dequeue_idx_(0), + items_available_(0) {} + + /// Enqueue multiple actions at once. Thread-safe. + void enqueue_bulk(const std::vector& actions) { + std::size_t pos = alloc_idx_.fetch_add(actions.size()); + for (std::size_t i = 0; i < actions.size(); ++i) { + queue_[(pos + i) % capacity_] = actions[i]; + } + items_available_.signal(static_cast(actions.size())); + } + + /// Dequeue a single action. Blocks if queue is empty. Thread-safe. + Action dequeue() { + while (!items_available_.wait()) {} + std::size_t idx = dequeue_idx_.fetch_add(1); + return queue_[idx % capacity_]; + } + + /// Approximate number of items in queue + std::size_t size_approx() const { + return alloc_idx_.load() - dequeue_idx_.load(); + } + +private: + std::size_t capacity_; + std::vector queue_; + std::atomic alloc_idx_; + std::atomic dequeue_idx_; + moodycamel::LightweightSemaphore items_available_; +}; + +} // namespace ale::vector + +#endif // ALE_VECTOR_ACTION_QUEUE_HPP_ diff --git a/src/ale/vector/async_vectorizer.hpp b/src/ale/vector/async_vectorizer.hpp deleted file mode 100644 index 0f222d79f..000000000 --- a/src/ale/vector/async_vectorizer.hpp +++ /dev/null @@ -1,464 +0,0 @@ -#ifndef ALE_VECTOR_ASYNC_VECTORIZER_HPP_ -#define ALE_VECTOR_ASYNC_VECTORIZER_HPP_ - -#include -#include -#include -#include -#include -#include - -#include "ale/external/ThreadPool.h" -#include "utils.hpp" -#include "preprocessed_env.hpp" - -#if defined(_WIN32) || defined(WIN32) || defined(_MSC_VER) - #include -#endif - -namespace ale::vector { - /** - * Batch data from recv() - caller takes ownership of allocated buffers. - */ - struct BatchData { - int* env_ids; // Newly allocated, caller owns - uint8_t* observations; // Newly allocated, caller owns - int* rewards; // Newly allocated, caller owns - bool* terminations; // Newly allocated, caller owns - bool* truncations; // Newly allocated, caller owns - int* lives; // Newly allocated, caller owns - int* frame_numbers; // Newly allocated, caller owns - int* episode_frame_numbers; // Newly allocated, caller owns - - uint8_t* final_observations; // nullptr or newly allocated, caller owns - std::size_t batch_size; // Number of results - }; - - /** - * AsyncVectorizer manages a collection of environments that can be stepped in parallel. - * It handles the (async) distribution of actions to environments and collection of observations. - */ - class AsyncVectorizer { - public: - /** - * Constructor for AsyncVectorizer - * - * @param num_envs The number of parallel environments to run - * @param batch_size The number of environments to process in a batch (0 means use num_envs) - * @param num_threads The number of worker threads to use (0 means use hardware concurrency) - * @param thread_affinity_offset The CPU core offset for thread affinity (-1 means no affinity) - * @param env_factory Function that creates environment instances - * @param autoreset_mode Specify how to automatically reset the sub-environments after an episode ends - */ - explicit AsyncVectorizer( - const int num_envs, - const int batch_size = 0, - const int num_threads = 0, - const int thread_affinity_offset = -1, - const std::function(int)> &env_factory = nullptr, - const AutoresetMode autoreset_mode = AutoresetMode::NextStep - ) : num_envs_(num_envs), - batch_size_(batch_size > 0 ? batch_size : num_envs), - autoreset_mode_(autoreset_mode), - stop_(false), - first_batch_(true), - action_queue_(new ActionQueue(num_envs_)), - pending_obs_buffer_(nullptr), - pending_final_obs_(nullptr), - pending_env_ids_(nullptr), - pending_rewards_(nullptr), - pending_terminations_(nullptr), - pending_truncations_(nullptr), - pending_lives_(nullptr), - pending_frame_numbers_(nullptr), - pending_episode_frame_numbers_(nullptr) { - - // Create environments - envs_.resize(num_envs_); - for (int i = 0; i < num_envs_; ++i) { - envs_[i] = env_factory(i); - } - stacked_obs_size_ = envs_[0]->get_stacked_obs_size(); - - // Create state buffer with observation size - state_buffer_ = std::make_unique(batch_size_, num_envs_, stacked_obs_size_); - - // Setup worker threads - const std::size_t processor_count = std::thread::hardware_concurrency(); - if (num_threads <= 0) { - num_threads_ = std::min(batch_size_, static_cast(processor_count)); - } else { - num_threads_ = num_threads; - } - - // Start worker threads - for (int i = 0; i < num_threads_; ++i) { - workers_.emplace_back([this] { - worker_function(); - }); - } - - // Set thread affinity if requested - if (thread_affinity_offset >= 0) { - set_thread_affinity(thread_affinity_offset, processor_count); - } - } - - /** - * Destructor - stops worker threads and cleans up resources - */ - ~AsyncVectorizer() { - stop_ = true; - // Send empty actions to wake up and terminate all worker threads - const std::vector empty_actions(workers_.size()); - action_queue_->enqueue_bulk(empty_actions); - for (auto& worker : workers_) { - if (worker.joinable()) { - worker.join(); - } - } - } - - /** - * Reset specified environments - * - * @param reset_indices Vector of environment IDs to reset - * @param seeds Vector of seeds to use on reset (use -1 to not change the environment's seed) - */ - void reset(const std::vector& reset_indices, const std::vector& seeds) { - // Allocate output buffers BEFORE enqueueing (prevents race condition) - const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; - pending_obs_buffer_ = new uint8_t[total_obs_size]; - state_buffer_->set_output_buffer(pending_obs_buffer_); - - // Release slots from previous batch (but not on first batch) - if (!first_batch_) { - state_buffer_->release_slots(); - } - first_batch_ = false; - - // Allocate metadata buffers - pending_env_ids_ = new int[batch_size_]; - pending_rewards_ = new int[batch_size_]; - pending_terminations_ = new bool[batch_size_]; - pending_truncations_ = new bool[batch_size_]; - pending_lives_ = new int[batch_size_]; - pending_frame_numbers_ = new int[batch_size_]; - pending_episode_frame_numbers_ = new int[batch_size_]; - state_buffer_->set_metadata_buffers( - pending_env_ids_, - pending_rewards_, - pending_terminations_, - pending_truncations_, - pending_lives_, - pending_frame_numbers_, - pending_episode_frame_numbers_ - ); - - // In SameStep mode, also allocate final_obs buffer - if (autoreset_mode_ == AutoresetMode::SameStep) { - pending_final_obs_ = new uint8_t[total_obs_size]; - state_buffer_->set_final_obs_buffer(pending_final_obs_); - } - - // Prepare reset actions - std::vector reset_actions; - reset_actions.reserve(reset_indices.size()); - - for (size_t i = 0; i < reset_indices.size(); ++i) { - const int env_id = reset_indices[i]; - envs_[env_id]->set_seed(seeds[i]); - - ActionSlice action; - action.env_id = env_id; - action.force_reset = true; - - reset_actions.emplace_back(action); - } - - // Enqueue actions - workers can now safely write to buffer - action_queue_->enqueue_bulk(reset_actions); - } - - /** - * Send actions to the sub-environments - * - * @param actions Vector of actions to send to the sub-environments - */ - void send(const std::vector& actions) { - // Allocate output buffers BEFORE enqueueing (prevents race condition) - const std::size_t total_obs_size = batch_size_ * stacked_obs_size_; - - pending_obs_buffer_ = new uint8_t[total_obs_size]; - pending_env_ids_ = new int[batch_size_]; - pending_rewards_ = new int[batch_size_]; - pending_terminations_ = new bool[batch_size_]; - pending_truncations_ = new bool[batch_size_]; - pending_lives_ = new int[batch_size_]; - pending_frame_numbers_ = new int[batch_size_]; - pending_episode_frame_numbers_ = new int[batch_size_]; - - state_buffer_->set_output_buffer(pending_obs_buffer_); - state_buffer_->set_metadata_buffers( - pending_env_ids_, - pending_rewards_, - pending_terminations_, - pending_truncations_, - pending_lives_, - pending_frame_numbers_, - pending_episode_frame_numbers_ - ); - - // Release slots from previous batch (but not on first batch) - if (!first_batch_) { - state_buffer_->release_slots(); - } - first_batch_ = false; - - // In SameStep mode, also allocate final_obs buffer - if (autoreset_mode_ == AutoresetMode::SameStep) { - pending_final_obs_ = new uint8_t[total_obs_size]; - state_buffer_->set_final_obs_buffer(pending_final_obs_); - } - - // Prepare action slices - std::vector action_slices; - action_slices.reserve(actions.size()); - - for (size_t i = 0; i < actions.size(); i++) { - const int env_id = actions[i].env_id; - envs_[env_id]->set_action(actions[i]); - - ActionSlice action; - action.env_id = env_id; - action.force_reset = false; - - action_slices.emplace_back(action); - } - - // Enqueue actions - workers can now safely write to buffer - action_queue_->enqueue_bulk(action_slices); - } - - /** - * Receive timesteps from the environments. - * Returns ownership of allocated observation buffer to caller. - * - * @return BatchData containing observation data and metadata - */ - BatchData recv() { - // Wait for all workers to complete - state_buffer_->wait_for_batch(); - - // Build result - transfer ownership of all buffers (no copying!) - BatchData result; - result.observations = pending_obs_buffer_; - result.final_observations = pending_final_obs_; - result.env_ids = pending_env_ids_; - result.rewards = pending_rewards_; - result.terminations = pending_terminations_; - result.truncations = pending_truncations_; - result.lives = pending_lives_; - result.frame_numbers = pending_frame_numbers_; - result.episode_frame_numbers = pending_episode_frame_numbers_; - result.batch_size = batch_size_; - - // Clear pending pointers (ownership transferred) - pending_obs_buffer_ = nullptr; - pending_final_obs_ = nullptr; - pending_env_ids_ = nullptr; - pending_rewards_ = nullptr; - pending_terminations_ = nullptr; - pending_truncations_ = nullptr; - pending_lives_ = nullptr; - pending_frame_numbers_ = nullptr; - pending_episode_frame_numbers_ = nullptr; - - // Reset state buffer for next batch - state_buffer_->reset(); - - return result; - } - - const int get_num_envs() const { - return num_envs_; - } - - const int get_batch_size() const { - return batch_size_; - } - - const int get_stacked_obs_size() const { - return stacked_obs_size_; - } - - const AutoresetMode get_autoreset() const { - return autoreset_mode_; - } - - private: - int num_envs_; // Number of parallel environments - int batch_size_; // Batch size for processing - int num_threads_; // Number of worker threads - int stacked_obs_size_; // The observation size (stack-num * width * height * channels) - AutoresetMode autoreset_mode_; // How to reset sub-environments after an episode ends - - std::atomic stop_; // Signal to stop worker threads - bool first_batch_; // Track if this is the first batch (don't release permits) - std::vector workers_; // Worker threads - std::unique_ptr action_queue_; // Queue for actions - std::unique_ptr state_buffer_; // Buffer for observations and metadata - std::vector> envs_; // Environment instances - - // Pending buffers allocated in send()/reset(), returned in recv() - uint8_t* pending_obs_buffer_; // Observations buffer - uint8_t* pending_final_obs_; // Final observations buffer (SameStep mode only) - int* pending_env_ids_; // Env IDs metadata buffer - int* pending_rewards_; // Rewards metadata buffer - bool* pending_terminations_; // Terminations metadata buffer - bool* pending_truncations_; // Truncations metadata buffer - int* pending_lives_; // Lives metadata buffer - int* pending_frame_numbers_; // Frame numbers metadata buffer - int* pending_episode_frame_numbers_; // Episode frame numbers metadata buffer - - /** - * Worker thread function that processes environment steps. - * Writes results directly to pre-allocated output buffer. - */ - void worker_function() { - while (!stop_) { - try { - ActionSlice action = action_queue_->dequeue(); - if (stop_) { - break; - } - - const int env_id = action.env_id; - if (autoreset_mode_ == AutoresetMode::NextStep) { - if (action.force_reset || envs_[env_id]->is_episode_over()) { - envs_[env_id]->reset(); - } else { - envs_[env_id]->step(); - } - - // Get write slot - pointers are into the pre-allocated output buffer (after the reset or step occurs) - WriteSlot slot = state_buffer_->allocate_write_slot(env_id); - envs_[env_id]->write_timestep_to( - slot.obs_dest, - slot.env_id_dest, - slot.reward_dest, - slot.terminated_dest, - slot.truncated_dest, - slot.lives_dest, - slot.frame_number_dest, - slot.episode_frame_number_dest - ); - } else if (autoreset_mode_ == AutoresetMode::SameStep) { - if (action.force_reset) { - envs_[env_id]->reset(); - - // Get write slot - pointers are into the pre-allocated output buffer (after the force reset) - WriteSlot slot = state_buffer_->allocate_write_slot(env_id); - envs_[env_id]->write_timestep_to( - slot.obs_dest, - slot.env_id_dest, - slot.reward_dest, - slot.terminated_dest, - slot.truncated_dest, - slot.lives_dest, - slot.frame_number_dest, - slot.episode_frame_number_dest - ); - } else { - envs_[env_id]->step(); - - // Get write slot - pointers are into the pre-allocated output buffer (after the step) - WriteSlot slot = state_buffer_->allocate_write_slot(env_id); - - if (envs_[env_id]->is_episode_over()) { - // Write current (final) observation before reset - envs_[env_id]->write_observation_to(slot.final_obs_dest); - - // Capture pre-reset metadata temporarily (for reward/terminated/truncated) - int pre_reward; - bool pre_terminated, pre_truncated; - envs_[env_id]->write_metadata_to( - slot.env_id_dest, - &pre_reward, - &pre_terminated, - &pre_truncated, - slot.lives_dest, - slot.frame_number_dest, - slot.episode_frame_number_dest - ); - - // Reset and write new observation - envs_[env_id]->reset(); - envs_[env_id]->write_timestep_to( - slot.obs_dest, - slot.env_id_dest, // overwrites with same value - slot.reward_dest, - slot.terminated_dest, - slot.truncated_dest, - slot.lives_dest, // overwrites with reset lives - slot.frame_number_dest, - slot.episode_frame_number_dest - ); - - // Restore pre-reset reward/terminated/truncated - *slot.reward_dest = pre_reward; - *slot.terminated_dest = pre_terminated; - *slot.truncated_dest = pre_truncated; - } else { - // No episode over - envs_[env_id]->write_timestep_to( - slot.obs_dest, - slot.env_id_dest, - slot.reward_dest, - slot.terminated_dest, - slot.truncated_dest, - slot.lives_dest, - slot.frame_number_dest, - slot.episode_frame_number_dest - ); - } - } - } else { - throw std::runtime_error("Invalid autoreset mode"); - } - - state_buffer_->mark_complete(); - - } catch (const std::exception& e) { - std::cerr << "Error in worker thread: " << e.what() << std::endl; - } - } - } - - /** - * Set thread affinity for worker threads - */ - void set_thread_affinity(const int thread_affinity_offset, const int processor_count) { - for (size_t tid = 0; tid < workers_.size(); ++tid) { - size_t core_id = (thread_affinity_offset + tid) % processor_count; - -#if defined(__linux__) - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(core_id, &cpuset); - pthread_setaffinity_np(workers_[tid].native_handle(), sizeof(cpu_set_t), &cpuset); -#elif defined(_WIN32) - DWORD_PTR mask = (static_cast(1) << core_id); - SetThreadAffinityMask(workers_[tid].native_handle(), mask); -#elif defined(__APPLE__) - thread_affinity_policy_data_t policy = { static_cast(core_id) }; - thread_port_t mach_thread = pthread_mach_thread_np(workers_[tid].native_handle()); - thread_policy_set(mach_thread, THREAD_AFFINITY_POLICY, - (thread_policy_t)&policy, THREAD_AFFINITY_POLICY_COUNT); -#endif - } - } - }; -} - -#endif // ALE_VECTOR_ASYNC_VECTORIZER_HPP_ diff --git a/src/ale/vector/env_vectorizer.cpp b/src/ale/vector/env_vectorizer.cpp new file mode 100644 index 000000000..95cb243cb --- /dev/null +++ b/src/ale/vector/env_vectorizer.cpp @@ -0,0 +1,296 @@ +#include "env_vectorizer.hpp" + +#if defined(__linux__) + #include +#elif defined(_WIN32) + #include +#elif defined(__APPLE__) + #include + #include + #include +#endif + +namespace ale::vector { + +EnvVectorizer::EnvVectorizer( + const fs::path& rom_path, + int num_envs, + int batch_size, + int num_threads, + int thread_affinity_offset, + AutoresetMode autoreset_mode, + int img_height, + int img_width, + int stack_num, + bool grayscale, + int frame_skip, + bool maxpool, + int noop_max, + bool use_fire_reset, + bool episodic_life, + bool life_loss_info, + bool reward_clipping, + int max_episode_steps, + float repeat_action_probability, + bool full_action_space +) : num_envs_(num_envs), + batch_size_(batch_size > 0 ? batch_size : num_envs), + img_height_(img_height), + img_width_(img_width), + stack_num_(stack_num), + grayscale_(grayscale), + autoreset_mode_(autoreset_mode), + last_recv_env_ids_(batch_size_ > 0 ? batch_size_ : num_envs) +{ + // Create environments + envs_.reserve(num_envs_); + for (int i = 0; i < num_envs_; ++i) { + envs_.push_back(std::make_unique( + i, rom_path, img_height, img_width, frame_skip, maxpool, + grayscale, stack_num, noop_max, use_fire_reset, episodic_life, + life_loss_info, reward_clipping, max_episode_steps, + repeat_action_probability, full_action_space, -1 + )); + } + + stacked_obs_size_ = envs_[0]->stacked_obs_size(); + action_set_ = envs_[0]->action_set(); + + // Create action queue (capacity = 2x num_envs for safety) + action_queue_ = std::make_unique(num_envs_ * 2); + + // Create result staging + bool same_step = (autoreset_mode_ == AutoresetMode::SameStep); + staging_ = std::make_unique(batch_size_, num_envs_, stacked_obs_size_, same_step); + + // Determine thread count + int hw_threads = static_cast(std::thread::hardware_concurrency()); + if (num_threads <= 0) { + num_threads_ = std::min(batch_size_, hw_threads); + } else { + num_threads_ = std::min(num_threads, hw_threads); + } + + // Start worker threads + workers_.reserve(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + workers_.emplace_back([this, i] { worker_loop(i); }); + } + + // Set thread affinity if requested + if (thread_affinity_offset >= 0) { + set_thread_affinity(thread_affinity_offset); + } +} + +EnvVectorizer::~EnvVectorizer() { + stop_.store(true); + + // Send dummy actions to wake up blocked workers + std::vector wake_actions(workers_.size()); + for (auto& a : wake_actions) { + a.env_id = 0; + a.force_reset = false; + } + action_queue_->enqueue_bulk(wake_actions); + + // Join all workers + for (auto& worker : workers_) { + if (worker.joinable()) { + worker.join(); + } + } +} + +BatchResult EnvVectorizer::reset(const std::vector& env_ids, const std::vector& seeds) { + if (env_ids.size() != seeds.size()) { + throw std::invalid_argument("env_ids and seeds must have same size"); + } + + // Release slots from previous batch (but not on first batch) + if (!first_batch_ && !staging_->is_ordered()) { + // For unordered mode, we need to handle this carefully + // The staging buffer was already released in recv() + } + first_batch_ = false; + + // Set seeds and prepare actions + std::vector actions; + actions.reserve(env_ids.size()); + for (std::size_t i = 0; i < env_ids.size(); ++i) { + int env_id = env_ids[i]; + envs_[env_id]->set_seed(seeds[i]); + + Action action; + action.env_id = env_id; + action.action_id = 0; + action.paddle_strength = 1.0f; + action.force_reset = true; + actions.push_back(action); + } + + // Enqueue reset actions + action_queue_->enqueue_bulk(actions); + + // Wait for results + return recv(); +} + +void EnvVectorizer::send(const std::vector& actions) { + if (actions.size() != static_cast(batch_size_)) { + throw std::invalid_argument( + "Expected " + std::to_string(batch_size_) + " actions, got " + std::to_string(actions.size()) + ); + } + + // Map actions to correct environments using last_recv_env_ids + std::vector mapped_actions; + mapped_actions.reserve(actions.size()); + + for (std::size_t i = 0; i < actions.size(); ++i) { + Action mapped = actions[i]; + int actual_env_id = last_recv_env_ids_[i]; + mapped.env_id = actual_env_id; + mapped.force_reset = false; + + // Set action on environment + envs_[actual_env_id]->set_action(mapped.action_id, mapped.paddle_strength); + + mapped_actions.push_back(mapped); + } + + // Enqueue actions + action_queue_->enqueue_bulk(mapped_actions); +} + +BatchResult EnvVectorizer::recv() { + // Wait for batch to complete + staging_->wait_for_batch(); + + // Check for errors + check_error(); + + // Release batch and get results + auto result = staging_->release_batch(); + + // Remember env_ids for next send() + std::memcpy(last_recv_env_ids_.data(), result.env_ids_data(), batch_size_ * sizeof(int)); + + return result; +} + +void EnvVectorizer::worker_loop(int thread_id) { + (void)thread_id; // For potential future use (logging, etc.) + + while (!stop_.load()) { + try { + Action action = action_queue_->dequeue(); + + if (stop_.load()) { + break; + } + + execute_env(action); + + } catch (...) { + set_error(std::current_exception()); + } + } +} + +void EnvVectorizer::execute_env(const Action& action) { + int env_id = action.env_id; + auto& env = *envs_[env_id]; + + if (autoreset_mode_ == AutoresetMode::NextStep) { + // NextStep mode: reset happens before step if episode was over + if (action.force_reset || env.is_episode_over()) { + env.reset(); + } else { + env.step(); + } + + // Stage result + staging_->stage_result(env_id, [&](OutputSlot& slot) { + env.write_to(slot); + }); + + } else { // SameStep mode + if (action.force_reset) { + env.reset(); + + staging_->stage_result(env_id, [&](OutputSlot& slot) { + env.write_to(slot); + }); + + } else { + env.step(); + + staging_->stage_result(env_id, [&](OutputSlot& slot) { + if (env.is_episode_over()) { + // Write final observation before reset + env.write_obs_to(slot.final_obs); + + // Capture pre-reset metadata + env.write_to(slot); + int pre_reward = *slot.reward; + bool pre_terminated = *slot.terminated; + bool pre_truncated = *slot.truncated; + + // Reset and write new observation + env.reset(); + env.write_to(slot); + + // Restore pre-reset reward/terminated/truncated + *slot.reward = pre_reward; + *slot.terminated = pre_terminated; + *slot.truncated = pre_truncated; + } else { + env.write_to(slot); + } + }); + } + } +} + +void EnvVectorizer::set_thread_affinity(int thread_affinity_offset) { + int processor_count = static_cast(std::thread::hardware_concurrency()); + + for (std::size_t i = 0; i < workers_.size(); ++i) { + int core_id = (thread_affinity_offset + static_cast(i)) % processor_count; + +#if defined(__linux__) + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); + pthread_setaffinity_np(workers_[i].native_handle(), sizeof(cpu_set_t), &cpuset); +#elif defined(_WIN32) + DWORD_PTR mask = (static_cast(1) << core_id); + SetThreadAffinityMask(workers_[i].native_handle(), mask); +#elif defined(__APPLE__) + thread_affinity_policy_data_t policy = { static_cast(core_id) }; + thread_port_t mach_thread = pthread_mach_thread_np(workers_[i].native_handle()); + thread_policy_set(mach_thread, THREAD_AFFINITY_POLICY, + (thread_policy_t)&policy, THREAD_AFFINITY_POLICY_COUNT); +#endif + } +} + +void EnvVectorizer::set_error(std::exception_ptr e) { + std::lock_guard lock(error_mutex_); + if (!has_error_.load()) { + error_ = e; + has_error_.store(true); + } +} + +void EnvVectorizer::check_error() { + if (has_error_.load()) { + std::lock_guard lock(error_mutex_); + if (error_) { + std::rethrow_exception(error_); + } + } +} + +} // namespace ale::vector diff --git a/src/ale/vector/env_vectorizer.hpp b/src/ale/vector/env_vectorizer.hpp new file mode 100644 index 000000000..9e5fb0fd5 --- /dev/null +++ b/src/ale/vector/env_vectorizer.hpp @@ -0,0 +1,140 @@ +#ifndef ALE_VECTOR_ENV_VECTORIZER_HPP_ +#define ALE_VECTOR_ENV_VECTORIZER_HPP_ + +#include +#include +#include +#include +#include +#include +#include + +#include "ale/common/Constants.h" +#include "types.hpp" +#include "action_queue.hpp" +#include "result_staging.hpp" +#include "preprocessed_env.hpp" + +namespace fs = std::filesystem; + +namespace ale::vector { + +class EnvVectorizer { +public: + EnvVectorizer( + const fs::path& rom_path, + int num_envs, + int batch_size = 0, + int num_threads = 0, + int thread_affinity_offset = -1, + AutoresetMode autoreset_mode = AutoresetMode::NextStep, + int img_height = 84, + int img_width = 84, + int stack_num = 4, + bool grayscale = true, + int frame_skip = 4, + bool maxpool = true, + int noop_max = 30, + bool use_fire_reset = true, + bool episodic_life = false, + bool life_loss_info = false, + bool reward_clipping = true, + int max_episode_steps = 108000, + float repeat_action_probability = 0.0f, + bool full_action_space = false + ); + + ~EnvVectorizer(); + + // Non-copyable, non-movable + EnvVectorizer(const EnvVectorizer&) = delete; + EnvVectorizer& operator=(const EnvVectorizer&) = delete; + + /// Reset specified environments with given seeds. + /// @param env_ids Environment indices to reset + /// @param seeds Seeds for each environment (-1 to keep current seed) + /// @return Batch of results from batch_size environments + BatchResult reset(const std::vector& env_ids, const std::vector& seeds); + + /// Send actions to environments. + /// actions[i] applies to the environment that was at position i in the last recv() result. + /// @param actions Actions with env_id, action_id, paddle_strength + void send(const std::vector& actions); + + /// Receive results from batch_size environments. + /// In ordered mode: returns results for all envs in order + /// In unordered mode: returns results from first batch_size envs to complete + BatchResult recv(); + + // Accessors + int num_envs() const { return num_envs_; } + int batch_size() const { return batch_size_; } + std::size_t stacked_obs_size() const { return stacked_obs_size_; } + const ActionVect& action_set() const { return action_set_; } + AutoresetMode autoreset_mode() const { return autoreset_mode_; } + bool is_grayscale() const { return grayscale_; } + + /// Get observation shape as tuple (stack_num, height, width) or (stack_num, height, width, 3) + std::tuple observation_shape() const { + return grayscale_ + ? std::make_tuple(stack_num_, img_height_, img_width_, 0) + : std::make_tuple(stack_num_, img_height_, img_width_, 3); + } + + /// Get raw pointer for JAX FFI handle + const void* handle() const { return this; } + +private: + // Configuration + int num_envs_; + int batch_size_; + int num_threads_; + int img_height_; + int img_width_; + int stack_num_; + bool grayscale_; + std::size_t stacked_obs_size_; + AutoresetMode autoreset_mode_; + + // Environments + std::vector> envs_; + ActionVect action_set_; + + // Worker threads + std::vector workers_; + std::atomic stop_{false}; + + // Work distribution and result collection + std::unique_ptr action_queue_; + std::unique_ptr staging_; + + // Maps batch position -> env_id for last recv() result + std::vector last_recv_env_ids_; + + // Error handling + std::atomic has_error_{false}; + std::exception_ptr error_; + std::mutex error_mutex_; + + // Track first batch for slot release + bool first_batch_{true}; + + /// Worker thread main loop + void worker_loop(int thread_id); + + /// Execute one environment step or reset + void execute_env(const Action& action); + + /// Set thread CPU affinity + void set_thread_affinity(int thread_affinity_offset); + + /// Record an error from a worker thread + void set_error(std::exception_ptr e); + + /// Check if an error occurred and rethrow + void check_error(); +}; + +} // namespace ale::vector + +#endif // ALE_VECTOR_ENV_VECTORIZER_HPP_ diff --git a/src/ale/vector/preprocessed_env.cpp b/src/ale/vector/preprocessed_env.cpp new file mode 100644 index 000000000..862a086fc --- /dev/null +++ b/src/ale/vector/preprocessed_env.cpp @@ -0,0 +1,267 @@ +#include "preprocessed_env.hpp" + +namespace ale::vector { + +PreprocessedEnv::PreprocessedEnv( + int env_id, + const fs::path& rom_path, + int img_height, + int img_width, + int frame_skip, + bool maxpool, + bool grayscale, + int stack_num, + int noop_max, + bool use_fire_reset, + bool episodic_life, + bool life_loss_info, + bool reward_clipping, + int max_episode_steps, + float repeat_action_probability, + bool full_action_space, + int seed +) : env_id_(env_id), + rom_path_(rom_path), + obs_frame_height_(img_height), + obs_frame_width_(img_width), + frame_skip_(frame_skip), + maxpool_(maxpool), + obs_format_(grayscale ? ObsFormat::Grayscale : ObsFormat::RGB), + channels_per_frame_(grayscale ? 1 : 3), + stack_num_(stack_num), + noop_max_(noop_max), + use_fire_reset_(use_fire_reset), + has_fire_action_(false), + episodic_life_(episodic_life), + life_loss_info_(life_loss_info), + reward_clipping_(reward_clipping), + max_episode_steps_(max_episode_steps), + rng_(seed == -1 ? std::random_device{}() : static_cast(seed)), + noop_dist_(0, noop_max > 0 ? noop_max - 1 : 0), + elapsed_steps_(max_episode_steps + 1), + game_over_(false), + lives_(0), + was_life_lost_(false), + reward_(0), + current_action_id_(PLAYER_A_NOOP), + current_paddle_strength_(1.0f), + pending_seed_(-1), + frame_stack_idx_(0) +{ + // Turn off verbosity + Logger::setMode(Logger::Error); + + // Initialize ALE + ale_ = std::make_unique(); + ale_->setFloat("repeat_action_probability", repeat_action_probability); + ale_->setInt("random_seed", seed); + ale_->loadROM(rom_path_); + + // Get action set + if (full_action_space) { + action_set_ = ale_->getLegalActionSet(); + } else { + action_set_ = ale_->getMinimalActionSet(); + } + + // Check if fire action is available (needed for fire_reset) + if (use_fire_reset_) { + has_fire_action_ = false; + for (const auto& a : action_set_) { + if (a == PLAYER_A_FIRE) { + has_fire_action_ = true; + break; + } + } + } + + const ALEScreen& screen = ale_->getScreen(); + raw_frame_height_ = screen.height(); + raw_frame_width_ = screen.width(); + raw_frame_size_ = raw_frame_height_ * raw_frame_width_; + raw_size_ = raw_frame_height_ * raw_frame_width_ * channels_per_frame_; + obs_size_ = obs_frame_height_ * obs_frame_width_ * channels_per_frame_; + + // Initialize the buffers + for (int i = 0; i < 2; ++i) { + raw_frames_.emplace_back(raw_size_); + } + frame_stack_ = std::vector(stack_num_ * obs_size_, 0); + frame_stack_idx_ = 0; +} + +void PreprocessedEnv::set_seed(int seed) { + pending_seed_ = seed; +} + +void PreprocessedEnv::set_action(int action_id, float paddle_strength) { + current_action_id_ = action_id; + current_paddle_strength_ = paddle_strength; +} + +void PreprocessedEnv::reset() { + if (pending_seed_ >= 0) { + ale_->setInt("random_seed", pending_seed_); + rng_.seed(pending_seed_); + ale_->loadROM(rom_path_); + pending_seed_ = -1; + } + ale_->reset_game(); + + // Press FIRE if required by the environment + if (use_fire_reset_ && has_fire_action_) { + ale_->act(PLAYER_A_FIRE); + } + + // Perform no-op steps + int noop_steps = noop_dist_(rng_) - static_cast(use_fire_reset_ && has_fire_action_); + while (noop_steps > 0) { + ale_->act(PLAYER_A_NOOP); + if (ale_->game_over()) { + ale_->reset_game(); + } + noop_steps--; + } + + // Clear the frame stack + std::fill(frame_stack_.begin(), frame_stack_.end(), 0); + frame_stack_idx_ = 0; + + // Get the screen data and process it + if (obs_format_ == ObsFormat::Grayscale) { + get_screen_grayscale(raw_frames_[0].data()); + } else { + get_screen_rgb(raw_frames_[0].data()); + } + std::fill(raw_frames_[1].begin(), raw_frames_[1].end(), 0); + + // Process the screen + process_screen(); + + // Update state + elapsed_steps_ = 0; + reward_ = 0; + game_over_ = false; + lives_ = ale_->lives(); + was_life_lost_ = false; + current_action_id_ = PLAYER_A_NOOP; +} + +void PreprocessedEnv::step() { + // Validate action + if (current_action_id_ < 0 || current_action_id_ >= static_cast(action_set_.size())) { + throw std::out_of_range("Invalid action_id: " + std::to_string(current_action_id_) + + ", available actions: " + std::to_string(action_set_.size())); + } + const ale::Action action = action_set_[current_action_id_]; + const float strength = current_paddle_strength_; + + // Execute action for frame_skip frames + reward_t reward = 0; + for (int skip_id = frame_skip_; skip_id > 0; --skip_id) { + reward += ale_->act(action, strength); + + game_over_ = ale_->game_over(); + elapsed_steps_++; + was_life_lost_ = ale_->lives() < lives_ && ale_->lives() > 0; + + if (game_over_ || elapsed_steps_ >= max_episode_steps_ || (episodic_life_ && was_life_lost_)) { + break; + } + + // Captures last two frames for maxpooling + if (skip_id <= 2) { + if (obs_format_ == ObsFormat::Grayscale) { + get_screen_grayscale(raw_frames_[skip_id - 1].data()); + } else { + get_screen_rgb(raw_frames_[skip_id - 1].data()); + } + } + } + + // Update state + process_screen(); + lives_ = ale_->lives(); + reward_ = reward_clipping_ ? std::clamp(reward, -1, 1) : reward; +} + +void PreprocessedEnv::write_to(const OutputSlot& slot) const { + *slot.env_id = env_id_; + *slot.reward = reward_; + *slot.terminated = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); + *slot.truncated = elapsed_steps_ >= max_episode_steps_ && !(*slot.terminated); + *slot.lives = lives_; + *slot.frame_number = ale_->getFrameNumber(); + *slot.episode_frame_number = ale_->getEpisodeFrameNumber(); + + // Linearize circular frame_stack to destination + for (int i = 0; i < stack_num_; ++i) { + int src_idx = (frame_stack_idx_ + i) % stack_num_; + std::memcpy( + slot.obs + i * obs_size_, + frame_stack_.data() + src_idx * obs_size_, + obs_size_ + ); + } +} + +void PreprocessedEnv::write_obs_to(uint8_t* dest) const { + for (int i = 0; i < stack_num_; ++i) { + int src_idx = (frame_stack_idx_ + i) % stack_num_; + std::memcpy( + dest + i * obs_size_, + frame_stack_.data() + src_idx * obs_size_, + obs_size_ + ); + } +} + +bool PreprocessedEnv::is_episode_over() const { + return game_over_ || elapsed_steps_ >= max_episode_steps_ || (episodic_life_ && was_life_lost_); +} + +void PreprocessedEnv::get_screen_grayscale(uint8_t* buffer) const { + const ALEScreen& screen = ale_->getScreen(); + uint8_t* ale_screen_data = screen.getArray(); + + ale_->theOSystem->colourPalette().applyPaletteGrayscale( + buffer, ale_screen_data, raw_frame_size_ + ); +} + +void PreprocessedEnv::get_screen_rgb(uint8_t* buffer) const { + const ALEScreen& screen = ale_->getScreen(); + uint8_t* ale_screen_data = screen.getArray(); + + ale_->theOSystem->colourPalette().applyPaletteRGB( + buffer, ale_screen_data, raw_frame_size_ + ); +} + +void PreprocessedEnv::process_screen() { + // Maxpool raw frames if required + if (maxpool_) { + for (int i = 0; i < raw_size_; ++i) { + raw_frames_[0][i] = std::max(raw_frames_[0][i], raw_frames_[1][i]); + } + } + + // Get pointer to current position in circular buffer + uint8_t* dest_ptr = frame_stack_.data() + (frame_stack_idx_ * obs_size_); + + // Resize directly into the circular buffer or copy if no resize needed + if (obs_frame_height_ != raw_frame_height_ || obs_frame_width_ != raw_frame_width_) { + auto cv2_format = (obs_format_ == ObsFormat::Grayscale) ? CV_8UC1 : CV_8UC3; + cv::Mat src_img(raw_frame_height_, raw_frame_width_, cv2_format, raw_frames_[0].data()); + cv::Mat dst_img(obs_frame_height_, obs_frame_width_, cv2_format, dest_ptr); + cv::resize(src_img, dst_img, dst_img.size(), 0, 0, cv::INTER_AREA); + } else { + // No resize needed, copy directly to circular buffer + std::memcpy(dest_ptr, raw_frames_[0].data(), raw_size_); + } + + // Move to next position in circular buffer + frame_stack_idx_ = (frame_stack_idx_ + 1) % stack_num_; +} + +} // namespace ale::vector diff --git a/src/ale/vector/preprocessed_env.hpp b/src/ale/vector/preprocessed_env.hpp index 4e1fe842b..c736e7676 100644 --- a/src/ale/vector/preprocessed_env.hpp +++ b/src/ale/vector/preprocessed_env.hpp @@ -1,454 +1,137 @@ -#ifndef ALE_VECTOR_ATARI_ENV_HPP_ -#define ALE_VECTOR_ATARI_ENV_HPP_ +#ifndef ALE_VECTOR_PREPROCESSED_ENV_HPP_ +#define ALE_VECTOR_PREPROCESSED_ENV_HPP_ #include #include -#include #include -#include #include +#include #include #include "ale/common/Constants.h" #include "ale/ale_interface.hpp" -#include "utils.hpp" +#include "types.hpp" -namespace ale::vector { - - /** - * PreprocessedAtariEnv encapsulates a single Atari environment using the ALE Interface with standard preprocessing and stacking. - */ - class PreprocessedAtariEnv { - public: - /** - * Constructor - * - * @param env_id Unique ID for this environment instance - * @param rom_path Path to the ROM file - * @param obs_height Height to resize frames to for observations - * @param obs_width Width to resize frames to for observations - * @param frame_skip Number of frames for which to repeat the action - * @param maxpool Whether to maxpool observations - * @param obs_format Format of observations (grayscale or RGB) - * @param stack_num Number of frames to stack for observations - * @param noop_max Maximum number of no-ops to perform on resets - * @param use_fire_reset Whether to press FIRE during reset - * @param episodic_life Whether to end episodes when a life is lost - * @param life_loss_info Whether to return `terminated=True` on a life loss but not reset until `lives==0` - * @param reward_clipping Whether to clip the environment rewards between -1 and 1 - * @param max_episode_steps Maximum number of steps per episode before truncating - * @param repeat_action_probability Probability of repeating the last action - * @param full_action_space Whether to use the full action space - * @param seed Random seed - */ - PreprocessedAtariEnv( - const int env_id, - const fs::path &rom_path, - const int obs_height = 84, - const int obs_width = 84, - const int frame_skip = 4, - const bool maxpool = true, - const ObsFormat obs_format = ObsFormat::Grayscale, - const int stack_num = 4, - const int noop_max = 30, - const bool use_fire_reset = true, - const bool episodic_life = false, - const bool life_loss_info = false, - const bool reward_clipping = true, - const int max_episode_steps = 108000, - const float repeat_action_probability = 0.0f, - const bool full_action_space = false, - const int seed = -1 - ) : env_id_(env_id), - rom_path_(rom_path), - obs_frame_height_(obs_height), - obs_frame_width_(obs_width), - frame_skip_(frame_skip), - maxpool_(maxpool), - obs_format_(obs_format), - channels_per_frame_(obs_format == ObsFormat::Grayscale ? 1 : 3), - stack_num_(stack_num), - noop_max_(noop_max), - use_fire_reset_(use_fire_reset), - episodic_life_(episodic_life), - life_loss_info_(life_loss_info), - reward_clipping_(reward_clipping), - max_episode_steps_(max_episode_steps), - rng_gen_(seed == -1 ? std::random_device{}() : seed), - elapsed_step_(max_episode_steps + 1), - // Uninitialised variables - game_over_(false), lives_(0), was_life_lost_(false), reward_(0), - current_action_(EnvironmentAction()), current_seed_(0) - { - // Turn off verbosity - Logger::setMode(Logger::Error); - - // Initialize ALE - env_ = std::make_unique(); - env_->setFloat("repeat_action_probability", repeat_action_probability); - env_->setInt("random_seed", seed); - env_->loadROM(rom_path_); - - // Get action set - if (full_action_space) { - action_set_ = env_->getLegalActionSet(); - } else { - action_set_ = env_->getMinimalActionSet(); - } - - // Check if fire action is available (needed for fire_reset) - if (use_fire_reset_) { - has_fire_action_ = false; - for (const auto a: action_set_) { - if (a == PLAYER_A_FIRE) { - has_fire_action_ = true; - break; - } - } - } - - // Initialize random distribution for no-ops - if (noop_max_ > 0) { - noop_generator_ = std::uniform_int_distribution<>(0, noop_max_ - 1); - } else { - // If noop_max is 0, create a distribution that always returns 0 - noop_generator_ = std::uniform_int_distribution<>(0, 0); - } - - const ALEScreen& screen = env_->getScreen(); - raw_frame_height_ = screen.height(); - raw_frame_width_ = screen.width(); - raw_frame_size_ = raw_frame_height_ * raw_frame_width_; - raw_size_ = raw_frame_height_ * raw_frame_width_ * channels_per_frame_; - obs_size_ = obs_frame_height_ * obs_frame_width_ * channels_per_frame_; - - // Initialize the buffers - for (int i = 0; i < 2; ++i) { - raw_frames_.emplace_back(raw_size_); - } - frame_stack_ = std::vector(stack_num_ * obs_size_, 0); - frame_stack_idx_ = 0; - } - - void set_seed(const int seed) { - current_seed_ = seed; - } - - /** - * Reset the environment and return the initial observation - */ - void reset() { - if (current_seed_ >= 0) { - env_->setInt("random_seed", current_seed_); - rng_gen_.seed(current_seed_); - - env_->loadROM(rom_path_); - current_seed_ = -1; - } - env_->reset_game(); - - // Press FIRE if required by the environment - if (use_fire_reset_ && has_fire_action_) { - env_->act(PLAYER_A_FIRE); - } - - // Perform no-op steps - int noop_steps = noop_generator_(rng_gen_) - static_cast(use_fire_reset_ && has_fire_action_); - while (noop_steps > 0) { - env_->act(PLAYER_A_NOOP); - if (env_->game_over()) { - env_->reset_game(); - } - noop_steps--; - } - - // Clear the frame stack - std::fill(frame_stack_.begin(), frame_stack_.end(), 0); - frame_stack_idx_ = 0; - - // Get the screen data and process it - if (obs_format_ == ObsFormat::Grayscale) { - get_screen_data_grayscale(raw_frames_[0].data()); - } else { - get_screen_data_rgb(raw_frames_[0].data()); - } - std::fill(raw_frames_[1].begin(), raw_frames_[1].end(), 0); - - // Process the screen - process_screen(); - - // Update state - elapsed_step_ = 0; - reward_ = 0; - game_over_ = false; - lives_ = env_->lives(); - was_life_lost_ = false; - current_action_.action_id = PLAYER_A_NOOP; - } - - /** - * Set the action to be taken in the next step - */ - void set_action(const EnvironmentAction& action) { - current_action_ = action; - } - - /** - * Steps the environment using the current action - */ - void step() { - // Convert the current action to Action and Paddle Strength - const int action_id = current_action_.action_id; - if (action_id < 0 || action_id >= action_set_.size()) { - throw std::out_of_range("Stepping sub-environment with action_id: " + std::to_string(action_id) + ", however, this is either less than zero or greater than available actions (" + std::to_string(action_set_.size()) + ")"); - } - const Action action = action_set_[action_id]; - const float strength = current_action_.paddle_strength; - - // Execute action for frame_skip frames - reward_t reward = 0; - for (int skip_id = frame_skip_; skip_id > 0; --skip_id) { - reward += env_->act(action, strength); - - game_over_ = env_->game_over(); - elapsed_step_++; - was_life_lost_ = env_->lives() < lives_ && env_->lives() > 0; - - if (game_over_ || elapsed_step_ >= max_episode_steps_ || (episodic_life_ && was_life_lost_)) { - break; - } - - // Captures last two frames for maxpooling - if (skip_id <= 2) { - if (obs_format_ == ObsFormat::Grayscale) { - get_screen_data_grayscale(raw_frames_[skip_id - 1].data()); - } else { - get_screen_data_rgb(raw_frames_[skip_id - 1].data()); - } - } - } - - // Update state - process_screen(); - lives_ = env_->lives(); - reward_ = reward_clipping_ ? std::clamp(reward, -1, 1) : reward; - } +namespace fs = std::filesystem; - /** - * Write timestep data directly to provided destinations. - * Avoids allocating intermediate vectors. - * - * @param obs_dest Pointer to write linearized observation (size: stack_num * obs_size) - * @param env_id_dest Pointer to write env_id - * @param reward_dest Pointer to write reward - * @param terminated_dest Pointer to write terminated flag - * @param truncated_dest Pointer to write truncated flag - * @param lives_dest Pointer to write lives - * @param frame_number_dest Pointer to write frame_number - * @param episode_frame_number_dest Pointer to write episode_frame_number - */ - void write_timestep_to( - uint8_t* obs_dest, - int* env_id_dest, - int* reward_dest, - bool* terminated_dest, - bool* truncated_dest, - int* lives_dest, - int* frame_number_dest, - int* episode_frame_number_dest - ) const { - // Write metadata directly to BatchData arrays - *env_id_dest = env_id_; - *reward_dest = reward_; - *terminated_dest = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); - *truncated_dest = elapsed_step_ >= max_episode_steps_ && !(*terminated_dest); - *lives_dest = lives_; - *frame_number_dest = env_->getFrameNumber(); - *episode_frame_number_dest = env_->getEpisodeFrameNumber(); - - // Linearize circular frame_stack directly to destination - for (int i = 0; i < stack_num_; ++i) { - const int src_idx = (frame_stack_idx_ + i) % stack_num_; - std::memcpy( - obs_dest + i * obs_size_, - frame_stack_.data() + src_idx * obs_size_, - obs_size_ - ); - } - } - - /** - * Write only observation to destination (for final_obs in SameStep mode). - * - * @param obs_dest Pointer to write linearized observation - */ - void write_observation_to(uint8_t* obs_dest) const { - for (int i = 0; i < stack_num_; ++i) { - const int src_idx = (frame_stack_idx_ + i) % stack_num_; - std::memcpy( - obs_dest + i * obs_size_, - frame_stack_.data() + src_idx * obs_size_, - obs_size_ - ); - } - } - - /** - * Write only metadata (used to capture state before reset in SameStep mode). - * - * @param env_id_dest Pointer to write env_id - * @param reward_dest Pointer to write reward - * @param terminated_dest Pointer to write terminated flag - * @param truncated_dest Pointer to write truncated flag - * @param lives_dest Pointer to write lives - * @param frame_number_dest Pointer to write frame_number - * @param episode_frame_number_dest Pointer to write episode_frame_number - */ - void write_metadata_to( - int* env_id_dest, - int* reward_dest, - bool* terminated_dest, - bool* truncated_dest, - int* lives_dest, - int* frame_number_dest, - int* episode_frame_number_dest - ) const { - *env_id_dest = env_id_; - *reward_dest = reward_; - *terminated_dest = game_over_ || ((life_loss_info_ || episodic_life_) && was_life_lost_); - *truncated_dest = elapsed_step_ >= max_episode_steps_ && !(*terminated_dest); - *lives_dest = lives_; - *frame_number_dest = env_->getFrameNumber(); - *episode_frame_number_dest = env_->getEpisodeFrameNumber(); - } - - /** - * Check if the episode is over (terminated or truncated) - */ - const bool is_episode_over() const { - return game_over_ || elapsed_step_ >= max_episode_steps_ || (episodic_life_ && was_life_lost_); - } - - /** - * Get the list of available actions - */ - const ActionVect& get_action_set() const { - return action_set_; - } - - /** - * Get observation size - */ - const int get_stacked_obs_size() const { - return obs_size_ * stack_num_; - } - - /** - * Get channels per frame - */ - const int get_channels_per_frame() const { - return channels_per_frame_; - } - - private: - /** - * Get the current screen data from ALE in grayscale format - */ - void get_screen_data_grayscale(uint8_t* buffer) const { - const ALEScreen& screen = env_->getScreen(); - uint8_t* ale_screen_data = screen.getArray(); - - env_->theOSystem->colourPalette().applyPaletteGrayscale( - buffer, ale_screen_data, raw_frame_size_ - ); - } - - /** - * Get the current screen data from ALE in RGB format - */ - void get_screen_data_rgb(uint8_t* buffer) const { - const ALEScreen& screen = env_->getScreen(); - uint8_t* ale_screen_data = screen.getArray(); - - env_->theOSystem->colourPalette().applyPaletteRGB( - buffer, ale_screen_data, raw_frame_size_ - ); - } - - /** - * Process the screen and update the frame stack - */ - void process_screen() { - // Maxpool raw frames if required (different for grayscale and RGB) - if (maxpool_) { - for (int i = 0; i < raw_size_; ++i) { - raw_frames_[0][i] = std::max(raw_frames_[0][i], raw_frames_[1][i]); - } - } - - // Get pointer to current position in circular buffer - uint8_t* dest_ptr = frame_stack_.data() + (frame_stack_idx_ * obs_size_); - - // Resize directly into the circular buffer or copy if no resize needed - if (obs_frame_height_ != raw_frame_height_ || obs_frame_width_ != raw_frame_width_) { - auto cv2_format = (obs_format_ == ObsFormat::Grayscale) ? CV_8UC1 : CV_8UC3; - cv::Mat src_img(raw_frame_height_, raw_frame_width_, cv2_format, raw_frames_[0].data()); - cv::Mat dst_img(obs_frame_height_, obs_frame_width_, cv2_format, dest_ptr); - cv::resize(src_img, dst_img, dst_img.size(), 0, 0, cv::INTER_AREA); - } else { - // No resize needed, copy directly to circular buffer - std::memcpy(dest_ptr, raw_frames_[0].data(), raw_size_); - } - - // Move to next position in circular buffer - frame_stack_idx_ = (frame_stack_idx_ + 1) % stack_num_; - } - - int env_id_; // Unique ID for this environment - fs::path rom_path_; // Path to the ROM file - std::unique_ptr env_; // ALE interface - - ActionVect action_set_; // Available actions - - ObsFormat obs_format_; // Format of observations (grayscale or RGB) - int channels_per_frame_; // The number of channels for each frame based on obs_format - int raw_frame_height_; // The raw frame height - int raw_frame_width_; // The raw frame width - int raw_frame_size_; // The raw frame size (height * width) - int raw_size_; - int obs_frame_height_; // Height to resize frames to for observations - int obs_frame_width_; // Width to resize frames to for observations - int obs_size_; // Observation size (height * width * channels) - int stack_num_; // Number of frames to stack for observations - - int frame_skip_; // Number of frames for which to repeat the action - bool maxpool_; // Whether to maxpool observations - int noop_max_; // Maximum number of no-ops at reset - bool use_fire_reset_; // Whether to press FIRE during reset - bool has_fire_action_; // Whether FIRE action is available for reset - bool episodic_life_; // Whether to end episodes when a life is lost - bool life_loss_info_; // If to provide termination signal (but not reset) on life loss - bool reward_clipping_; // If to clip rewards between -1 and 1 - int max_episode_steps_; // Maximum number of steps per episode before truncating - - std::mt19937 rng_gen_; // Random number generator - std::uniform_int_distribution<> noop_generator_; // Distribution for no-op steps - - int elapsed_step_; // Current step in the episode - bool game_over_; // Whether the game is over - int lives_; // Current number of lives - bool was_life_lost_; // If a life is loss from a step - reward_t reward_; // Last reward received - - EnvironmentAction current_action_; // Current action to take - int current_seed_; // Current seed to update - - // Frame buffers - std::vector> raw_frames_; // Raw frame buffers for maxpooling - std::vector frame_stack_; // Stack of recent frames - int frame_stack_idx_; // Frame stack index - }; -} +namespace ale::vector { -#endif // ALE_VECTOR_ATARI_ENV_HPP_ +/// Single ALE environment with standard preprocessing: +/// - Frame skipping with max-pooling +/// - Grayscale or RGB observations +/// - Resize to specified dimensions +/// - Frame stacking +/// - Noop-max on reset +/// - Fire on reset (for games that require it) +/// - Episodic life / life loss info +/// - Reward clipping +class PreprocessedEnv { +public: + PreprocessedEnv( + int env_id, + const fs::path& rom_path, + int img_height = 84, + int img_width = 84, + int frame_skip = 4, + bool maxpool = true, + bool grayscale = true, + int stack_num = 4, + int noop_max = 30, + bool use_fire_reset = true, + bool episodic_life = false, + bool life_loss_info = false, + bool reward_clipping = true, + int max_episode_steps = 108000, + float repeat_action_probability = 0.0f, + bool full_action_space = false, + int seed = -1 + ); + + /// Set seed for next reset + void set_seed(int seed); + + /// Set action for next step + void set_action(int action_id, float paddle_strength); + + /// Reset environment + void reset(); + + /// Step environment using previously set action + void step(); + + /// Write current state to output slot + void write_to(const OutputSlot& slot) const; + + /// Write only observation to destination (for final_obs before reset) + void write_obs_to(uint8_t* dest) const; + + /// Check if episode is over + bool is_episode_over() const; + + /// Get available actions + const ActionVect& action_set() const { return action_set_; } + + /// Get stacked observation size in bytes + std::size_t stacked_obs_size() const { return obs_size_ * stack_num_; } + + /// Get channels per frame (1 for grayscale, 3 for RGB) + int channels_per_frame() const { return channels_per_frame_; } + +private: + void get_screen_grayscale(uint8_t* buffer) const; + void get_screen_rgb(uint8_t* buffer) const; + void process_screen(); + + int env_id_; + fs::path rom_path_; + std::unique_ptr ale_; + + ActionVect action_set_; + + // Observation settings + ObsFormat obs_format_; + int channels_per_frame_; + int raw_frame_height_; + int raw_frame_width_; + int raw_frame_size_; + int raw_size_; + int obs_frame_height_; + int obs_frame_width_; + int obs_size_; + int stack_num_; + + // Preprocessing settings + int frame_skip_; + bool maxpool_; + int noop_max_; + bool use_fire_reset_; + bool has_fire_action_; + bool episodic_life_; + bool life_loss_info_; + bool reward_clipping_; + int max_episode_steps_; + + // RNG + std::mt19937 rng_; + std::uniform_int_distribution<> noop_dist_; + + // State + int elapsed_steps_; + bool game_over_; + int lives_; + bool was_life_lost_; + int reward_; + int current_action_id_; + float current_paddle_strength_; + int pending_seed_; + + // Frame buffers + std::vector> raw_frames_; + std::vector frame_stack_; + int frame_stack_idx_; +}; + +} // namespace ale::vector + +#endif // ALE_VECTOR_PREPROCESSED_ENV_HPP_ diff --git a/src/ale/vector/result_staging.hpp b/src/ale/vector/result_staging.hpp new file mode 100644 index 000000000..4f2d12cfe --- /dev/null +++ b/src/ale/vector/result_staging.hpp @@ -0,0 +1,126 @@ +#ifndef ALE_VECTOR_RESULT_STAGING_HPP_ +#define ALE_VECTOR_RESULT_STAGING_HPP_ + +#include +#include +#include + +#ifndef MOODYCAMEL_DELETE_FUNCTION + #define MOODYCAMEL_DELETE_FUNCTION = delete +#endif + +#include "ale/external/lightweightsemaphore.h" +#include "types.hpp" + +namespace ale::vector { + +/// Manages result collection with backpressure for async (unordered) mode. +/// +/// Workers call stage_result() after completing work. In unordered mode, +/// if batch_size slots are already filled, the worker blocks until recv() +/// releases the current batch. +class ResultStaging { +public: + ResultStaging(std::size_t batch_size, std::size_t num_envs, std::size_t obs_size, bool same_step_mode) + : batch_size_(batch_size), + num_envs_(num_envs), + obs_size_(obs_size), + ordered_mode_(batch_size == num_envs), + same_step_mode_(same_step_mode), + current_batch_(std::make_unique(batch_size, obs_size, same_step_mode)), + staged_count_(0), + next_slot_(0), + slots_available_(static_cast(batch_size)), + batch_ready_(0) {} + + /// Stage a result from a worker thread. + /// In ordered mode: writes to slot[env_id] + /// In unordered mode: atomically allocates next slot, may block if batch is full + /// + /// @param env_id The environment that produced this result + /// @param write_fn Callback that writes data to the provided OutputSlot + void stage_result(int env_id, const std::function& write_fn) { + std::size_t slot; + + if (ordered_mode_) { + slot = static_cast(env_id); + } else { + // Acquire a slot permit (blocks if batch is full) + while (!slots_available_.wait()) {} + slot = next_slot_.fetch_add(1) % batch_size_; + } + + // Build output slot pointing into current batch + OutputSlot output; + output.obs = current_batch_->obs_data() + slot * obs_size_; + output.env_id = ¤t_batch_->env_ids_data()[slot]; + output.reward = ¤t_batch_->rewards_data()[slot]; + output.terminated = ¤t_batch_->terminations_data()[slot]; + output.truncated = ¤t_batch_->truncations_data()[slot]; + output.lives = ¤t_batch_->lives_data()[slot]; + output.frame_number = ¤t_batch_->frame_numbers_data()[slot]; + output.episode_frame_number = ¤t_batch_->episode_frame_numbers_data()[slot]; + output.final_obs = same_step_mode_ + ? current_batch_->final_obs_data() + slot * obs_size_ + : nullptr; + + // Let worker write its data + write_fn(output); + + // Signal completion + std::size_t completed = staged_count_.fetch_add(1) + 1; + if (completed == batch_size_) { + batch_ready_.signal(1); + } + } + + /// Wait for batch_size results to be staged. Called by recv(). + void wait_for_batch() { + while (!batch_ready_.wait()) {} + } + + /// Release current batch and prepare for next. + /// Returns the completed batch (transfers ownership). + /// Releases slot permits for blocked workers. + BatchResult release_batch() { + // Take ownership of completed batch + auto result = std::move(*current_batch_); + + // Allocate fresh batch for next round + current_batch_ = std::make_unique(batch_size_, obs_size_, same_step_mode_); + + // Reset counters + staged_count_.store(0); + next_slot_.store(0); + + // Release permits for blocked workers (they'll write to new batch) + if (!ordered_mode_) { + slots_available_.signal(static_cast(batch_size_)); + } + + return result; + } + + std::size_t batch_size() const { return batch_size_; } + std::size_t obs_size() const { return obs_size_; } + bool is_ordered() const { return ordered_mode_; } + +private: + const std::size_t batch_size_; + const std::size_t num_envs_; + const std::size_t obs_size_; + const bool ordered_mode_; + const bool same_step_mode_; + + std::unique_ptr current_batch_; + + std::atomic staged_count_; + std::atomic next_slot_; + + moodycamel::LightweightSemaphore slots_available_; // Permits for staging (unordered only) + moodycamel::LightweightSemaphore batch_ready_; // Signaled when batch is full +}; + +} // namespace ale::vector + +#endif // ALE_VECTOR_RESULT_STAGING_HPP_ diff --git a/src/ale/vector/types.hpp b/src/ale/vector/types.hpp new file mode 100644 index 000000000..4a96756b5 --- /dev/null +++ b/src/ale/vector/types.hpp @@ -0,0 +1,181 @@ +#ifndef ALE_VECTOR_TYPES_HPP_ +#define ALE_VECTOR_TYPES_HPP_ + +#include +#include +#include +#include +#include + +namespace ale::vector { + +/// Autoreset behavior when episode ends +enum class AutoresetMode { + NextStep, // Reset on next step() call (observation is first frame of new episode) + SameStep // Reset immediately, return final_obs separately +}; + +/// Observation format +enum class ObsFormat { + Grayscale, + RGB +}; + +/// Action to execute in an environment +struct Action { + int env_id; + int action_id; + float paddle_strength; + bool force_reset; +}; + +/// Pointers for worker to write environment output directly into batch buffers +struct OutputSlot { + uint8_t* obs; + int* env_id; + int* reward; + bool* terminated; + bool* truncated; + int* lives; + int* frame_number; + int* episode_frame_number; + uint8_t* final_obs; // nullptr if not SameStep mode or not needed +}; + +/// Batch of results with ownership semantics +/// Owns all buffers. Supports releasing ownership for Python handoff. +class BatchResult { +public: + BatchResult(std::size_t batch_size, std::size_t obs_size, bool include_final_obs) + : batch_size_(batch_size), + obs_size_(obs_size), + observations_(new uint8_t[batch_size * obs_size]), + env_ids_(new int[batch_size]), + rewards_(new int[batch_size]), + terminations_(new bool[batch_size]), + truncations_(new bool[batch_size]), + lives_(new int[batch_size]), + frame_numbers_(new int[batch_size]), + episode_frame_numbers_(new int[batch_size]), + final_observations_(include_final_obs ? new uint8_t[batch_size * obs_size] : nullptr) {} + + ~BatchResult() { + delete[] observations_; + delete[] env_ids_; + delete[] rewards_; + delete[] terminations_; + delete[] truncations_; + delete[] lives_; + delete[] frame_numbers_; + delete[] episode_frame_numbers_; + delete[] final_observations_; + } + + // Move only + BatchResult(BatchResult&& other) noexcept + : batch_size_(other.batch_size_), + obs_size_(other.obs_size_), + observations_(other.observations_), + env_ids_(other.env_ids_), + rewards_(other.rewards_), + terminations_(other.terminations_), + truncations_(other.truncations_), + lives_(other.lives_), + frame_numbers_(other.frame_numbers_), + episode_frame_numbers_(other.episode_frame_numbers_), + final_observations_(other.final_observations_) { + other.observations_ = nullptr; + other.env_ids_ = nullptr; + other.rewards_ = nullptr; + other.terminations_ = nullptr; + other.truncations_ = nullptr; + other.lives_ = nullptr; + other.frame_numbers_ = nullptr; + other.episode_frame_numbers_ = nullptr; + other.final_observations_ = nullptr; + } + + BatchResult& operator=(BatchResult&& other) noexcept { + if (this != &other) { + delete[] observations_; + delete[] env_ids_; + delete[] rewards_; + delete[] terminations_; + delete[] truncations_; + delete[] lives_; + delete[] frame_numbers_; + delete[] episode_frame_numbers_; + delete[] final_observations_; + + batch_size_ = other.batch_size_; + obs_size_ = other.obs_size_; + observations_ = other.observations_; + env_ids_ = other.env_ids_; + rewards_ = other.rewards_; + terminations_ = other.terminations_; + truncations_ = other.truncations_; + lives_ = other.lives_; + frame_numbers_ = other.frame_numbers_; + episode_frame_numbers_ = other.episode_frame_numbers_; + final_observations_ = other.final_observations_; + + other.observations_ = nullptr; + other.env_ids_ = nullptr; + other.rewards_ = nullptr; + other.terminations_ = nullptr; + other.truncations_ = nullptr; + other.lives_ = nullptr; + other.frame_numbers_ = nullptr; + other.episode_frame_numbers_ = nullptr; + other.final_observations_ = nullptr; + } + return *this; + } + + BatchResult(const BatchResult&) = delete; + BatchResult& operator=(const BatchResult&) = delete; + + // Data access for workers to write into + uint8_t* obs_data() { return observations_; } + uint8_t* final_obs_data() { return final_observations_; } + int* env_ids_data() { return env_ids_; } + int* rewards_data() { return rewards_; } + bool* terminations_data() { return terminations_; } + bool* truncations_data() { return truncations_; } + int* lives_data() { return lives_; } + int* frame_numbers_data() { return frame_numbers_; } + int* episode_frame_numbers_data() { return episode_frame_numbers_; } + + // Release ownership - returns pointer and nulls internal pointer + // Caller takes ownership and must delete[] + uint8_t* release_observations() { auto p = observations_; observations_ = nullptr; return p; } + uint8_t* release_final_observations() { auto p = final_observations_; final_observations_ = nullptr; return p; } + int* release_env_ids() { auto p = env_ids_; env_ids_ = nullptr; return p; } + int* release_rewards() { auto p = rewards_; rewards_ = nullptr; return p; } + bool* release_terminations() { auto p = terminations_; terminations_ = nullptr; return p; } + bool* release_truncations() { auto p = truncations_; truncations_ = nullptr; return p; } + int* release_lives() { auto p = lives_; lives_ = nullptr; return p; } + int* release_frame_numbers() { auto p = frame_numbers_; frame_numbers_ = nullptr; return p; } + int* release_episode_frame_numbers() { auto p = episode_frame_numbers_; episode_frame_numbers_ = nullptr; return p; } + + std::size_t batch_size() const { return batch_size_; } + std::size_t obs_size() const { return obs_size_; } + bool has_final_obs() const { return final_observations_ != nullptr; } + +private: + std::size_t batch_size_; + std::size_t obs_size_; + uint8_t* observations_; + int* env_ids_; + int* rewards_; + bool* terminations_; + bool* truncations_; + int* lives_; + int* frame_numbers_; + int* episode_frame_numbers_; + uint8_t* final_observations_; +}; + +} // namespace ale::vector + +#endif // ALE_VECTOR_TYPES_HPP_ diff --git a/src/ale/vector/utils.hpp b/src/ale/vector/utils.hpp deleted file mode 100644 index b7499fc4c..000000000 --- a/src/ale/vector/utils.hpp +++ /dev/null @@ -1,330 +0,0 @@ -#ifndef ALE_VECTOR_UTILS_HPP_ -#define ALE_VECTOR_UTILS_HPP_ - -#include -#include -#include -#include - -#ifndef MOODYCAMEL_DELETE_FUNCTION - #define MOODYCAMEL_DELETE_FUNCTION = delete -#endif - -#include "ale/common/Constants.h" -#include "ale/external/lightweightsemaphore.h" - -namespace ale::vector { - - /** - * ActionSlice represents a single action or command to be processed by a worker thread - */ - struct ActionSlice { - int env_id; // ID of the environment to apply the action to - bool force_reset; // Whether to force a reset of the environment - }; - - /** - * EnvironmentAction represents an action to be taken in an environment - */ - struct EnvironmentAction { - int env_id; // ID of the environment to apply the action to - int action_id; // ID of the action to take - float paddle_strength; // Strength for paddle-based games (default: 1.0) - }; - - /** - * WriteSlot provides destinations for workers to write data directly. - * All pointers point into externally allocated BatchData arrays. - */ - struct WriteSlot { - int slot_index; // Index in the batch - uint8_t* obs_dest; // Pointer to write observation data - int* env_id_dest; // Pointer to write env_id - int* reward_dest; // Pointer to write reward - bool* terminated_dest; // Pointer to write terminated flag - bool* truncated_dest; // Pointer to write truncated flag - int* lives_dest; // Pointer to write lives - int* frame_number_dest; // Pointer to write frame_number - int* episode_frame_number_dest; // Pointer to write episode_frame_number - uint8_t* final_obs_dest; // Pointer for final_obs (SameStep mode) - }; - - /** - * Observation format enumeration - */ - enum class ObsFormat { - Grayscale, // Single channel grayscale observations - RGB // Three channel RGB observations - }; - - enum class AutoresetMode { - NextStep, // Will reset the sub-environment in the next step if the episode ended in the previous timestep - SameStep // Will reset the sub-environment in the same timestep if the episode ended - }; - - /** - * Lock-free queue for actions to be processed by worker threads - */ - class ActionQueue { - public: - explicit ActionQueue(const std::size_t num_envs) - : alloc_ptr_(0), - done_ptr_(0), - queue_size_(num_envs * 2), - queue_(queue_size_), - sem_(0), - sem_enqueue_(1), - sem_dequeue_(1) {} - - /** - * Enqueue multiple actions at once - */ - void enqueue_bulk(const std::vector& actions) { - while (!sem_enqueue_.wait()) {} - - const uint64_t pos = alloc_ptr_.fetch_add(actions.size()); - for (std::size_t i = 0; i < actions.size(); ++i) { - queue_[(pos + i) % queue_size_] = actions[i]; - } - - sem_.signal(actions.size()); - sem_enqueue_.signal(1); - } - - /** - * Dequeue a single action - */ - ActionSlice dequeue() { - while (!sem_.wait()) {} - while (!sem_dequeue_.wait()) {} - - const auto ptr = done_ptr_.fetch_add(1); - const auto ret = queue_[ptr % queue_size_]; - - sem_dequeue_.signal(1); - return ret; - } - - /** - * Get the approximate size of the queue - */ - std::size_t size_approx() const { - return alloc_ptr_ - done_ptr_; - } - - private: - std::atomic alloc_ptr_; // Pointer to next allocation position - std::atomic done_ptr_; // Pointer to next dequeue position - std::size_t queue_size_; // Size of the queue - std::vector queue_; // The actual queue data - moodycamel::LightweightSemaphore sem_; // Semaphore for queue access - moodycamel::LightweightSemaphore sem_enqueue_; // Semaphore for enqueue operations - moodycamel::LightweightSemaphore sem_dequeue_; // Semaphore for dequeue operations - }; - - /** - * StateBuffer manages output buffers for vectorized environment results. - * - * The buffer is set externally before workers begin writing. - * Workers write directly to allocated slots, avoiding intermediate copies. - * - * Two modes of operation: - * 1. Ordered mode (batch_size == num_envs): Slot index equals env_id - * 2. Unordered mode (batch_size != num_envs): Atomic slot allocation - */ - class StateBuffer { - public: - StateBuffer(const std::size_t batch_size, const std::size_t num_envs, const std::size_t obs_size) - : batch_size_(batch_size), - num_envs_(num_envs), - obs_size_(obs_size), - ordered_mode_(batch_size == num_envs), - output_obs_buffer_(nullptr), - final_obs_buffer_(nullptr), - env_ids_buffer_(nullptr), - rewards_buffer_(nullptr), - terminations_buffer_(nullptr), - truncations_buffer_(nullptr), - lives_buffer_(nullptr), - frame_numbers_buffer_(nullptr), - episode_frame_numbers_buffer_(nullptr), - count_(0), - write_idx_(0), - sem_ready_(0), - sem_read_(1), - sem_slots_(batch_size) {} // Initialize with batch_size permits - - /** - * Set the output buffer that workers will write observations into. - * MUST be called before enqueueing any actions that will use this buffer. - * - * @param obs_buffer Pointer to allocated buffer of size batch_size * obs_size - */ - void set_output_buffer(uint8_t* obs_buffer) { - output_obs_buffer_ = obs_buffer; - } - - /** - * Set the final_obs output buffer for SameStep autoreset mode. - * - * @param final_obs_buffer Pointer to allocated buffer of size batch_size * obs_size - */ - void set_final_obs_buffer(uint8_t* final_obs_buffer) { - final_obs_buffer_ = final_obs_buffer; - } - - /** - * Set the metadata output buffers that workers will write into. - * MUST be called before enqueueing any actions that will use these buffers. - * - * @param env_ids Pointer to allocated array of size batch_size - * @param rewards Pointer to allocated array of size batch_size - * @param terminations Pointer to allocated array of size batch_size - * @param truncations Pointer to allocated array of size batch_size - * @param lives Pointer to allocated array of size batch_size - * @param frame_numbers Pointer to allocated array of size batch_size - * @param episode_frame_numbers Pointer to allocated array of size batch_size - */ - void set_metadata_buffers( - int* env_ids, - int* rewards, - bool* terminations, - bool* truncations, - int* lives, - int* frame_numbers, - int* episode_frame_numbers - ) { - env_ids_buffer_ = env_ids; - rewards_buffer_ = rewards; - terminations_buffer_ = terminations; - truncations_buffer_ = truncations; - lives_buffer_ = lives; - frame_numbers_buffer_ = frame_numbers; - episode_frame_numbers_buffer_ = episode_frame_numbers; - } - - /** - * Allocate a write slot for a worker thread. - * Returns pointers for direct writing into the output buffer. - * - * Thread-safe: multiple workers can call simultaneously. - * In unordered mode, blocks if all slots are occupied. - * - * @param env_id The environment ID requesting a slot - * @return WriteSlot with pointers into output buffers - */ - WriteSlot allocate_write_slot(int env_id) { - // In unordered mode, block if all slots are occupied - if (!ordered_mode_) { - while (!sem_slots_.wait()) {} // Acquire permit, blocks if none available - } - - WriteSlot slot; - - if (ordered_mode_) { - // In ordered mode, slot index equals env_id - slot.slot_index = env_id; - } else { - // In unordered mode, atomically allocate next available slot - slot.slot_index = static_cast(write_idx_.fetch_add(1) % batch_size_); - } - - const int idx = slot.slot_index; - - // Set observation pointers - slot.obs_dest = output_obs_buffer_ + idx * obs_size_; - - // Set final_obs pointer (only used in SameStep mode, nullptr in NextStep mode) - slot.final_obs_dest = final_obs_buffer_ != nullptr - ? final_obs_buffer_ + idx * obs_size_ - : nullptr; - - // Set metadata pointers (directly into BatchData arrays) - slot.env_id_dest = &env_ids_buffer_[idx]; - slot.reward_dest = &rewards_buffer_[idx]; - slot.terminated_dest = &terminations_buffer_[idx]; - slot.truncated_dest = &truncations_buffer_[idx]; - slot.lives_dest = &lives_buffer_[idx]; - slot.frame_number_dest = &frame_numbers_buffer_[idx]; - slot.episode_frame_number_dest = &episode_frame_numbers_buffer_[idx]; - - return slot; - } - - /** - * Mark a slot as complete. Called by worker after writing all data. - * When all slots are complete, signals that batch is ready. - */ - void mark_complete() { - const auto old_count = count_.fetch_add(1); - if (old_count + 1 == batch_size_) { - sem_ready_.signal(1); - } - } - - /** - * Wait for batch to complete. Blocks until all slots are filled. - */ - void wait_for_batch() { - while (!sem_ready_.wait()) {} - } - - /** - * Reset state for next batch. Must be called after collecting results. - */ - void reset() { - count_.store(0); - write_idx_.store(0); - output_obs_buffer_ = nullptr; - final_obs_buffer_ = nullptr; - env_ids_buffer_ = nullptr; - rewards_buffer_ = nullptr; - terminations_buffer_ = nullptr; - truncations_buffer_ = nullptr; - lives_buffer_ = nullptr; - frame_numbers_buffer_ = nullptr; - episode_frame_numbers_buffer_ = nullptr; - } - - /** - * Release all slots for the next batch. - * Called by recv() after transferring buffer ownership to Python. - * This allows waiting workers to proceed and allocate slots. - */ - void release_slots() { - if (!ordered_mode_) { - sem_slots_.signal(batch_size_); // Release batch_size permits - } - } - - // Accessors - std::size_t get_batch_size() const { return batch_size_; } - std::size_t get_obs_size() const { return obs_size_; } - - private: - const std::size_t batch_size_; - const std::size_t num_envs_; - const std::size_t obs_size_; - const bool ordered_mode_; - - // External output buffers (set via set_output_buffer / set_final_obs_buffer / set_metadata_buffers) - uint8_t* output_obs_buffer_; - uint8_t* final_obs_buffer_; - int* env_ids_buffer_; - int* rewards_buffer_; - bool* terminations_buffer_; - bool* truncations_buffer_; - int* lives_buffer_; - int* frame_numbers_buffer_; - int* episode_frame_numbers_buffer_; - - // Synchronization - std::atomic count_; - std::atomic write_idx_; - moodycamel::LightweightSemaphore sem_ready_; - moodycamel::LightweightSemaphore sem_read_; - moodycamel::LightweightSemaphore sem_slots_; // Controls slot availability - }; -} - -#endif // ALE_VECTOR_UTILS_HPP_ diff --git a/tests/python/test_atari_vector_env.py b/tests/python/test_atari_vector_env.py index 1e1e588af..c0e28cf4f 100644 --- a/tests/python/test_atari_vector_env.py +++ b/tests/python/test_atari_vector_env.py @@ -434,7 +434,9 @@ def test_batch_size_async( ) async_env_timestep[async_env_ids] += 1 - assert np.all(async_env_timestep > rollout_length / (num_envs * 2)), async_env_timestep + assert np.all( + async_env_timestep > rollout_length / (num_envs * 2) + ), async_env_timestep sync_envs.close() async_envs.close() @@ -644,7 +646,12 @@ def test_same_step_autoreset_mode( for i, ep_over in enumerate(episode_over): if ep_over: - assert obs_equivalence(gym_final_obs[i], ale_final_obs[i], t, autoreset_mode="SAME-STEP"), t + assert obs_equivalence( + gym_final_obs[i], + ale_final_obs[i], + t, + autoreset_mode="SAME-STEP", + ), t else: gym_info = { key: value.astype(np.int32) From 5950d70977a5f2199e6459bd16d0485ae7a8f19e Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sat, 29 Nov 2025 14:00:36 +0000 Subject: [PATCH 5/8] Add thread and memory sanitizers --- .github/workflows/ci.yml | 36 +++++++++++++++++++++++++++++------- CMakeLists.txt | 28 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df58705b4..d42439f6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,20 +41,27 @@ jobs: matrix: include: # To minimise the computational resources, we only use a single python version and the final test-wheels for all python versions + + # Thread Sanitizer build on Linux (most critical for vectorized env) - runs-on: ubuntu-latest python: '3.12' triplet: x64-linux-mixed + sanitizer: tsan + cc: clang-18 + cxx: clang++-18 - runs-on: windows-latest python: '3.12' triplet: x64-windows-mixed - - runs-on: macos-14 + - runs-on: macos-latest python: '3.12' triplet: arm64-osx-mixed env: VCPKG_DEFAULT_TRIPLET: ${{ matrix.triplet }} + CC: ${{ matrix.cc || '' }} + CXX: ${{ matrix.cxx || '' }} runs-on: ${{ matrix.runs-on }} steps: @@ -80,13 +87,28 @@ jobs: method: 'network' sub-packages: '["nvcc", "cudart", "visual_studio_integration"]' + - name: Install Clang (for sanitizer builds) + if: matrix.sanitizer + run: | + sudo apt-get update + sudo apt-get install -y clang-18 llvm-18 + - name: Download and unpack ROMs run: ./scripts/download_unpack_roms.sh - name: Build - run: python -m pip install --verbose .[test] + run: | + if [ -n "${{ matrix.sanitizer }}" ]; then + echo "Building with Thread Sanitizer..." + python -m pip install --verbose .[test] \ + --config-settings=cmake.args="-DENABLE_SANITIZER=thread" + else + python -m pip install --verbose .[test] + fi - name: Test + env: + TSAN_OPTIONS: ${{ matrix.sanitizer == 'tsan' && 'second_deadlock_stack=1 history_size=7' || '' }} run: python -m pytest build-wheels: @@ -216,19 +238,19 @@ jobs: wheel-name: 'cp313-cp313-win_amd64' arch: AMD64 - - runs-on: macos-14 + - runs-on: macos-latest python: '3.10' wheel-name: 'cp310-cp310-macosx_13_0_arm64' arch: arm64 - - runs-on: macos-14 + - runs-on: macos-latest python: '3.11' wheel-name: 'cp311-cp311-macosx_13_0_arm64' arch: arm64 - - runs-on: macos-14 + - runs-on: macos-latest python: '3.12' wheel-name: 'cp312-cp312-macosx_13_0_arm64' arch: arm64 - - runs-on: macos-14 + - runs-on: macos-latest python: '3.13' wheel-name: 'cp313-cp313-macosx_13_0_arm64' arch: arm64 @@ -272,7 +294,7 @@ jobs: # - runs-on: windows-latest # wheel-name: 'cp313-cp313-win_amd64' # arch: AMD64 - - runs-on: macos-14 + - runs-on: macos-latest wheel-name: 'cp313-cp313-macosx_13_0_arm64' arch: arm64 diff --git a/CMakeLists.txt b/CMakeLists.txt index 479f2e2c6..98d699709 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,34 @@ if (BUILD_VECTOR_XLA_LIB) add_definitions(-DBUILD_VECTOR_XLA_LIB) endif() +# Sanitizer support for debugging (thread, address, undefined) +# Usage: -DENABLE_SANITIZER=thread or -DENABLE_SANITIZER=address +set(ENABLE_SANITIZER "" CACHE STRING "Enable sanitizer (thread, address, or empty to disable)") +set_property(CACHE ENABLE_SANITIZER PROPERTY STRINGS "" "thread" "address") + +if(ENABLE_SANITIZER) + if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT CMAKE_CXX_COMPILER_ID MATCHES "GNU") + message(WARNING "Sanitizers are best supported with Clang or GCC. Current compiler: ${CMAKE_CXX_COMPILER_ID}") + endif() + + if(ENABLE_SANITIZER STREQUAL "thread") + message(STATUS "Building with Thread Sanitizer (TSan)") + add_compile_options(-fsanitize=thread -g -O1) + add_link_options(-fsanitize=thread) + # TSan requires PIE on some platforms + if(UNIX AND NOT APPLE) + add_compile_options(-fPIE) + add_link_options(-pie) + endif() + elseif(ENABLE_SANITIZER STREQUAL "address") + message(STATUS "Building with Address Sanitizer (ASan) + Undefined Behavior Sanitizer (UBSan)") + add_compile_options(-fsanitize=address -fsanitize=undefined -fno-omit-frame-pointer -g -O1) + add_link_options(-fsanitize=address -fsanitize=undefined) + else() + message(FATAL_ERROR "Invalid sanitizer '${ENABLE_SANITIZER}'. Choose 'thread' or 'address'.") + endif() +endif() + # Set cmake module path set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) From d6fe76f7db72ba740791d5ff4580ece7d7b81f0d Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sat, 29 Nov 2025 21:20:12 +0000 Subject: [PATCH 6/8] Fix CI issues --- CMakeLists.txt | 5 - scripts/run_sanitizers.sh | 149 ++++++++++++++++++++++++++++++ src/ale/vector/env_vectorizer.cpp | 1 + 3 files changed, 150 insertions(+), 5 deletions(-) create mode 100755 scripts/run_sanitizers.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 98d699709..c92502947 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,11 +38,6 @@ if(ENABLE_SANITIZER) message(STATUS "Building with Thread Sanitizer (TSan)") add_compile_options(-fsanitize=thread -g -O1) add_link_options(-fsanitize=thread) - # TSan requires PIE on some platforms - if(UNIX AND NOT APPLE) - add_compile_options(-fPIE) - add_link_options(-pie) - endif() elseif(ENABLE_SANITIZER STREQUAL "address") message(STATUS "Building with Address Sanitizer (ASan) + Undefined Behavior Sanitizer (UBSan)") add_compile_options(-fsanitize=address -fsanitize=undefined -fno-omit-frame-pointer -g -O1) diff --git a/scripts/run_sanitizers.sh b/scripts/run_sanitizers.sh new file mode 100755 index 000000000..0ec66c112 --- /dev/null +++ b/scripts/run_sanitizers.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# Helper script for running sanitizers locally +# Usage: ./scripts/run_sanitizers.sh [thread|address|valgrind-memcheck|valgrind-helgrind] + +set -e + +SANITIZER=${1:-thread} +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$PROJECT_ROOT" + +echo "==========================================" +echo "Running sanitizer: $SANITIZER" +echo "==========================================" + +# Ensure ROMs are downloaded +if [ ! -d "src/ale/python/roms" ]; then + echo "Downloading ROMs..." + ./scripts/download_unpack_roms.sh +fi + +case $SANITIZER in + thread) + echo "Building with Thread Sanitizer..." + + # On macOS, use Homebrew LLVM to avoid SIP restrictions + if [[ "$OSTYPE" == "darwin"* ]]; then + if command -v brew &> /dev/null; then + LLVM_PREFIX=$(brew --prefix llvm 2>/dev/null || echo "") + if [ -n "$LLVM_PREFIX" ] && [ -d "$LLVM_PREFIX" ]; then + echo "Using Homebrew LLVM from $LLVM_PREFIX" + export CC="$LLVM_PREFIX/bin/clang" + export CXX="$LLVM_PREFIX/bin/clang++" + + # Pass Homebrew LLVM library paths to CMake + # Use a single cmake.args with semicolon-separated values + pip install --verbose -e .[test] \ + --config-settings=cmake.args="-DENABLE_SANITIZER=thread;-DCMAKE_EXE_LINKER_FLAGS=-L$LLVM_PREFIX/lib -Wl,-rpath,$LLVM_PREFIX/lib;-DCMAKE_SHARED_LINKER_FLAGS=-L$LLVM_PREFIX/lib -Wl,-rpath,$LLVM_PREFIX/lib;-DCMAKE_MODULE_LINKER_FLAGS=-L$LLVM_PREFIX/lib -Wl,-rpath,$LLVM_PREFIX/lib" + else + echo "WARNING: Homebrew LLVM not found. Installing..." + echo "Run: brew install llvm" + echo "" + echo "macOS System Integrity Protection (SIP) blocks system clang's TSan runtime." + echo "You need Homebrew LLVM for Thread Sanitizer to work on macOS." + exit 1 + fi + else + echo "ERROR: Homebrew not found. Thread Sanitizer on macOS requires Homebrew LLVM." + echo "Install Homebrew from https://brew.sh/ then run: brew install llvm" + exit 1 + fi + else + export CC=clang + export CXX=clang++ + + pip install --verbose -e .[test] \ + --config-settings=cmake.args="-DENABLE_SANITIZER=thread" + fi + + echo "" + echo "Running tests with Thread Sanitizer..." + export TSAN_OPTIONS="second_deadlock_stack=1 history_size=7" + + echo "→ Running vector environment tests (most critical for threading)..." + python -m pytest tests/python/test_atari_vector_env.py -v -x + + echo "→ Running all tests..." + python -m pytest -v + ;; + + address) + echo "Building with Address Sanitizer + UBSan..." + export CC=clang + export CXX=clang++ + pip install --verbose -e .[test] \ + --config-settings=cmake.args="-DENABLE_SANITIZER=address" + + echo "" + echo "Running tests with Address Sanitizer..." + export ASAN_OPTIONS="detect_leaks=1:check_initialization_order=1" + export UBSAN_OPTIONS="print_stacktrace=1" + + echo "→ Running vector environment tests..." + python -m pytest tests/python/test_atari_vector_env.py -v -x + + echo "→ Running all tests..." + python -m pytest -v + ;; + + valgrind-memcheck) + echo "Building with debug symbols..." + pip install --verbose -e .[test] \ + --config-settings=cmake.args="-DCMAKE_BUILD_TYPE=RelWithDebInfo" + + echo "" + echo "Running Valgrind Memcheck (memory leak detection)..." + + if [ ! -f ".valgrind-python.supp" ]; then + echo "Warning: .valgrind-python.supp not found. Some false positives may appear." + fi + + valgrind \ + --tool=memcheck \ + --leak-check=full \ + --show-leak-kinds=definite,possible \ + --track-origins=yes \ + --verbose \ + --suppressions=.valgrind-python.supp \ + python -m pytest tests/python/test_atari_vector_env.py::TestVectorEnv::test_reset_step_shapes -v -k "num_envs-1" + ;; + + valgrind-helgrind) + echo "Building with debug symbols..." + pip install --verbose -e .[test] \ + --config-settings=cmake.args="-DCMAKE_BUILD_TYPE=RelWithDebInfo" + + echo "" + echo "Running Valgrind Helgrind (thread error detection)..." + + if [ ! -f ".valgrind-python.supp" ]; then + echo "Warning: .valgrind-python.supp not found. Some false positives may appear." + fi + + valgrind \ + --tool=helgrind \ + --verbose \ + --suppressions=.valgrind-python.supp \ + python -m pytest tests/python/test_atari_vector_env.py::TestVectorEnv::test_batch_size_async -v + ;; + + *) + echo "Unknown sanitizer: $SANITIZER" + echo "" + echo "Usage: $0 [thread|address|valgrind-memcheck|valgrind-helgrind]" + echo "" + echo "Options:" + echo " thread - Thread Sanitizer (detects data races, deadlocks)" + echo " address - Address Sanitizer + UBSan (detects memory errors, UB)" + echo " valgrind-memcheck - Valgrind Memcheck (detects memory leaks)" + echo " valgrind-helgrind - Valgrind Helgrind (detects thread errors)" + exit 1 + ;; +esac + +echo "" +echo "==========================================" +echo "Sanitizer run complete: $SANITIZER" +echo "==========================================" diff --git a/src/ale/vector/env_vectorizer.cpp b/src/ale/vector/env_vectorizer.cpp index 95cb243cb..fe937a2c2 100644 --- a/src/ale/vector/env_vectorizer.cpp +++ b/src/ale/vector/env_vectorizer.cpp @@ -3,6 +3,7 @@ #if defined(__linux__) #include #elif defined(_WIN32) + #define NOMINMAX // Prevent Windows.h from defining min/max macros #include #elif defined(__APPLE__) #include From 84ae69ff0cbfdcb850225748d1efe75a01313e67 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sat, 29 Nov 2025 21:27:32 +0000 Subject: [PATCH 7/8] upstream SIMD optimizations --- src/ale/vector/preprocessed_env.cpp | 44 +++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/src/ale/vector/preprocessed_env.cpp b/src/ale/vector/preprocessed_env.cpp index 862a086fc..21bfe9781 100644 --- a/src/ale/vector/preprocessed_env.cpp +++ b/src/ale/vector/preprocessed_env.cpp @@ -241,9 +241,7 @@ void PreprocessedEnv::get_screen_rgb(uint8_t* buffer) const { void PreprocessedEnv::process_screen() { // Maxpool raw frames if required if (maxpool_) { - for (int i = 0; i < raw_size_; ++i) { - raw_frames_[0][i] = std::max(raw_frames_[0][i], raw_frames_[1][i]); - } + maxpool_frames(raw_frames_[0].data(), raw_frames_[1].data(), raw_size_); } // Get pointer to current position in circular buffer @@ -264,4 +262,44 @@ void PreprocessedEnv::process_screen() { frame_stack_idx_ = (frame_stack_idx_ + 1) % stack_num_; } +/** + * Maxpool two uint8_t frames using SIMD when available + * @param dst Destination buffer (will be modified in-place with max values) + * @param src Source buffer to compare against + * @param size Number of bytes to process + */ +inline void maxpool_frames(uint8_t* dst, const uint8_t* src, int size) { + int i = 0; + +#if defined(__AVX2__) + // Process 32 bytes at a time with AVX2 + for (; i + 32 <= size; i += 32) { + __m256i a = _mm256_loadu_si256(reinterpret_cast(dst + i)); + __m256i b = _mm256_loadu_si256(reinterpret_cast(src + i)); + __m256i max_val = _mm256_max_epu8(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), max_val); + } +#elif defined(__SSE2__) + // Process 16 bytes at a time with SSE2 + for (; i + 16 <= size; i += 16) { + __m128i a = _mm_loadu_si128(reinterpret_cast(dst + i)); + __m128i b = _mm_loadu_si128(reinterpret_cast(src + i)); + __m128i max_val = _mm_max_epu8(a, b); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), max_val); + } +#elif defined(__ARM_NEON) + // Process 16 bytes at a time with NEON + for (; i + 16 <= size; i += 16) { + uint8x16_t a = vld1q_u8(dst + i); + uint8x16_t b = vld1q_u8(src + i); + uint8x16_t max_val = vmaxq_u8(a, b); + vst1q_u8(dst + i, max_val); + } +#endif + // Handle remainder with scalar code + for (; i < size; ++i) { + dst[i] = std::max(dst[i], src[i]); + } +} + } // namespace ale::vector From 7ae7b6a9d18ff96e9bc89f20e39985ee49acc347 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Sun, 30 Nov 2025 16:57:14 +0000 Subject: [PATCH 8/8] Fix the merge --- src/ale/vector/preprocessed_env.cpp | 86 ++++++++++++++--------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/src/ale/vector/preprocessed_env.cpp b/src/ale/vector/preprocessed_env.cpp index 21bfe9781..15607d541 100644 --- a/src/ale/vector/preprocessed_env.cpp +++ b/src/ale/vector/preprocessed_env.cpp @@ -2,6 +2,46 @@ namespace ale::vector { +/** + * Maxpool two uint8_t frames using SIMD when available + * @param dst Destination buffer (will be modified in-place with max values) + * @param src Source buffer to compare against + * @param size Number of bytes to process + */ +inline void maxpool_frames(uint8_t* dst, const uint8_t* src, int size) { + int i = 0; + +#if defined(__AVX2__) + // Process 32 bytes at a time with AVX2 + for (; i + 32 <= size; i += 32) { + __m256i a = _mm256_loadu_si256(reinterpret_cast(dst + i)); + __m256i b = _mm256_loadu_si256(reinterpret_cast(src + i)); + __m256i max_val = _mm256_max_epu8(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), max_val); + } +#elif defined(__SSE2__) + // Process 16 bytes at a time with SSE2 + for (; i + 16 <= size; i += 16) { + __m128i a = _mm_loadu_si128(reinterpret_cast(dst + i)); + __m128i b = _mm_loadu_si128(reinterpret_cast(src + i)); + __m128i max_val = _mm_max_epu8(a, b); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), max_val); + } +#elif defined(__ARM_NEON) + // Process 16 bytes at a time with NEON + for (; i + 16 <= size; i += 16) { + uint8x16_t a = vld1q_u8(dst + i); + uint8x16_t b = vld1q_u8(src + i); + uint8x16_t max_val = vmaxq_u8(a, b); + vst1q_u8(dst + i, max_val); + } +#endif + // Handle remainder with scalar code + for (; i < size; ++i) { + dst[i] = std::max(dst[i], src[i]); + } +} + PreprocessedEnv::PreprocessedEnv( int env_id, const fs::path& rom_path, @@ -22,13 +62,13 @@ PreprocessedEnv::PreprocessedEnv( int seed ) : env_id_(env_id), rom_path_(rom_path), + obs_format_(grayscale ? ObsFormat::Grayscale : ObsFormat::RGB), + channels_per_frame_(grayscale ? 1 : 3), obs_frame_height_(img_height), obs_frame_width_(img_width), + stack_num_(stack_num), frame_skip_(frame_skip), maxpool_(maxpool), - obs_format_(grayscale ? ObsFormat::Grayscale : ObsFormat::RGB), - channels_per_frame_(grayscale ? 1 : 3), - stack_num_(stack_num), noop_max_(noop_max), use_fire_reset_(use_fire_reset), has_fire_action_(false), @@ -262,44 +302,4 @@ void PreprocessedEnv::process_screen() { frame_stack_idx_ = (frame_stack_idx_ + 1) % stack_num_; } -/** - * Maxpool two uint8_t frames using SIMD when available - * @param dst Destination buffer (will be modified in-place with max values) - * @param src Source buffer to compare against - * @param size Number of bytes to process - */ -inline void maxpool_frames(uint8_t* dst, const uint8_t* src, int size) { - int i = 0; - -#if defined(__AVX2__) - // Process 32 bytes at a time with AVX2 - for (; i + 32 <= size; i += 32) { - __m256i a = _mm256_loadu_si256(reinterpret_cast(dst + i)); - __m256i b = _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i max_val = _mm256_max_epu8(a, b); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), max_val); - } -#elif defined(__SSE2__) - // Process 16 bytes at a time with SSE2 - for (; i + 16 <= size; i += 16) { - __m128i a = _mm_loadu_si128(reinterpret_cast(dst + i)); - __m128i b = _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i max_val = _mm_max_epu8(a, b); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), max_val); - } -#elif defined(__ARM_NEON) - // Process 16 bytes at a time with NEON - for (; i + 16 <= size; i += 16) { - uint8x16_t a = vld1q_u8(dst + i); - uint8x16_t b = vld1q_u8(src + i); - uint8x16_t max_val = vmaxq_u8(a, b); - vst1q_u8(dst + i, max_val); - } -#endif - // Handle remainder with scalar code - for (; i < size; ++i) { - dst[i] = std::max(dst[i], src[i]); - } -} - } // namespace ale::vector