Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c9fa7d2
fix non-deterministic issue seen in multiple beams
lisaliu1 Sep 17, 2025
9568032
Merge remote-tracking branch 'origin/main' into lisal.multi-beams
lisaliu1 Sep 17, 2025
868aa52
clean up
lisaliu1 Sep 17, 2025
f0689eb
use original topk
lisaliu1 Sep 17, 2025
89f73a5
remove commented code
lisaliu1 Sep 17, 2025
0187b7a
Merge branch 'lisal.multi-beams' into lisal.fix-kv-cache-issues
lisaliu1 Sep 18, 2025
ce7725a
fix issues related with base attention cache
lisaliu1 Sep 18, 2025
ecb437a
single batch and single beam trie == base
lisaliu1 Sep 18, 2025
0a1166c
do not free pages in other batch req
lisaliu1 Sep 18, 2025
79e0897
fix page leak in trie
lisaliu1 Sep 19, 2025
00486a4
cleanup
lisaliu1 Sep 19, 2025
9cb9d34
trie works with batch > 1
lisaliu1 Sep 19, 2025
54914ba
resolve conflict
lisaliu1 Sep 19, 2025
5ec78e5
fix issues detected by unit tests
lisaliu1 Sep 20, 2025
d17bb9a
multi beam multi batch and multi work with Trie works
lisaliu1 Sep 22, 2025
fc36d0b
cleanup
lisaliu1 Sep 22, 2025
327648c
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 22, 2025
e494be3
move the duplicated pages to shared allocated_pages for handling same…
lisaliu1 Sep 22, 2025
3cf0f83
re-org cache APIs
lisaliu1 Sep 23, 2025
d767eed
cleanup message.py
lisaliu1 Sep 23, 2025
ca9fad0
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 23, 2025
7d8b285
remove commentted lines
lisaliu1 Sep 23, 2025
f90b7af
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 23, 2025
26e603e
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 23, 2025
9bb6d86
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 24, 2025
99cf478
remove duplicated_code
lisaliu1 Sep 24, 2025
f9a46f6
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 24, 2025
da8994a
remove unused argument in cache API
lisaliu1 Sep 24, 2025
b85fcfa
remove unnecessary locking in decoder
lisaliu1 Sep 24, 2025
5523c75
remove locking in PageManager
lisaliu1 Sep 24, 2025
8dda840
fix creating duplicated trie node and the page leakage issue caused b…
lisaliu1 Sep 24, 2025
4b43ede
remove lookup from allocate in Trie
lisaliu1 Sep 25, 2025
27ac25d
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 25, 2025
195fef8
cleanup comments
lisaliu1 Sep 25, 2025
c5a49e0
cleanup APIs
lisaliu1 Sep 25, 2025
486c496
cleanup
lisaliu1 Sep 25, 2025
05e0f9b
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 25, 2025
a387265
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 25, 2025
ef4e0ef
cleanup a bit
lisaliu1 Sep 25, 2025
4d124c1
publish pages after decode request finishes
lisaliu1 Sep 25, 2025
9216766
evict pages work, but after eviction mismatch happens
lisaliu1 Sep 26, 2025
5b7d3ff
solved duplicated pages issue caused by idential requests, but evict …
lisaliu1 Sep 26, 2025
c4f2d68
only evict during creating prefill request, got inconsistent results …
lisaliu1 Sep 26, 2025
3b0cdae
update decoder
lisaliu1 Sep 26, 2025
f23bf5a
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 26, 2025
e4e9f6d
fix ref_count issue with alllocate nodes along the same branch of the…
lisaliu1 Sep 27, 2025
b1697d5
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 27, 2025
886ac18
remove evict_pages from lookup
lisaliu1 Sep 28, 2025
c7b36d9
Trie cached works with all tests including one tests for evicting pages
lisaliu1 Sep 28, 2025
1791fe8
remove debugging print
lisaliu1 Sep 28, 2025
366ab83
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 29, 2025
2f805ac
Merge remote-tracking branch 'origin/main' into lisal.fix-batch-trie
lisaliu1 Sep 30, 2025
92acd41
remove lookup in decoder loop as performance mainly comes from re-use…
lisaliu1 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions shortfin/python/shortfin_apps/llm/components/decoder/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import numpy as np
import threading
import math

from ..prefill_config import PrefillConfig

Expand Down Expand Up @@ -120,8 +121,7 @@ def __init__(
):
self._page_cache = page_cache
self._page_pool = page_pool
self._allocated_pages = []
self._allocated_page_ids = []

self._free_pages = []
self._beam_page_ids = [[]]

Expand All @@ -146,12 +146,24 @@ def allocate(
acquire_count = max(count, self._allocation_block_size)
if not allocate_block:
acquire_count = count

# do not lookup published tokens as the major performance improvement comes from re-using partially filled pages in prefill phase
acquired_cache_info = self._page_cache.allocate(
input_token_ids, acquire_count, req.allocated_cache_info
input_token_ids,
req.allocated_cache_info,
acquire_count,
)

acquired = acquired_cache_info.pages[len(req.allocated_cache_info.pages) :]
self._free_pages.extend([p.index for p in acquired])
pages = req.allocated_cache_info.pages + acquired[:count]
req.allocated_cache_info = acquired_cache_info
req.allocated_cache_info.pages = pages
else:
req.allocated_cache_info.num_tokens += len(input_token_ids)
req.allocated_cache_info.tokens.extend(input_token_ids)
free_pages = self._page_cache.get_allocated_pages(self._free_pages[:count])
req.allocated_cache_info.pages.extend(free_pages)
allocation = self._free_pages[:count]
self._free_pages = self._free_pages[count:]
return allocation, req
Expand Down Expand Up @@ -191,10 +203,14 @@ def _update_decode_reqs_existing_page(
)
new_page = new_pages[0]
decode_reqs[i].allocated_cache_info = req.allocated_cache_info

if beam[-1] != new_page:
self._page_pool.copy_page_index(beam[-1], new_page)
beam[-1] = new_page
else:
decode_reqs[i].allocated_cache_info.num_tokens += len(
next_token_ids[i]
)
decode_reqs[i].allocated_cache_info.tokens.extend(next_token_ids[i])
used.add(beam[-1])

def update_decode_reqs(
Expand All @@ -207,23 +223,19 @@ def update_decode_reqs(
# TODO: Allocation more requests
if len(decode_reqs) < len(tokens):
raise ValueError("NEED TO ALLOCATE MORE REQS")

next_token_ids = []
for token in tokens:
next_tokens = [token]
next_token_ids.append(next_tokens)

if len(select) == 0:
return

new_page = (self._position % self._tokens_per_page) == 0
new_beam_page_ids = [[p for p in self._beam_page_ids[b]] for b in select]

old_pages = set(itertools.chain.from_iterable(self._beam_page_ids))
new_pages = set(itertools.chain.from_iterable(new_beam_page_ids))

free_pages = old_pages - new_pages
self._free_pages.extend(free_pages)

if new_page:
self._update_decode_reqs_new_page(
new_beam_page_ids, next_token_ids, decode_reqs
Expand All @@ -232,10 +244,8 @@ def update_decode_reqs(
self._update_decode_reqs_existing_page(
new_beam_page_ids, next_token_ids, decode_reqs
)

self._beam_page_ids = new_beam_page_ids
self._position += 1

# setup decode_reqs
for i, ids in enumerate(next_token_ids):
decode_reqs[i].input_token_ids = ids
Expand All @@ -244,8 +254,8 @@ def update_decode_reqs(
return decode_reqs[: len(tokens)]

def release_pages(self):
self._page_pool.free_pages(self._allocated_pages)
self._allocated_pages = []
self._page_cache.free_allocated_pages(self._free_pages)
self._free_pages = []


class TokenSelector:
Expand Down Expand Up @@ -430,10 +440,14 @@ def create_prefill_req(self, input_ids):

async def run(self, input_ids):
input_length = len(input_ids)
prefill_req = self.create_prefill_req(input_ids)
prefill_req = None
with self._lock:
prefill_req = self.create_prefill_req(input_ids)

# Run Prefill:
self._unified_batcher.submit(prefill_req)
await prefill_req.done
prefill_req.publish_allocated_pages(publish_incomplete_page=False)

token_selector = TokenSelector(self._decode_config)
initial_pages = [p.index for p in prefill_req.allocated_cache_info.pages]
Expand Down Expand Up @@ -482,9 +496,6 @@ async def run(self, input_ids):
[req.result_indices for req in to_run],
)

for req in decode_reqs:
req.publish_allocated_pages()

# Remove the reservation:
self._unified_batcher.reserve_workload(
rid=prefill_req.orig_instance_id, count=0
Expand All @@ -496,5 +507,9 @@ async def run(self, input_ids):
# Return Results:
self._results_callback(completed)

for req in decode_reqs:
req.free_cache_pages()
with self._lock:
for req in decode_reqs:
req.publish_allocated_pages(publish_incomplete_page=True)
req.free_cache_pages()

page_manager.release_pages()
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ class CacheInfo:
"""
Metadata about the allocated cache space.
- num_tokens: Number of tokens allocated in the cache.
- tokens: List of tokens in the allocation
- pages: The actual pages allocated in the cache.
- pool: The cache store where this information is stored.
- last_cached_node: Optional reference to the last cached node, if applicable.
"""

num_tokens: int
tokens: List[int]
pages: Any # This should be a list of PageInfo or similar objects.
pool: CacheStoreAbstract
last_cached_node: Any # Optional reference to the last cached node, if applicable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ def __init__(
self._ref_count_lock: None | threading.Lock = (
None if not use_ref_counts else threading.Lock()
)
self._allocated_pages: List[
PageInfo
] = [] # global allocated page pool that contains all un-tracked pages

def shutdown(self):
self.page_pool.free_pages(self._allocated_pages)
available = self.page_pool.available_page_count()
total = self.page_pool.total_page_count()
if available != total:
Expand Down Expand Up @@ -144,23 +148,41 @@ def free_pages(self, pages: List[PageInfo]):

self.page_pool.free_pages(pages_to_free)

def fork_pages(self, pages: List[PageInfo]) -> List[PageInfo]:
new_pages = pages.copy()
def fork_pages(self, tokens: list[int], cache_info: CacheInfo) -> CacheInfo:
new_pages = cache_info.pages.copy()
last_page = new_pages.pop(-1)
new_page = self.page_pool.copy_page(last_page)
if new_page is None:
raise CacheAllocationFailure()

new_pages.append(new_page)
self.increment_pages(new_pages)
return BasePagedAttentionCacheAllocation(new_pages, cache=self)
cache_info.pages = new_pages
cache_info.tokens.extend(tokens)
cache_info.num_tokens += len(tokens)
return cache_info

def lookup(self, tokens: List[int]) -> CacheInfo:
return CacheInfo(
num_tokens=0,
tokens=[],
pages=[],
pool=self.page_pool,
last_cached_node=None,
)

def get_allocated_pages(self, page_ids: List[int]) -> List[PageInfo]:
pages = []
for page in self._allocated_pages:
if page.index in page_ids:
pages.append(page)
return pages

def allocate(
self,
tokens: List[int],
allocation_block_size: int = 0,
cache_info: CacheInfo = None,
lookup: bool = True,
allocation_block_size: int = 0,
evict: bool = True,
) -> CacheInfo:
"""
Expand All @@ -187,47 +209,27 @@ def allocate(
if cache_info is not None:
pages = cache_info.pages + pages
num_tokens += cache_info.num_tokens
self._allocated_pages.extend(pages)
return CacheInfo(
num_tokens=num_tokens,
tokens=tokens,
pages=pages,
pool=self.page_pool,
last_cached_node=None,
)

def extend_allocation(
self, tokens, cache_info, *, extra_token_slots=0
) -> CacheInfo:
# assert old tokens are a prefix of incoming tokens
# if we don't have enough pages to hold the tokens, we need to allocate more pages
token_count = len(tokens) + extra_token_slots
pages_needed = math.ceil(token_count / self.tokens_per_page)
if pages_needed > len(cache_info.pages):
new_pages = self.page_pool.acquire_free_pages(
pages_needed - len(cache_info.pages)
)
if new_pages is None:
msg = (
f"FATAL CacheAllocationFailure: Failed to allocate {pages_needed - len(self._pages)} pages from `PagePool`.\n"
f"Required pages: {pages_needed}, Available pages: {len(self._cache.page_pool.available_pages)}, Total pages: {self._cache.page_pool.config.alloc_page_count}\n"
f"Consider re-exporting the model with a higher `--device-block-count` value."
)
logger.error(msg)
raise CacheAllocationFailure(msg)
if self.use_ref_counts:
self.increment_pages(new_pages)

return CacheInfo(
num_tokens=token_count,
pages=cache_info.pages + tuple(new_pages),
pool=self.page_pool,
last_cached_node=cache_info.last_cached_node,
)

def publish_pages_for_tokens(
self, tokens, cache_info, *, publish_incomplete_page=False
self, cache_info, *, publish_incomplete_page=False
) -> CacheInfo:
return cache_info # no-op for base class

def release_pages(self, cache_info: CacheInfo):
if cache_info is not None:
self.free_pages(cache_info.pages)

def free_allocated_pages(self, page_ids: List[int]):
pages = []
for page in self._allocated_pages:
if page.index in page_ids:
pages.append(page)
self.free_pages(pages)
Loading
Loading