Skip to content
12 changes: 7 additions & 5 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _check_unloaded_tm_params(self):

def _load_weights(self):
"""Load weights."""
self._get_model_params()

with torch.cuda.device(self.devices[0]):
self._tm_model.export()
Expand Down Expand Up @@ -206,9 +207,12 @@ def _create_weight_func(device_id):
for future in futures:
future.result()

def _get_model_params(self, model_comm, tm_params: dict):
def _get_model_params(self):
"""Get turbomind model params when loading from hf."""

model_comm = self.model_comm
tm_params = self._tm_model.tm_params

def _get_params(device_id, que):
rank = self.node_id * self.gpu_count + device_id
out = model_comm.get_params(device_id, rank)
Expand All @@ -229,6 +233,7 @@ def _get_params(device_id, que):
tm_params[k] = [v]
else:
tm_params[k].append(v)
logger.warning(f'get {len(tm_params)} model params')

def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig):
"""Postprocess turbomind config by."""
Expand Down Expand Up @@ -274,10 +279,6 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
self._create_weight(model_comm)
# output model
self._tm_model = tm_model
# get tm params
tm_params = tm_model.tm_params
self._get_model_params(model_comm, tm_params)
logger.warning(f'get {len(tm_params)} model params')
return model_comm

def sleep(self, level: int = 1):
Expand Down Expand Up @@ -314,6 +315,7 @@ def _construct(item):
return func(*args).clone()

if not hasattr(self, '_export_iter'):
self._get_model_params()
que = Queue()
tm_model = self._tm_model
tm_model.input_model.model_path = que
Expand Down
21 changes: 19 additions & 2 deletions src/turbomind/comm/thread_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,27 @@ namespace turbomind::comm {

struct ThreadCommImpl: public HostCommImpl {

constexpr static int kMaxSplits = 32;

class State {
public:
explicit State(int n): n_{n}, channels_(n * n) {}
explicit State(int n): n_{n}, channels_(n * n * kMaxSplits) {}
std::atomic<void*>& channel(int from, int to)
{
return channels_[from * n_ + to];
}

int next_offset()
{
std::lock_guard lock{mutex_};
TM_CHECK(offset_ < channels_.size() / n_);
offset_ += n_;
return offset_;
}

private:
std::mutex mutex_;
int offset_{0};
int n_;
std::deque<std::atomic<void*>> channels_;
};
Expand All @@ -33,6 +45,8 @@ struct ThreadCommImpl: public HostCommImpl {

int rank_; // global rank

int offset_{0};

std::vector<int> l2g_;
std::vector<int> g2l_;

Expand All @@ -46,6 +60,9 @@ struct ThreadCommImpl: public HostCommImpl {
ThreadCommImpl(std::vector<int> l2g, std::vector<int> g2l, std::shared_ptr<State> state, int rank):
state_{std::move(state)}, rank_{rank}, l2g_{std::move(l2g)}, g2l_{std::move(g2l)}
{
int offset = (this->rank() == 0) ? state_->next_offset() : 0;
comm::Broadcast(this, offset, 0);
offset_ = offset;
}

int rank() const override
Expand All @@ -65,7 +82,7 @@ struct ThreadCommImpl: public HostCommImpl {

std::atomic<void*>& channel(int from, int to)
{
return state_->channel(from, to);
return state_->channel(from + offset_, to);
}

std::shared_ptr<HostCommImpl> Split(int color, int key) override
Expand Down
80 changes: 41 additions & 39 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ LlamaBatch::LlamaBatch(DataType data_type,
[this](void* p, ssize_t size) { return SymmFree(p, size, true); },
kDEVICE);

InitializeBufferAndKVCache();
InitializeBufferAndKVCache(param_.empty_init);

// Wait for allocations
check_cuda_error(cudaStreamSynchronize(stream_));
Expand Down Expand Up @@ -1791,49 +1791,51 @@ void LlamaBatch::Warmup()
}
}

void LlamaBatch::InitializeBufferAndKVCache()
void LlamaBatch::InitializeBufferAndKVCache(bool skip_kvcache)
{
// initialize kvcache, BatchState and persist buffers
core::ContextGuard guard{context_->core_stream, context_->allocator, Allocator{kCPUpinned}};

const auto cache_block_seq_len = model_->attn_param_.cache_block_seq_len;

const int dbits = byte_size(data_type_, 8);

const auto quant_policy = model_->param_.quant_policy;
const int elem_bits = quant_policy ? quant_policy : dbits;

SequenceManager::BlockConfig block_config{
(int)model_->size_per_head_,
(int)model_->local_kv_head_num_,
cache_block_seq_len,
elem_bits == dbits ? 0 : dbits,
elem_bits,
};

const auto get_free_size = [&] { //
size_t free{}, total{};
check_cuda_error(cudaMemGetInfo(&free, &total));
return AllReduce(model_->comm_->h_tp_group, free, comm::RedOp::kMin);
};

sequence_manager_.reset(new SequenceManager{model_->layer_num_,
block_config,
param_.cache_max_block_count,
param_.cache_chunk_size,
param_.enable_prefix_caching,
tp_rank_,
core::Context::alloc(kDEVICE),
get_free_size});

const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len_) {
if (tp_rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len_,
max_session_len);
if (!skip_kvcache) {
const auto cache_block_seq_len = model_->attn_param_.cache_block_seq_len;

const int dbits = byte_size(data_type_, 8);

const auto quant_policy = model_->param_.quant_policy;
const int elem_bits = quant_policy ? quant_policy : dbits;

SequenceManager::BlockConfig block_config{
(int)model_->size_per_head_,
(int)model_->local_kv_head_num_,
cache_block_seq_len,
elem_bits == dbits ? 0 : dbits,
elem_bits,
};

const auto get_free_size = [&] { //
size_t free{}, total{};
check_cuda_error(cudaMemGetInfo(&free, &total));
return AllReduce(model_->comm_->h_tp_mem_group, free, comm::RedOp::kMin);
};

sequence_manager_.reset(new SequenceManager{model_->layer_num_,
block_config,
param_.cache_max_block_count,
param_.cache_chunk_size,
param_.enable_prefix_caching,
tp_rank_,
core::Context::alloc(kDEVICE),
get_free_size});

const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len_) {
if (tp_rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len_,
max_session_len);
}
session_len_ = max_session_len;
}
session_len_ = max_session_len;
}

FT_CHECK(max_context_token_num_ >= session_len_);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class LlamaBatch {

~LlamaBatch();

void InitializeBufferAndKVCache();
void InitializeBufferAndKVCache(bool skip_kvcache = false);

void FreeBufferAndKVCache();

Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct Communicators {
comm::HostComm h_comm;
comm::HostComm h_tp_group;
comm::HostComm h_dp_group;
comm::HostComm h_tp_mem_group;

comm::DeviceComm d_comm;
int d_tp_group;
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct EngineParam {
int mlp_tp_rank;

std::vector<int> devices;
bool empty_init;
};

enum class LoraPolicy : int
Expand Down
15 changes: 11 additions & 4 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ LlamaTritonModel::LlamaTritonModel(DataType dtype,
engine_param_.mlp_tp_size = engine_reader["mlp_tp_size"].as<int>();
engine_param_.mlp_tp_rank = 0;

engine_param_.devices = engine_reader["devices"].as<std::vector<int>>();
engine_param_.devices = engine_reader["devices"].as<std::vector<int>>();
engine_param_.empty_init = engine_reader["empty_init"].as<bool>(false);

{
auto tp = engine_param_.attn_tp_size;
Expand Down Expand Up @@ -497,6 +498,9 @@ Communicators LlamaTritonModel::createCommSplits(int rank)
comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0);
comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0);

// kvcache manager may be initialized by another thread, use same h_tp_group may cause conflict
comm.h_tp_mem_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0);

if (comm_size_ > 1) {
comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm);
//
Expand Down Expand Up @@ -595,8 +599,10 @@ void LlamaTritonModel::sleep(int device_id, int level)
weights_[device_id]->to_device(kCPU);
}

// free kv cache
engines_[device_id]->FreeBufferAndKVCache();
if (engines_[device_id]) {
// free kv cache
engines_[device_id]->FreeBufferAndKVCache();
}
}

void LlamaTritonModel::wakeup(int device_id, const std::vector<std::string>& tags)
Expand All @@ -617,7 +623,8 @@ void LlamaTritonModel::wakeup(int device_id, const std::vector<std::string>& tag
}
}

if (keys.find("kv_cache") != keys.end()) {
if (keys.find("kv_cache") != keys.end() && engines_[device_id]) {
engines_[device_id]->FreeBufferAndKVCache();
engines_[device_id]->InitializeBufferAndKVCache();
}
}
Expand Down
Loading