Skip to content

Commit aea352c

Browse files
committed
fix pd
1 parent 364d67c commit aea352c

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

lmdeploy/pytorch/engine/cache_engine.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,13 +422,7 @@ def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistSe
422422

423423
async def migrate(self, migration_execution_inputs: MigrationExecutionBatch):
424424

425-
def get_assignment_len():
426-
head_dim = self.model_config.get_head_size()
427-
num_heads = self.model_config.num_key_value_heads // self.world_size
428-
block_size = self.cache_config.block_size
429-
return head_dim * num_heads * block_size * self.model_config.dtype.itemsize
430-
431-
assignment_len = get_assignment_len()
425+
assignment_len = self.full_gpu_cache.element_size() * self.full_gpu_cache.size(-1)
432426
layer_stride = self.cache_config.num_gpu_blocks * assignment_len
433427

434428
def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote_layer_stride):

0 commit comments

Comments
 (0)