-
Notifications
You must be signed in to change notification settings - Fork 60
[fs_connector][feat]: Add multithreaded worker pool (thread_pool) #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kfirtoledo
wants to merge
2
commits into
llm-d:main
Choose a base branch
from
kfirtoledo:connector_thread
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
140 changes: 140 additions & 0 deletions
140
kv_connectors/llmd_fs_backend/src/csrc/storage/thread_pool.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| /* | ||
| * Copyright 2025 The llm-d Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <torch/extension.h> | ||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <c10/cuda/CUDAGuard.h> | ||
| #include <cuda_runtime.h> | ||
| #include <iostream> | ||
| #include <thread> | ||
| #include <mutex> | ||
| #include <queue> | ||
| #include <condition_variable> | ||
| #include <atomic> | ||
| #include <sys/syscall.h> | ||
| #include <unistd.h> | ||
| #include <numa.h> | ||
|
|
||
| #include "thread_pool.hpp" | ||
| #include "buffer.hpp" | ||
| #include "debug_utils.hpp" | ||
|
|
||
| // Thread-local index for CUDA streams | ||
| extern thread_local size_t thread_stream_idx; | ||
|
|
||
| // ThreadPool constructor | ||
| ThreadPool::ThreadPool(int threads, size_t pinned_buffer_mb, int tp_rank, int device_id) : m_device_id(device_id) { | ||
| // Initialize PyTorch threading globally (main thread only) | ||
| // at::init_num_threads(); | ||
| // at::set_num_threads(1); | ||
|
|
||
| // Get GPU NUMA node ONCE outside the thread loop | ||
| int gpu_numa = get_gpu_numa_node(device_id); | ||
| std::cout << "[INFO] GPU " << device_id << " mapped to NUMA node " << gpu_numa << "\n"; | ||
|
|
||
| // Get all CPUs in that NUMA node | ||
| auto local_cpus = get_cpus_in_numa_node(gpu_numa); | ||
|
|
||
| if (local_cpus.empty()) { | ||
| std::cerr << "[WARN] No CPUs found for NUMA node " << gpu_numa << ". System may not be NUMA-aware. Using all CPUs.\n"; | ||
| // Populate with all available CPUs as fallback | ||
| int num_cpus = sysconf(_SC_NPROCESSORS_ONLN); | ||
| for (int i = 0; i < num_cpus; ++i) { | ||
| local_cpus.push_back(i); | ||
| } | ||
| } | ||
|
|
||
| // Log available CPUs | ||
| std::cout << "CPUs available for GPU " << device_id << " (NUMA " << gpu_numa << "): "; | ||
| for (int cpu : local_cpus) std::cout << cpu << " "; | ||
| std::cout << "\n"; | ||
|
|
||
| // Create all worker threads | ||
| for (size_t i = 0; i < threads; ++i) { | ||
| // Launch a new worker thread with a lambda that initializes thread resources and processes queued tasks. | ||
| workers.emplace_back([this, i, threads, staging_buffer_mb, tp_rank, device_id, gpu_numa, local_cpus] { | ||
| cudaSetDevice(device_id); | ||
|
|
||
| // Round-robin CPUs within the NUMA node | ||
| // TODO: Re-evaluate whether strict NUMA-based round-robin CPU assignment is optimal for performance. | ||
| int cpu_id = local_cpus[i % local_cpus.size()]; | ||
|
|
||
| cpu_set_t cpuset; | ||
| CPU_ZERO(&cpuset); | ||
| CPU_SET(cpu_id, &cpuset); | ||
|
|
||
| if (pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset) != 0) { | ||
| std::cerr << "[ERROR] Failed to set affinity for thread " << i << " to CPU " << cpu_id << "\n"; | ||
| } | ||
|
|
||
| DEBUG_PRINT("IO thread " << i << " set CUDA device to " << device_id << ", tp_rank=" << tp_rank << ") pinned to CPU " | ||
| << cpu_id); | ||
|
|
||
| // Attach preallocated staging buffer for this thread | ||
| if (i < g_staging_buffers.size() && g_staging_buffers[i].ptr != nullptr) { | ||
| t_staging_buffer.ptr = g_staging_buffers[i].ptr; | ||
| t_staging_buffer.size = g_staging_buffers[i].size; | ||
| DEBUG_PRINT("IO thread " << i << " attached to preallocated staging buffer " << (t_staging_buffer.size / (1024 * 1024)) | ||
| << " MB"); | ||
| } else { | ||
| std::cerr << "[WARN] IO thread " << i << " has no preallocated staging buffer; it will allocate one on first use.\n"; | ||
| } | ||
|
|
||
| // Each thread gets its own CUDA stream index | ||
| thread_stream_idx = i; | ||
|
|
||
| // Worker loop | ||
| while (true) { | ||
| std::function<void()> task; | ||
| { | ||
| // Lock the task queue before checking it | ||
| std::unique_lock<std::mutex> lock(queue_mutex); | ||
|
|
||
| // Wait until either a new task arrives or the pool is stopping. | ||
| // (wait() unlocks the mutex while sleeping and re-locks it when waking) | ||
| condition.wait(lock, [this] { return stop || !tasks.empty(); }); | ||
|
|
||
| // Exit thread if pool is stopping and no tasks remain | ||
| if (stop && tasks.empty()) return; | ||
|
|
||
| // Fetch next task from the queue | ||
| task = std::move(tasks.front()); | ||
| tasks.pop(); | ||
| } | ||
| try { | ||
| // Execute the task | ||
| task(); | ||
| } catch (const std::exception& e) { | ||
| std::cerr << "[ERROR] Exception in worker thread: " << e.what() << "\n"; | ||
| } catch (...) { | ||
| std::cerr << "[ERROR] Unknown exception in worker thread\n"; | ||
| } | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| std::cout << "[INFO] All " << threads << " I/O threads initialized with staging buffers\n"; | ||
| } | ||
|
|
||
| // ThreadPool destructor | ||
| ThreadPool::~ThreadPool() { | ||
| stop = true; | ||
| condition.notify_all(); | ||
| // Wait for all worker threads to exit | ||
| for (std::thread& worker : workers) { | ||
| worker.join(); | ||
| } | ||
| } |
91 changes: 91 additions & 0 deletions
91
kv_connectors/llmd_fs_backend/src/csrc/storage/thread_pool.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| /* | ||
| * Copyright 2025 The llm-d Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
| #include <vector> | ||
| #include <thread> | ||
| #include <future> | ||
| #include <queue> | ||
| #include <mutex> | ||
| #include <condition_variable> | ||
| #include <atomic> | ||
| #include <functional> | ||
| #include <sys/syscall.h> | ||
| #include <unistd.h> | ||
|
|
||
| #include <cuda_runtime.h> | ||
|
|
||
| #include "buffer.hpp" | ||
| #include "debug_utils.hpp" | ||
|
|
||
| // Thread-local storage used by each I/O thread | ||
| extern thread_local size_t thread_stream_idx; | ||
|
|
||
| // ThreadPool class is a thread pool used for parallel file offloading. Each | ||
| // worker thread handles one file end-to-end: reading or writing the file, | ||
| // staging data through its own thread-local staging buffer, and launching the | ||
| // GPU copy on a dedicated CUDA stream. This enables many files to be processed | ||
| // concurrently with full I/O–GPU overlap. | ||
| class ThreadPool { | ||
| public: | ||
| ThreadPool(int threads, size_t pinned_buffer_mb, int tp_rank, int device_id); | ||
|
|
||
| ~ThreadPool(); | ||
|
|
||
| template <class F> | ||
| auto enqueue(F&& f) -> std::future<std::invoke_result_t<F>>; | ||
|
|
||
| private: | ||
| std::vector<std::thread> workers; // All worker threads | ||
| std::queue<std::function<void()>> tasks; // Queue of pending tasks | ||
|
|
||
| std::mutex queue_mutex; // Protects access to the task queue | ||
| std::condition_variable condition; // Signals workers when tasks are available | ||
|
|
||
| std::atomic<bool> stop{false}; // Tells workers to stop and exit | ||
| int m_device_id; // CUDA device this thread pool is bound to | ||
| }; | ||
|
|
||
| // enqueue: submit a task to the thread pool | ||
| template <class F> | ||
| auto ThreadPool::enqueue(F&& f) -> std::future<std::invoke_result_t<F>> { | ||
| // Get the return type of the submitted task | ||
| using return_type = std::invoke_result_t<F>; | ||
|
|
||
| // Wrap the callable into a packaged_task so we can return a future | ||
| auto task = std::make_shared<std::packaged_task<return_type()>>(std::forward<F>(f)); | ||
|
|
||
| // Future for the caller to wait on | ||
| std::future<return_type> res = task->get_future(); | ||
|
|
||
| { | ||
| std::unique_lock<std::mutex> lock(queue_mutex); | ||
|
|
||
| // Reject new tasks if the pool is shutting down | ||
| if (stop) { | ||
| std::cerr << "[WARN] ThreadPool is stopping. Rejecting new task.\n"; | ||
| return std::future<return_type>(); // empty future | ||
| } | ||
|
|
||
| // Push the task wrapper into the queue | ||
| tasks.emplace([task]() { (*task)(); }); | ||
| } | ||
|
|
||
| // Wake one worker thread to process the task | ||
| condition.notify_one(); | ||
|
|
||
| return res; | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.