Skip to content

Commit 51cbd2c

Browse files
authored
reserve blocks for dummy inputs (#4157)
* reserve blocks for dummy inputs * fix sliding window * limit session len * remove comment
1 parent 4416bd3 commit 51cbd2c

File tree

8 files changed

+49
-49
lines changed

8 files changed

+49
-49
lines changed

lmdeploy/pytorch/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class CacheConfig:
9090
num_state_caches: int = None
9191
states_shapes: List[Tuple] = field(default_factory=list)
9292

93+
# reserved blocks for dummy inputs, init to 0 for unit test.
94+
num_reserved_gpu_blocks: int = 0
95+
9396
# For PD Disaggregation
9497
role: EngineRole = EngineRole.Hybrid
9598
migration_backend: MigrationBackend = MigrationBackend.DLSlime

lmdeploy/pytorch/engine/engine.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,20 @@ def _build_scheduler_config(engine_config: PytorchEngineConfig):
113113

114114
def _build_cache_config(engine_config: PytorchEngineConfig):
115115
"""Build cache config."""
116-
cache_config = CacheConfig(max_batches=engine_config.max_batch_size,
117-
block_size=engine_config.block_size,
118-
num_cpu_blocks=engine_config.num_cpu_blocks,
119-
num_gpu_blocks=engine_config.num_gpu_blocks,
120-
cache_max_entry_count=engine_config.cache_max_entry_count,
121-
max_prefill_token_num=engine_config.max_prefill_token_num,
122-
enable_prefix_caching=engine_config.enable_prefix_caching,
123-
quant_policy=engine_config.quant_policy,
124-
device_type=engine_config.device_type,
125-
migration_backend=engine_config.migration_backend,
126-
role=engine_config.role)
116+
cache_config = CacheConfig(
117+
max_batches=engine_config.max_batch_size,
118+
block_size=engine_config.block_size,
119+
num_cpu_blocks=engine_config.num_cpu_blocks,
120+
num_gpu_blocks=engine_config.num_gpu_blocks,
121+
cache_max_entry_count=engine_config.cache_max_entry_count,
122+
max_prefill_token_num=engine_config.max_prefill_token_num,
123+
enable_prefix_caching=engine_config.enable_prefix_caching,
124+
quant_policy=engine_config.quant_policy,
125+
device_type=engine_config.device_type,
126+
migration_backend=engine_config.migration_backend,
127+
role=engine_config.role,
128+
# reserve 1 blocks for dummy input and padding
129+
num_reserved_gpu_blocks=1)
127130
return cache_config
128131

129132

@@ -542,7 +545,8 @@ def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, e
542545
def _get_max_session_len(self):
543546
"""Get max session len."""
544547
session_len = self.scheduler_config.max_session_len
545-
max_tokens = (self.cache_config.num_gpu_blocks * self.cache_config.block_size)
548+
num_gpu_blocks = self.cache_config.num_gpu_blocks - self.cache_config.num_reserved_gpu_blocks
549+
max_tokens = (num_gpu_blocks * self.cache_config.block_size)
546550
window_size = self.cache_config.window_size
547551
if window_size > 0 and window_size <= max_tokens:
548552
max_tokens = (1 << 63) - 1

lmdeploy/pytorch/model_inputs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,6 @@ def new(
406406
# seq_len + history_length
407407
kv_seqlens = q_seqlens + history_seqlens
408408
kv_seqlens -= inputs.num_ignored_history
409-
if inputs.is_dummy:
410-
kv_seqlens = torch.zeros_like(kv_seqlens)
411409

412410
ret = StepContext(
413411
input_ids=inputs.input_ids,

lmdeploy/pytorch/paging/block_manager/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ def build_block_manager(cache_config: CacheConfig) -> BaseBlockManager:
1515
num_cpu_blocks = cache_config.num_cpu_blocks
1616
num_gpu_blocks = cache_config.num_gpu_blocks
1717
window_size = cache_config.window_size
18+
num_gpu_reserved = cache_config.num_reserved_gpu_blocks
1819

1920
if window_size < 0:
20-
return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks)
21+
return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved=num_gpu_reserved)
2122
else:
22-
return WindowBlockManager(num_gpu_blocks, num_cpu_blocks, window_size=window_size)
23+
return WindowBlockManager(num_gpu_blocks,
24+
num_cpu_blocks,
25+
window_size=window_size,
26+
num_gpu_reserved=num_gpu_reserved)

lmdeploy/pytorch/paging/block_manager/base_block_manager.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,13 @@ def num_blocks(self):
2828
return self._num_blocks
2929

3030

31-
class PhysicalMemory:
32-
"""Physical memory blocks."""
33-
34-
def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
35-
self._num_cpu_blocks = num_cpu_blocks
36-
self._num_gpu_blocks = num_gpu_blocks
37-
self._num_blocks = num_cpu_blocks + num_gpu_blocks
38-
39-
def num_cpu_blocks(self):
40-
"""Get num cpu blocks."""
41-
return self._num_cpu_blocks
42-
43-
def num_gpu_blocks(self):
44-
"""Get num gpu blocks."""
45-
return self._num_gpu_blocks
46-
47-
4831
class PhysicalAllocator:
4932
"""The physical block allocator.
5033
5134
The allocator won't allocate real memory. It is used to support block manager.
5235
"""
5336

54-
def __init__(self, memory: PhysicalMemory, num_blocks: int, offset: int = 0):
55-
self._mem = memory
37+
def __init__(self, num_blocks: int, offset: int = 0):
5638
self._num_blocks = num_blocks
5739
self._offset = offset
5840

@@ -87,13 +69,13 @@ def get_num_free_blocks(self):
8769
class LogicalAllocator:
8870
"""The logical block allocator."""
8971

90-
def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
72+
def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int, num_gpu_reserved: int = 0) -> None:
9173
self._log_mem = LogicalMemory(num_cpu_blocks + num_gpu_blocks)
92-
self._phy_mem = PhysicalMemory(num_cpu_blocks, num_gpu_blocks)
9374

9475
self._cpu_mem_offset = num_gpu_blocks
95-
self._gpu_allocator = PhysicalAllocator(self._phy_mem, num_gpu_blocks, 0)
96-
self._cpu_allocator = PhysicalAllocator(self._phy_mem, num_cpu_blocks, self._cpu_mem_offset)
76+
num_gpu_blocks -= num_gpu_reserved
77+
self._gpu_allocator = PhysicalAllocator(num_gpu_blocks, num_gpu_reserved)
78+
self._cpu_allocator = PhysicalAllocator(num_cpu_blocks, self._cpu_mem_offset)
9779

9880
num_blocks = self._log_mem.num_blocks()
9981
self._num_blocks = num_blocks
@@ -225,11 +207,11 @@ class BaseBlockManager:
225207
num_cpu_blocks (int): number of cpu blocks.
226208
"""
227209

228-
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
210+
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, num_gpu_reserved: int = 0) -> None:
229211
self.num_gpu_blocks = num_gpu_blocks
230212
self.num_cpu_blocks = num_cpu_blocks
231213

232-
self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks)
214+
self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks, num_gpu_reserved)
233215

234216
self.block_tables: Dict[int, BlockTable] = {}
235217

lmdeploy/pytorch/paging/block_manager/window_block_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class WindowBlockManager(DefaultBlockManager):
2929
num_cpu_blocks (int): number of cpu blocks.
3030
"""
3131

32-
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, window_size: int):
33-
super().__init__(num_gpu_blocks, num_cpu_blocks)
32+
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, window_size: int, num_gpu_reserved: int = 0):
33+
super().__init__(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved)
3434
assert window_size > 0, ('expect window size > 0, '
3535
f'but get window_size = {window_size}')
3636
self.window_size = window_size

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta
1313
from .block_manager import build_block_manager
1414
from .block_trie import BlockTrie
15-
from .state_manager import StateManager
15+
from .state_manager import build_state_manager
1616

1717
logger = get_logger('lmdeploy')
1818

@@ -52,7 +52,7 @@ def __init__(
5252

5353
self.block_manager = build_block_manager(cache_config)
5454
self.block_trie = BlockTrie(self.cache_config, self.block_manager)
55-
self.state_manager = StateManager(self.cache_config.num_state_caches)
55+
self.state_manager = build_state_manager(self.cache_config)
5656
self.is_ssm = len(self.cache_config.states_shapes) > 0
5757

5858
self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type)

lmdeploy/pytorch/paging/state_manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
33

4+
from lmdeploy.pytorch.config import CacheConfig
45
from lmdeploy.pytorch.messages import SchedulerSequence
56

67

78
class StateAllocator:
89
"""State allocator."""
910

10-
def __init__(self, num_states: int):
11+
def __init__(self, num_states: int, offset: int = 0):
1112
self.num_states = num_states
12-
self._free_states = np.arange(num_states, dtype=np.int64)
13+
self._free_states = np.arange(offset, offset + num_states, dtype=np.int64)
1314
self._free_count = num_states
1415

1516
def allocate(self):
@@ -33,10 +34,10 @@ def get_num_free(self):
3334

3435
class StateManager:
3536

36-
def __init__(self, num_states: int):
37+
def __init__(self, num_states: int, num_reserved: int = 0):
3738
if num_states is None:
3839
num_states = 1
39-
self.allocator = StateAllocator(num_states)
40+
self.allocator = StateAllocator(num_states, offset=num_reserved)
4041

4142
def is_allocated(self, seq: SchedulerSequence):
4243
"""Check if a sequence is allocated."""
@@ -58,3 +59,11 @@ def free(self, seq: SchedulerSequence):
5859
def get_num_free(self):
5960
"""Get num free."""
6061
return self.allocator.get_num_free()
62+
63+
64+
def build_state_manager(cache_config: CacheConfig) -> StateManager:
65+
"""Build state manager."""
66+
num_states = cache_config.num_state_caches
67+
# state is different from block, we always reserve one state for system use
68+
num_reserved = 1
69+
return StateManager(num_states, num_reserved)

0 commit comments

Comments
 (0)