Skip to content

Commit deae225

Browse files
committed
[fs_connector][feat]: Add multithreaded worker pool (thread_pool) for fs connector
Signed-off-by: Kfir Toledo <[email protected]>
1 parent f8bb304 commit deae225

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright 2025 The llm-d Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/extension.h>
18+
#include <ATen/cuda/CUDAContext.h>
19+
#include <c10/cuda/CUDAGuard.h>
20+
#include <cuda_runtime.h>
21+
#include <iostream>
22+
#include <thread>
23+
#include <mutex>
24+
#include <queue>
25+
#include <condition_variable>
26+
#include <atomic>
27+
#include <sys/syscall.h>
28+
#include <unistd.h>
29+
#include <numa.h>
30+
31+
#include "thread_pool.hpp"
32+
#include "buffer.hpp"
33+
#include "debug_utils.hpp"
34+
35+
// Thread-local index for CUDA streams
36+
extern thread_local size_t thread_stream_idx;
37+
38+
// ThreadPool constructor
39+
ThreadPool::ThreadPool(int threads, size_t pinned_buffer_mb, int tp_rank, int device_id) : m_device_id(device_id) {
40+
// Initialize PyTorch threading globally (main thread only)
41+
// at::init_num_threads();
42+
// at::set_num_threads(1);
43+
44+
// Get GPU NUMA node ONCE outside the thread loop
45+
int gpu_numa = get_gpu_numa_node(device_id);
46+
std::cout << "[INFO] GPU " << device_id << " mapped to NUMA node " << gpu_numa << "\n";
47+
48+
// Get all CPUs in that NUMA node
49+
auto local_cpus = get_cpus_in_numa_node(gpu_numa);
50+
51+
if (local_cpus.empty()) {
52+
std::cerr << "[WARN] No CPUs found for NUMA node " << gpu_numa << ". System may not be NUMA-aware. Using all CPUs.\n";
53+
// Populate with all available CPUs as fallback
54+
int num_cpus = sysconf(_SC_NPROCESSORS_ONLN);
55+
for (int i = 0; i < num_cpus; ++i) {
56+
local_cpus.push_back(i);
57+
}
58+
}
59+
60+
// Log available CPUs
61+
std::cout << "CPUs available for GPU " << device_id << " (NUMA " << gpu_numa << "): ";
62+
for (int cpu : local_cpus) std::cout << cpu << " ";
63+
std::cout << "\n";
64+
65+
// Create all worker threads
66+
for (size_t i = 0; i < threads; ++i) {
67+
// Launch a new worker thread with a lambda that initializes thread resources and processes queued tasks.
68+
workers.emplace_back([this, i, threads, pinned_buffer_mb, tp_rank, device_id, gpu_numa, local_cpus] {
69+
cudaSetDevice(device_id);
70+
71+
// Round-robin CPUs within the NUMA node
72+
int cpu_id = local_cpus[i % local_cpus.size()];
73+
74+
cpu_set_t cpuset;
75+
CPU_ZERO(&cpuset);
76+
CPU_SET(cpu_id, &cpuset);
77+
78+
if (pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset) != 0) {
79+
std::cerr << "[ERROR] Failed to set affinity for thread " << i << " to CPU " << cpu_id << "\n";
80+
}
81+
82+
int actual_cpu = sched_getcpu();
83+
pid_t tid = static_cast<pid_t>(syscall(SYS_gettid));
84+
DEBUG_PRINT("IO thread " << i << " set CUDA device to " << device_id << " (tid=" << tid << ", tp_rank=" << tp_rank
85+
<< ") pinned to CPU " << cpu_id << " (running on CPU " << actual_cpu << ")");
86+
87+
// Attach preallocated pinned buffer for this thread
88+
if (i < g_pinned_buffers.size() && g_pinned_buffers[i].ptr != nullptr) {
89+
t_pinned_buffer.ptr = g_pinned_buffers[i].ptr;
90+
t_pinned_buffer.size = g_pinned_buffers[i].size;
91+
DEBUG_PRINT("IO thread " << i << " attached to preallocated pinned buffer " << (t_pinned_buffer.size / (1024 * 1024))
92+
<< " MB");
93+
} else {
94+
std::cerr << "[WARN] IO thread " << i << " has no preallocated pinned buffer\n";
95+
}
96+
97+
// Each thread gets its own CUDA stream index
98+
thread_stream_idx = i;
99+
100+
// Worker loop
101+
while (true) {
102+
std::function<void()> task;
103+
{
104+
// Lock the task queue before checking it
105+
std::unique_lock<std::mutex> lock(queue_mutex);
106+
107+
// Wait until either a new task arrives or the pool is stopping.
108+
// (wait() unlocks the mutex while sleeping and re-locks it when waking)
109+
condition.wait(lock, [this] { return stop || !tasks.empty(); });
110+
111+
// Exit thread if pool is stopping and no tasks remain
112+
if (stop && tasks.empty()) return;
113+
114+
// Fetch next task from the queue
115+
task = std::move(tasks.front());
116+
tasks.pop();
117+
}
118+
// Execute the task
119+
task();
120+
}
121+
});
122+
}
123+
124+
std::cout << "[INFO] All " << threads << " I/O threads initialized with pinned buffers\n";
125+
}
126+
127+
// ThreadPool destructor
128+
ThreadPool::~ThreadPool() {
129+
stop = true;
130+
condition.notify_all();
131+
// Wait for all worker threads to exit
132+
for (std::thread& worker : workers) {
133+
worker.join();
134+
}
135+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright 2025 The llm-d Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <vector>
19+
#include <thread>
20+
#include <future>
21+
#include <queue>
22+
#include <mutex>
23+
#include <condition_variable>
24+
#include <atomic>
25+
#include <functional>
26+
#include <sys/syscall.h>
27+
#include <unistd.h>
28+
29+
#include <cuda_runtime.h>
30+
31+
#include "buffer.hpp" // defines PinnedBuffer + get_gpu_numa_node + get_cpus_in_numa_node
32+
#include "debug_utils.hpp"
33+
34+
// Thread-local storage used by each I/O thread
35+
extern thread_local size_t thread_stream_idx;
36+
37+
// ThreadPool class
38+
class ThreadPool {
39+
public:
40+
ThreadPool(int threads, size_t pinned_buffer_mb, int tp_rank, int device_id);
41+
42+
~ThreadPool();
43+
44+
template <class F>
45+
auto enqueue(F&& f) -> std::future<std::invoke_result_t<F>>;
46+
47+
private:
48+
std::vector<std::thread> workers; // All worker threads
49+
std::queue<std::function<void()>> tasks; // Queue of pending tasks
50+
51+
std::mutex queue_mutex; // Protects access to the task queue
52+
std::condition_variable condition; // Signals workers when tasks are available
53+
54+
std::atomic<bool> stop{false}; // Tells workers to stop and exit
55+
int m_device_id; // CUDA device this thread pool is bound to
56+
};
57+
58+
// enqueue: submit a task to the thread pool
59+
template <class F>
60+
auto ThreadPool::enqueue(F&& f) -> std::future<std::invoke_result_t<F>> {
61+
// Get the return type of the submitted task
62+
using return_type = std::invoke_result_t<F>;
63+
64+
// Wrap the callable into a packaged_task so we can return a future
65+
auto task = std::make_shared<std::packaged_task<return_type()>>(std::forward<F>(f));
66+
67+
// Future for the caller to wait on
68+
std::future<return_type> res = task->get_future();
69+
70+
{
71+
std::unique_lock<std::mutex> lock(queue_mutex);
72+
73+
// Reject new tasks if the pool is shutting down
74+
if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
75+
76+
// Push the task wrapper into the queue
77+
tasks.emplace([task]() { (*task)(); });
78+
}
79+
80+
// Wake one worker thread to process the task
81+
condition.notify_one();
82+
83+
return res;
84+
}

0 commit comments

Comments
 (0)