Skip to content
10 changes: 9 additions & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,17 @@ def wakeup(self, tags: Optional[List[str]] = None):
"""Wake up the model.

Args:
tags (List[str]): The tags to wake up. Values must be in `("weights", "kv_cache")`
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
`("weights", "kv_cache")`. If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.engine.wakeup(tags)
# for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "tags" must be in ("weights", "kv_cache")
So what does None mean?

if self.backend == 'turbomind' and (tags is None or 'kv_cache' in tags):
self.instances = [self.engine.create_instance() for _ in range(self.instance_num)]
self.free_insts = None

def _get_limiter(self):
if not self.limiter:
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,15 +953,15 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None):
@router.post('/sleep', dependencies=[Depends(check_api_key)])
async def sleep(raw_request: Request = None):
level = raw_request.query_params.get('level', '1')
VariableInterface.async_engine.engine.sleep(int(level))
VariableInterface.async_engine.sleep(int(level))
return Response(status_code=200)


@router.post('/wakeup', dependencies=[Depends(check_api_key)])
async def wakeup(raw_request: Request = None):
tags = raw_request.query_params.getlist('tags')
tags = tags or None
VariableInterface.async_engine.engine.wakeup(tags)
VariableInterface.async_engine.wakeup(tags)
return Response(status_code=200)


Expand Down
15 changes: 9 additions & 6 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _check_unloaded_tm_params(self):

def _load_weights(self):
"""Load weights."""
self._get_model_params()

with torch.cuda.device(self.devices[0]):
self._tm_model.export()
Expand Down Expand Up @@ -206,9 +207,12 @@ def _create_weight_func(device_id):
for future in futures:
future.result()

def _get_model_params(self, model_comm, tm_params: dict):
def _get_model_params(self):
"""Get turbomind model params when loading from hf."""

model_comm = self.model_comm
tm_params = self._tm_model.tm_params

def _get_params(device_id, que):
rank = self.node_id * self.gpu_count + device_id
out = model_comm.get_params(device_id, rank)
Expand All @@ -229,6 +233,7 @@ def _get_params(device_id, que):
tm_params[k] = [v]
else:
tm_params[k].append(v)
logger.warning(f'get {len(tm_params)} model params')

def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig):
"""Postprocess turbomind config by."""
Expand Down Expand Up @@ -274,10 +279,6 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
self._create_weight(model_comm)
# output model
self._tm_model = tm_model
# get tm params
tm_params = tm_model.tm_params
self._get_model_params(model_comm, tm_params)
logger.warning(f'get {len(tm_params)} model params')
return model_comm

def sleep(self, level: int = 1):
Expand All @@ -291,7 +292,8 @@ def wakeup(self, tags: Optional[list[str]] = None):
if tags is None:
tags = ['weights', 'kv_cache']
with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count):
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count, ranks):
pass

def update_params(self, request: UpdateParamsRequest):
Expand All @@ -314,6 +316,7 @@ def _construct(item):
return func(*args).clone()

if not hasattr(self, '_export_iter'):
self._get_model_params()
que = Queue()
tm_model = self._tm_model
tm_model.input_model.model_path = que
Expand Down
10 changes: 0 additions & 10 deletions src/turbomind/core/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,6 @@ void Clear(Ref<Tensor> a_)
Clear(a_, Context::stream());
}

Tensor to_device(const Tensor& src, const Device& device)
{
Tensor dst;
if (src) {
dst = {src.layout(), src.dtype(), device};
Copy(src, dst, Context::stream());
}
return dst;
}

#if 0

void Copy(const Tensor& src, Tensor& dst, Stream& stream)
Expand Down
2 changes: 0 additions & 2 deletions src/turbomind/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ void Clear(Ref<Tensor> a_, const Stream& stream);

void Clear(Ref<Tensor> a_);

Tensor to_device(const Tensor& src, const Device& device);

#if 0

void Copy(const Tensor& src, Tensor&& dst, Stream& stream);
Expand Down
5 changes: 1 addition & 4 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ LlamaBatch::~LlamaBatch()
LlamaBatch::LlamaBatch(DataType data_type,
const EngineParam& param,
std::unique_ptr<LlamaV2> model, // ! This is moved
std::unique_ptr<Context> ctx, // ! This is moved
std::shared_ptr<Context> ctx,
std::shared_ptr<Gateway> gateway,
int device_id,
int dp_rank):
Expand Down Expand Up @@ -1907,9 +1907,6 @@ void LlamaBatch::DestroyCommunicators()
FreeSymmBuffers();
comm_.h_comm->Sync();

// Destroy device communicator
comm_.d_comm = {};

cudaStreamSynchronize(stream_);
comm_.h_comm->Sync();
}
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class LlamaBatch {
explicit LlamaBatch(DataType data_type,
const EngineParam& param,
std::unique_ptr<LlamaV2> model,
std::unique_ptr<Context> ctx,
std::shared_ptr<Context> ctx,
std::shared_ptr<Gateway> gateway,
int device_id,
int dp_rank);
Expand Down Expand Up @@ -226,7 +226,7 @@ class LlamaBatch {

int session_len_; // May be truncated in ctor

std::unique_ptr<Context> context_;
std::shared_ptr<Context> context_;
std::unique_ptr<LlamaV2> model_;
std::unique_ptr<SequenceManager> sequence_manager_;

Expand Down
13 changes: 12 additions & 1 deletion src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,20 @@ void LlamaWeight::to_device(const core::Device& device)
{
core::ContextGuard guard = context();

auto to_device = [&](Tensor& x) -> Tensor {
auto tmp = std::exchange(x, empty_like(x, device));
Copy(tmp, x);
return tmp;
};

std::vector<Tensor> tmp_cpu_tensors;

auto tensor_ptr_map = get_parameters();
for (auto& [name, tensor_ptr] : tensor_ptr_map) {
*tensor_ptr = core::to_device(*tensor_ptr, device);
auto tmp_tensor = to_device(*tensor_ptr);
if (tmp_tensor.device().type != kDEVICE) {
tmp_cpu_tensors.push_back(tmp_tensor);
}
}
core::Context::stream().Sync();
if (device.type == kCPU) {
Expand Down
7 changes: 4 additions & 3 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,13 @@ PYBIND11_MODULE(_turbomind, m)
"level"_a)
.def(
"wakeup",
[](LlamaTritonModel* model, int deviceId, const std::vector<std::string>& tags) {
model->wakeup(deviceId, tags);
[](LlamaTritonModel* model, int deviceId, const std::vector<std::string>& tags, int rank) {
model->wakeup(deviceId, tags, rank);
},
py::call_guard<py::gil_scoped_release>(),
"device_id"_a,
"tags"_a)
"tags"_a,
"rank"_a)
.def("__str__", &LlamaTritonModel::toString)
.def("__repr__", &LlamaTritonModel::toString)
.def("get_tensor_para_size", &LlamaTritonModel::getTensorParaSize)
Expand Down
55 changes: 39 additions & 16 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,16 @@ LlamaTritonModel::~LlamaTritonModel()
{
FT_CHECK(weights_.size() == engines_.size());

gateway_->shutdown();
if (gateway_) {
gateway_->shutdown();
}

for (int device_id = 0; device_id < (int)engines_.size(); ++device_id) {
// Set device id before destructing CUDA resources
CudaDeviceGuard dev_guard(engine_param_.devices[device_id]);
engines_[device_id].reset();
weights_[device_id].reset();
contexts_[device_id].reset();
trim_default_mempool(engine_param_.devices[device_id]);
}
}
Expand Down Expand Up @@ -388,9 +391,11 @@ LlamaTritonModel::LlamaTritonModel(DataType dtype,
handleMissingParams();

gateway_ = std::make_shared<Gateway>(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory);
ffi_ctx_factory_ = ffi_ctx_factory;

weights_.resize(engine_param_.devices.size());
engines_.resize(engine_param_.devices.size());
contexts_.resize(engine_param_.devices.size());

const std::string weight_type_str = model_reader["weight_type"].as<std::string>();
if (weight_type_str == "fp16" || weight_type_str == "float16") {
Expand Down Expand Up @@ -513,12 +518,15 @@ void LlamaTritonModel::createEngine(int device_id, int rank)
{
CudaDeviceGuard dev_guard(engine_param_.devices[device_id]);

auto ctx = std::make_unique<Context>(engine_param_.devices[device_id]);
auto& ctx = contexts_[device_id];
const bool first_create = (ctx == nullptr);
if (first_create) {
ctx = std::make_shared<Context>(engine_param_.devices[device_id]);
ctx->comm = createCommSplits(rank);
}

core::ContextGuard guard{ctx->core_stream, ctx->allocator, Allocator{kCPUpinned}};

ctx->comm = createCommSplits(rank);

const auto& engine_param = engine_params_.at(rank);

// Get `h_comm` first as ctx will be moved later
Expand All @@ -541,9 +549,9 @@ void LlamaTritonModel::createEngine(int device_id, int rank)
try {
const int dp_rank = engine_param.outer_dp_rank * engine_param.attn_dp_size + engine_param.attn_dp_rank;
engines_[device_id] = std::make_unique<Engine>(dtype_,
engine_param_, //
engine_param, //
std::move(model),
std::move(ctx),
ctx,
gateway_,
engine_param_.devices[device_id],
dp_rank);
Expand All @@ -560,12 +568,14 @@ void LlamaTritonModel::createEngine(int device_id, int rank)

auto& engine = *engines_[device_id];

try {
engine.Warmup();
}
catch (const std::exception& e) {
TM_LOG_ERROR("[Engine][Warmup] %s", e.what());
throw;
if (first_create) {
try {
engine.Warmup();
}
catch (const std::exception& e) {
TM_LOG_ERROR("[Engine][Warmup] %s", e.what());
throw;
}
}

h_comm->Sync();
Expand All @@ -592,14 +602,21 @@ void LlamaTritonModel::sleep(int device_id, int level)
}
else {
// offload weights to CPU
TM_CHECK(moe_param_.experts_per_token == 0) << "level 1 sleep not supported for MoE model";
weights_[device_id]->to_device(kCPU);
}

// free kv cache
engines_[device_id]->FreeBufferAndKVCache();
// free model (kv cache and buffer)
if (device_id == 0) {
gateway_->shutdown();
gateway_.reset();
}
engines_[device_id].reset();
contexts_[device_id]->allocator->trim(0);
trim_default_mempool(engine_param_.devices[device_id]);
}

void LlamaTritonModel::wakeup(int device_id, const std::vector<std::string>& tags)
void LlamaTritonModel::wakeup(int device_id, const std::vector<std::string>& tags, int rank)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

Expand All @@ -618,7 +635,13 @@ void LlamaTritonModel::wakeup(int device_id, const std::vector<std::string>& tag
}

if (keys.find("kv_cache") != keys.end()) {
engines_[device_id]->InitializeBufferAndKVCache();
if (device_id == 0) {
gateway_ =
std::make_shared<Gateway>(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory_);
}
TM_CHECK(contexts_[device_id] != nullptr);
contexts_[device_id]->comm.h_comm->Sync();
createEngine(device_id, rank);
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class LlamaTritonModel {

void sleep(int device_id, int level);

void wakeup(int device_id, const std::vector<std::string>& tags);
void wakeup(int device_id, const std::vector<std::string>& tags, int rank);

std::string toString();

Expand Down Expand Up @@ -86,12 +86,15 @@ class LlamaTritonModel {

std::vector<std::unique_ptr<comm::HostGroupId>> group_ids_;

std::shared_ptr<Gateway> gateway_;
std::shared_ptr<Gateway> gateway_;
std::function<std::shared_ptr<void>()> ffi_ctx_factory_;

// Weights & engine instances for the ranks
std::vector<std::shared_ptr<LlamaWeight>> weights_;
std::vector<std::shared_ptr<Engine>> engines_;

std::vector<std::shared_ptr<Context>> contexts_;

bool is_fp16_;

std::string model_name_;
Expand Down
Loading