diff --git a/shortfin/python/shortfin_apps/llm/components/decoder/decoder.py b/shortfin/python/shortfin_apps/llm/components/decoder/decoder.py index f56e3d02aa..ff7b26bdab 100644 --- a/shortfin/python/shortfin_apps/llm/components/decoder/decoder.py +++ b/shortfin/python/shortfin_apps/llm/components/decoder/decoder.py @@ -9,6 +9,7 @@ import itertools import numpy as np import threading +import math from ..prefill_config import PrefillConfig @@ -120,8 +121,7 @@ def __init__( ): self._page_cache = page_cache self._page_pool = page_pool - self._allocated_pages = [] - self._allocated_page_ids = [] + self._free_pages = [] self._beam_page_ids = [[]] @@ -146,12 +146,24 @@ def allocate( acquire_count = max(count, self._allocation_block_size) if not allocate_block: acquire_count = count + + # do not lookup published tokens as the major performance improvement comes from re-using partially filled pages in prefill phase acquired_cache_info = self._page_cache.allocate( - input_token_ids, acquire_count, req.allocated_cache_info + input_token_ids, + req.allocated_cache_info, + acquire_count, ) + acquired = acquired_cache_info.pages[len(req.allocated_cache_info.pages) :] self._free_pages.extend([p.index for p in acquired]) + pages = req.allocated_cache_info.pages + acquired[:count] req.allocated_cache_info = acquired_cache_info + req.allocated_cache_info.pages = pages + else: + req.allocated_cache_info.num_tokens += len(input_token_ids) + req.allocated_cache_info.tokens.extend(input_token_ids) + free_pages = self._page_cache.get_allocated_pages(self._free_pages[:count]) + req.allocated_cache_info.pages.extend(free_pages) allocation = self._free_pages[:count] self._free_pages = self._free_pages[count:] return allocation, req @@ -191,10 +203,14 @@ def _update_decode_reqs_existing_page( ) new_page = new_pages[0] decode_reqs[i].allocated_cache_info = req.allocated_cache_info - if beam[-1] != new_page: self._page_pool.copy_page_index(beam[-1], new_page) beam[-1] = new_page + else: + decode_reqs[i].allocated_cache_info.num_tokens += len( + next_token_ids[i] + ) + decode_reqs[i].allocated_cache_info.tokens.extend(next_token_ids[i]) used.add(beam[-1]) def update_decode_reqs( @@ -207,23 +223,19 @@ def update_decode_reqs( # TODO: Allocation more requests if len(decode_reqs) < len(tokens): raise ValueError("NEED TO ALLOCATE MORE REQS") + next_token_ids = [] for token in tokens: next_tokens = [token] next_token_ids.append(next_tokens) - if len(select) == 0: return - new_page = (self._position % self._tokens_per_page) == 0 new_beam_page_ids = [[p for p in self._beam_page_ids[b]] for b in select] - old_pages = set(itertools.chain.from_iterable(self._beam_page_ids)) new_pages = set(itertools.chain.from_iterable(new_beam_page_ids)) - free_pages = old_pages - new_pages self._free_pages.extend(free_pages) - if new_page: self._update_decode_reqs_new_page( new_beam_page_ids, next_token_ids, decode_reqs @@ -232,10 +244,8 @@ def update_decode_reqs( self._update_decode_reqs_existing_page( new_beam_page_ids, next_token_ids, decode_reqs ) - self._beam_page_ids = new_beam_page_ids self._position += 1 - # setup decode_reqs for i, ids in enumerate(next_token_ids): decode_reqs[i].input_token_ids = ids @@ -244,8 +254,8 @@ def update_decode_reqs( return decode_reqs[: len(tokens)] def release_pages(self): - self._page_pool.free_pages(self._allocated_pages) - self._allocated_pages = [] + self._page_cache.free_allocated_pages(self._free_pages) + self._free_pages = [] class TokenSelector: @@ -430,10 +440,14 @@ def create_prefill_req(self, input_ids): async def run(self, input_ids): input_length = len(input_ids) - prefill_req = self.create_prefill_req(input_ids) + prefill_req = None + with self._lock: + prefill_req = self.create_prefill_req(input_ids) + # Run Prefill: self._unified_batcher.submit(prefill_req) await prefill_req.done + prefill_req.publish_allocated_pages(publish_incomplete_page=False) token_selector = TokenSelector(self._decode_config) initial_pages = [p.index for p in prefill_req.allocated_cache_info.pages] @@ -482,9 +496,6 @@ async def run(self, input_ids): [req.result_indices for req in to_run], ) - for req in decode_reqs: - req.publish_allocated_pages() - # Remove the reservation: self._unified_batcher.reserve_workload( rid=prefill_req.orig_instance_id, count=0 @@ -496,5 +507,9 @@ async def run(self, input_ids): # Return Results: self._results_callback(completed) - for req in decode_reqs: - req.free_cache_pages() + with self._lock: + for req in decode_reqs: + req.publish_allocated_pages(publish_incomplete_page=True) + req.free_cache_pages() + + page_manager.release_pages() diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/attention_cache_abstract.py b/shortfin/python/shortfin_apps/llm/components/kvcache/attention_cache_abstract.py index a38c6a03f1..029d16c8a9 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/attention_cache_abstract.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/attention_cache_abstract.py @@ -32,12 +32,14 @@ class CacheInfo: """ Metadata about the allocated cache space. - num_tokens: Number of tokens allocated in the cache. + - tokens: List of tokens in the allocation - pages: The actual pages allocated in the cache. - pool: The cache store where this information is stored. - last_cached_node: Optional reference to the last cached node, if applicable. """ num_tokens: int + tokens: List[int] pages: Any # This should be a list of PageInfo or similar objects. pool: CacheStoreAbstract last_cached_node: Any # Optional reference to the last cached node, if applicable. diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 7253bdb746..2090add876 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -97,8 +97,12 @@ def __init__( self._ref_count_lock: None | threading.Lock = ( None if not use_ref_counts else threading.Lock() ) + self._allocated_pages: List[ + PageInfo + ] = [] # global allocated page pool that contains all un-tracked pages def shutdown(self): + self.page_pool.free_pages(self._allocated_pages) available = self.page_pool.available_page_count() total = self.page_pool.total_page_count() if available != total: @@ -144,8 +148,8 @@ def free_pages(self, pages: List[PageInfo]): self.page_pool.free_pages(pages_to_free) - def fork_pages(self, pages: List[PageInfo]) -> List[PageInfo]: - new_pages = pages.copy() + def fork_pages(self, tokens: list[int], cache_info: CacheInfo) -> CacheInfo: + new_pages = cache_info.pages.copy() last_page = new_pages.pop(-1) new_page = self.page_pool.copy_page(last_page) if new_page is None: @@ -153,14 +157,32 @@ def fork_pages(self, pages: List[PageInfo]) -> List[PageInfo]: new_pages.append(new_page) self.increment_pages(new_pages) - return BasePagedAttentionCacheAllocation(new_pages, cache=self) + cache_info.pages = new_pages + cache_info.tokens.extend(tokens) + cache_info.num_tokens += len(tokens) + return cache_info + + def lookup(self, tokens: List[int]) -> CacheInfo: + return CacheInfo( + num_tokens=0, + tokens=[], + pages=[], + pool=self.page_pool, + last_cached_node=None, + ) + + def get_allocated_pages(self, page_ids: List[int]) -> List[PageInfo]: + pages = [] + for page in self._allocated_pages: + if page.index in page_ids: + pages.append(page) + return pages def allocate( self, tokens: List[int], - allocation_block_size: int = 0, cache_info: CacheInfo = None, - lookup: bool = True, + allocation_block_size: int = 0, evict: bool = True, ) -> CacheInfo: """ @@ -187,47 +209,27 @@ def allocate( if cache_info is not None: pages = cache_info.pages + pages num_tokens += cache_info.num_tokens + self._allocated_pages.extend(pages) return CacheInfo( num_tokens=num_tokens, + tokens=tokens, pages=pages, pool=self.page_pool, last_cached_node=None, ) - def extend_allocation( - self, tokens, cache_info, *, extra_token_slots=0 - ) -> CacheInfo: - # assert old tokens are a prefix of incoming tokens - # if we don't have enough pages to hold the tokens, we need to allocate more pages - token_count = len(tokens) + extra_token_slots - pages_needed = math.ceil(token_count / self.tokens_per_page) - if pages_needed > len(cache_info.pages): - new_pages = self.page_pool.acquire_free_pages( - pages_needed - len(cache_info.pages) - ) - if new_pages is None: - msg = ( - f"FATAL CacheAllocationFailure: Failed to allocate {pages_needed - len(self._pages)} pages from `PagePool`.\n" - f"Required pages: {pages_needed}, Available pages: {len(self._cache.page_pool.available_pages)}, Total pages: {self._cache.page_pool.config.alloc_page_count}\n" - f"Consider re-exporting the model with a higher `--device-block-count` value." - ) - logger.error(msg) - raise CacheAllocationFailure(msg) - if self.use_ref_counts: - self.increment_pages(new_pages) - - return CacheInfo( - num_tokens=token_count, - pages=cache_info.pages + tuple(new_pages), - pool=self.page_pool, - last_cached_node=cache_info.last_cached_node, - ) - def publish_pages_for_tokens( - self, tokens, cache_info, *, publish_incomplete_page=False + self, cache_info, *, publish_incomplete_page=False ) -> CacheInfo: return cache_info # no-op for base class def release_pages(self, cache_info: CacheInfo): if cache_info is not None: self.free_pages(cache_info.pages) + + def free_allocated_pages(self, page_ids: List[int]): + pages = [] + for page in self._allocated_pages: + if page.index in page_ids: + pages.append(page) + self.free_pages(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index dc05c3feb0..c023b3e394 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -60,8 +60,13 @@ def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": Returns: The newly created child node """ - new_node = TrieNode(tokens=tokens, page=page, parent=self) - self.children[tokens] = new_node + new_node = None + if tokens in self.children: + # If the child already exists, return it + new_node = self.children[tokens] + else: + new_node = TrieNode(tokens=tokens, page=page, parent=self) + self.children[tokens] = new_node return new_node def unlink(self) -> None: @@ -90,14 +95,10 @@ class TrieCacheInfo(CacheInfo): Contains information about the tokens, pages, and last cached node. Attributes: - tokens: List of tokens in the allocation last_cached_node: Last node in the trie that was cached - cached_pages: List of pages that were already cached - newly_acquired_pages: List of pages that were newly acquired for this allocation number_of_published_pages: Number of pages that have been published to the cache """ - tokens: List[int] number_of_published_pages: int last_cached_node: TrieNode @@ -139,9 +140,13 @@ def __init__(self, page_pool: PagePool, tokens_per_page: int): self.root = TrieNode(tokens=tuple(), page=dummy_page) self.leaves: Set[TrieNode] = set() self._lock: Lock = Lock() - self._allocated_pages: List[PageInfo] = [] + self._duplicated_pages: List[ + PageInfo + ] = ( + [] + ) # pages that are duplicated from existing pages in Trie tree. These pages can be safely freed when calling release_pages. - def fork_pages(self, pages: List[PageInfo], tokens: list[int]) -> TrieCacheInfo: + def fork_pages(self, tokens: list[int], cache_info: TrieCacheInfo) -> TrieCacheInfo: """Fork a sequence of pages into the trie. Share prefixes with existing nodes till N-1 tokens, then create a new node @@ -170,10 +175,10 @@ def fork_pages(self, pages: List[PageInfo], tokens: list[int]) -> TrieCacheInfo: pool=self.page_pool, ) - new_page = self.page_pool.copy_page(pages[-1]) + new_page = self.page_pool.copy_page(cache_info.pages[-1]) if new_page is None: self._evict_pages(1) - new_page = self.page_pool.copy_page(pages[-1]) + new_page = self.page_pool.copy_page(cache_info.pages[-1]) return TrieCacheInfo( tokens=list(tokens), @@ -212,6 +217,30 @@ def match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: return cur, matched_pages + def lookup(self, tokens: List[int]) -> TrieCacheInfo: + """Lookup the cache for the given token sequence. It only returns fully matched pages. + + Args: + tokens: Sequence of tokens to look up + returns: TrieCacheInfo with matched tokens and pages + """ + with self._lock: + page_aligned_token_len = ( + len(tokens) // self.tokens_per_page + ) * self.tokens_per_page + page_aligned_tokens = tokens[:page_aligned_token_len] + cur_node, matched_pages = self.match(page_aligned_tokens) + num_matched_tokens = len(matched_pages) * self.tokens_per_page + matched_tokens = page_aligned_tokens[:num_matched_tokens] + return TrieCacheInfo( + num_tokens=num_matched_tokens, + tokens=matched_tokens, + pages=matched_pages, + last_cached_node=cur_node, + number_of_published_pages=len(matched_pages), + pool=self.page_pool, + ) + def evict_pages(self, max_pages: int) -> int: """Evict up to max_pages pages using LRU strategy. @@ -236,6 +265,8 @@ def evict_pages(self, max_pages: int) -> int: # Evict least recently used nodes while unused_leaf_heap and len(pages_to_evict) < max_pages: _, leaf = heapq.heappop(unused_leaf_heap) + if leaf.page in pages_to_evict: + continue pages_to_evict.append(leaf.page) parent = leaf.parent @@ -253,9 +284,6 @@ def evict_pages(self, max_pages: int) -> int: heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) if pages_to_evict: - logger.debug( - f"TriePagedAttentionCache: Released allocated pages in evict_pages {[p.index for p in pages_to_evict]}" - ) self.page_pool.free_pages(pages_to_evict) return len(pages_to_evict) @@ -263,9 +291,8 @@ def evict_pages(self, max_pages: int) -> int: def allocate( self, tokens: List[int], - allocation_block_size: int = 0, cache_info: TrieCacheInfo = None, - lookup: bool = True, + allocation_block_size: int = 0, evict: bool = True, ) -> TrieCacheInfo: """Acquire pages for a sequence of tokens. @@ -275,8 +302,8 @@ def allocate( Args: tokens: Sequence of tokens needing pages + cache_info: Existing TrieCacheInfo to extend/update, if any allocation_block_size: number of pages to allocate at once, not used if it is 0 - lookup: Whether to look up existing tokens in the cache. evict: Whether to evict old tokens if the cache is full. Returns: @@ -288,32 +315,25 @@ def allocate( with self._lock: tokens = tuple(tokens) n_empty_pages = 0 - cached_pages = [] pages = [] cur_node = self.root - if lookup: - cur_node, matched_pages = self.match(tokens) - logger.debug( - f"TriePagedAttentionCache: Lookup found {len(matched_pages)} cached pages for token length {len(tokens)}" - ) + if not cache_info: + raise ValueError("cache_info cannot be None") - cached_pages = matched_pages - n_cached_tokens = 0 - if matched_pages: - n_cached_tokens = len(matched_pages) * self.tokens_per_page - remaining_length = len(tokens) - n_cached_tokens - n_empty_pages = math.ceil(remaining_length / self.tokens_per_page) - else: - n_empty_pages = math.ceil(len(tokens) / self.tokens_per_page) + cur_node = cache_info.last_cached_node - if not cached_pages and allocation_block_size > 0: - n_empty_pages = allocation_block_size + n_empty_pages = math.ceil(len(tokens) / self.tokens_per_page) + + if allocation_block_size > 0: + n_empty_pages = max(n_empty_pages, allocation_block_size) new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is None and evict: # Try eviction - self.evict_pages(n_empty_pages - len(self.page_pool.available_pages)) + number_evicted_pages = self.evict_pages( + n_empty_pages - len(self.page_pool.available_pages) + ) new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is None: @@ -321,95 +341,43 @@ def allocate( "Failed to acquire pages even after attempting eviction from LRU leaves" ) - cur_node.ref_count.increment() - pages = cached_pages + new_pages - self._allocated_pages.extend(new_pages) + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages and eviction is disabled" + ) - num_tokens = len(tokens) - if cache_info: - if ( - cache_info.last_cached_node - and not cache_info.last_cached_node.ref_count.is_empty() + if len(new_pages) > 0: + # some new pages are allocated and will be used to create children of cur_node, hence increment ref_count of cur_node. + # do not increment last_cached_node ref_count when we allocate along the same branch again + if len(cache_info.pages) == 0 or ( + len(cache_info.pages) > 0 + and cur_node.page.index == cache_info.pages[-1].index ): - cache_info.last_cached_node.ref_count.decrement() - pages = cache_info.pages + pages - num_tokens += cache_info.num_tokens + cur_node.ref_count.increment() + + self._allocated_pages.extend(new_pages) + pages = cache_info.pages + new_pages + num_tokens = len(tokens) + cache_info.num_tokens + tokens = cache_info.tokens + list(tokens) + number_of_published_pages = cache_info.number_of_published_pages return TrieCacheInfo( - num_tokens=len(tokens), + num_tokens=num_tokens, tokens=tokens, pages=pages, last_cached_node=cur_node, - number_of_published_pages=len(cached_pages), + number_of_published_pages=number_of_published_pages, pool=self.page_pool, ) - def extend_allocation( - self, tokens: List[int], cache_info: TrieCacheInfo, *, extra_token_slots=0 - ) -> TrieCacheInfo: - """Extend the current allocation to accommodate additional tokens. - - Args: - tokens: New token sequence to extend the allocation to - extra_token_slots: Additional token slots to allocate. - - This allows us to allocate additional space for future token(s). - - Raises: - ValueError: If new tokens don't extend current allocation's tokens - """ - # Verify new tokens extend current tokens - if len(tokens) < len(cache_info.tokens): - raise ValueError("New tokens must be longer than current tokens") - - # Check that current tokens are a prefix of new tokens - if tokens[: len(cache_info.tokens)] != cache_info.tokens: - raise ValueError("New tokens must extend current token sequence") - - # If tokens are identical, no extension needed - if len(tokens) == len(cache_info.tokens): - return cache_info - - # Calculate how many new pages we need - tokens_per_page = self.tokens_per_page - current_pages = len(cache_info.pages) - total_tokens = len(tokens) + extra_token_slots - total_pages_needed = math.ceil(total_tokens / tokens_per_page) - new_pages_needed = total_pages_needed - current_pages - - pages = cache_info.pages - if new_pages_needed > 0: - # Acquire new pages - new_pages = self.page_pool.acquire_free_pages(new_pages_needed) - - if new_pages is None: - # Try eviction if initial allocation fails - self.evict_pages(new_pages_needed - len(self.page_pool.available_pages)) - new_pages = self.page_pool.acquire_free_pages(new_pages_needed) - - if new_pages is None: - raise CacheAllocationFailure( - "Failed to acquire pages for allocation extension even after attempting eviction" - ) - - # Extend our page list - pages.extend(new_pages) - return TrieCacheInfo( - num_tokens=len(tokens), - tokens=deepcopy(tokens), - pages=cache_info.pages, - pool=cache_info.page_pool, - last_cached_node=cache_info.last_cached_node, - number_of_published_pages=cache_info.number_of_pages_to_publish, - ) - def publish_pages_for_tokens( - self, tokens: List[int], cache_info: TrieCacheInfo + self, cache_info: TrieCacheInfo, publish_incomplete_page: bool = False ) -> TrieCacheInfo: """Make pages available in the cache for the specified tokens. Args: - tokens_to_publish: Tokens to publish to the cache cache_info: TrieCacheInfo object containing allocation metadata + publish_incomplete_page: Whether to publish the last page even if it is not full Raises: ValueError: If tokens don't match allocation or exceed available pages @@ -417,17 +385,27 @@ def publish_pages_for_tokens( with self._lock: # If we have more tokens, publish pages up to the incoming tokens. # If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens. + if not cache_info: + raise ValueError("cache_info cannot be None") updated_tokens = deepcopy(cache_info.tokens) tokens_per_page = self.tokens_per_page - matched_node, matched_pages = self.match(updated_tokens) + number_of_pages_to_publish = len(updated_tokens) // tokens_per_page + matched_node, matched_pages = self.match( + updated_tokens[: number_of_pages_to_publish * tokens_per_page] + ) + + for i, page in enumerate(matched_pages): + if page.index != cache_info.pages[i].index: + if page not in self._duplicated_pages: + self._duplicated_pages.append(cache_info.pages[i]) + if cache_info.last_cached_node.page.index == page.index: + cache_info.last_cached_node.page = page + last_number_of_published_pages = cache_info.number_of_published_pages + if len(matched_pages) > last_number_of_published_pages: last_number_of_published_pages = len(matched_pages) - number_of_pages_to_publish = -( - len(updated_tokens) // -tokens_per_page - ) # ceil division - # Create token blocks for unpublished pages start_token_index = last_number_of_published_pages * tokens_per_page unpublished_tokens = [] @@ -442,17 +420,26 @@ def publish_pages_for_tokens( ) unpublished_pages = cache_info.pages[ - last_number_of_published_pages:number_of_pages_to_publish + last_number_of_published_pages : last_number_of_published_pages + + len(unpublished_tokens) ] + number_of_published_pages = last_number_of_published_pages - number_of_published_pages = 0 + pages = matched_pages # using matched pages instead of cache_info.pages to avoid using the _duplicated pages + last_cached_node = cache_info.last_cached_node cur_node = matched_node - for token_block, page in zip(unpublished_tokens, unpublished_pages): + for token_block, page in zip( + unpublished_tokens[: len(unpublished_pages)], unpublished_pages + ): + if not publish_incomplete_page and len(token_block) < tokens_per_page: + # Do not publish incomplete page + break new_node = cur_node.create_child(token_block, page) - if page in self._allocated_pages: - self._allocated_pages.remove(page) - + if new_node.page.index != page.index: + if page not in self._duplicated_pages: + self._duplicated_pages.append(page) + pages.append(new_node.page) # remove parent node from the leaves. # No need to delete if it was deleted earlier. if cur_node in self.leaves: @@ -462,21 +449,29 @@ def publish_pages_for_tokens( if cur_node is not self.root and cur_node not in self.leaves: self.leaves.add(cur_node) + # we create a new node for each token block, but we only publish full pages, hence last_cached_node is updated only when a full page is published. if len(token_block) == tokens_per_page: number_of_published_pages += 1 + last_cached_node = new_node - # Update reference counts - last_cached_node = cache_info.last_cached_node + # Update reference counts only when we have unpublished tokens if unpublished_tokens: - cur_node.ref_count.increment() - if not last_cached_node.ref_count.is_empty(): - last_cached_node.ref_count.decrement() - last_cached_node = cur_node + last_cached_node.ref_count.increment() + if not cache_info.last_cached_node.ref_count.is_empty(): + cache_info.last_cached_node.ref_count.decrement() + + # Remove published pages from _allocated_pages + for page in pages: + if page in self._allocated_pages: + self._allocated_pages.remove(page) + # if we don't publish the last incomplete page, and len(pages) < len(cache_info.pages), we should return cache_info.pages to avoid losing the reference to the last incomplete page. + if not publish_incomplete_page and len(pages) < len(cache_info.pages): + pages = pages + cache_info.pages[len(pages) :] return TrieCacheInfo( num_tokens=len(updated_tokens), tokens=updated_tokens, - pages=cache_info.pages, + pages=pages, last_cached_node=last_cached_node, number_of_published_pages=number_of_published_pages, pool=self.page_pool, @@ -517,6 +512,17 @@ def free_cache_pages(self): self.page_pool.free_pages(pages_to_free) self.page_pool.free_pages(self._allocated_pages) + self._allocated_pages = [] + + def free_allocated_pages(self, page_ids: List[int]): + pages = [] + for id in page_ids: + for page in self._allocated_pages: + if page.index == id: + pages.append(page) + self._allocated_pages.remove(page) + break + self.page_pool.free_pages(pages) def release_pages(self, cache_info: TrieCacheInfo): """Release the allocation's reference to its pages. @@ -530,13 +536,15 @@ def release_pages(self, cache_info: TrieCacheInfo): if not last_cached_node.ref_count.is_empty(): last_cached_node.ref_count.decrement() - self.page_pool.free_pages(self._allocated_pages) - self._allocated_pages = [] + # free duplicated pages + self.page_pool.free_pages(self._duplicated_pages) + self._duplicated_pages = [] def shutdown(self): self.free_cache_pages() available = self.page_pool.available_page_count() total = self.page_pool.total_page_count() + if available != total: raise ValueError(f"Pages lost: {total - available} of {total} unfreed") diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 600de5ff11..4370240fa3 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -18,6 +18,10 @@ from .kvcache.trie_attention_cache import TriePagedAttentionCache from ...utils import InferenceExecRequest +import logging + +logger = logging.getLogger(__name__) + class InferencePhase(Enum): PREFILL = 1 @@ -98,20 +102,14 @@ def cache_page_indices(self, max_len: int) -> list[int]: def acquire_pages(self): """Acquire pages for this request.""" - self.allocated_cache_info = self._cache.allocate(self.input_token_ids) - self.page_ids = [p.index for p in self.allocated_cache_info.pages] - - def extend_pages(self, extra_token_slots: int): - self.allocated_cache_info = self._cache.extend_pages( - self.input_token_ids, - self.allocated_cache_info, - extra_token_slots=extra_token_slots, - ) + cached_allocation = self._cache.lookup(self.input_token_ids) + token_ids = self.input_token_ids[cached_allocation.num_tokens :] + self.allocated_cache_info = self._cache.allocate(token_ids, cached_allocation) self.page_ids = [p.index for p in self.allocated_cache_info.pages] - def publish_allocated_pages(self): + def publish_allocated_pages(self, publish_incomplete_page: bool = False): self.allocated_cache_info = self._cache.publish_pages_for_tokens( - self.input_token_ids, self.allocated_cache_info + self.allocated_cache_info, publish_incomplete_page=publish_incomplete_page ) def free_cache_pages(self): diff --git a/shortfin/tests/apps/llm/components/invocation_test.py b/shortfin/tests/apps/llm/components/invocation_test.py index 98551b9242..c96517967b 100644 --- a/shortfin/tests/apps/llm/components/invocation_test.py +++ b/shortfin/tests/apps/llm/components/invocation_test.py @@ -112,6 +112,7 @@ def staggered_exec_req_list(cache_ref_count, page_pool): ] req.allocated_cache_info = CacheInfo( num_tokens=len(req.input_token_ids), + tokens=req.input_token_ids, pages=pages, pool=page_pool, last_cached_node=None, diff --git a/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py index d6fd437d2c..4516541cec 100644 --- a/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py @@ -539,4 +539,4 @@ async def test_fork_pages_allocation_error(cache_ref_count): # Should throw an allocation error when forking with pytest.raises(CacheAllocationFailure): - cache_ref_count.fork_pages(pages) + cache_ref_count.fork_pages([], allocation) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py index 2d83152939..693d02a4df 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py @@ -155,8 +155,11 @@ def published_sequence(trie_cache): """Helper fixture that returns a function to publish token sequences""" def _publish_sequence(tokens: List[int]) -> None: - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens(alloc.tokens, alloc) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) return _publish_sequence @@ -211,16 +214,17 @@ def print_node(node, depth=0): @pytest.mark.parametrize("test_sequence", basic_sequences) def test_basic_allocation(trie_cache, test_sequence): """Test basic page allocation without reuse""" - allocation = trie_cache.allocate(test_sequence["tokens"]) + cached_allocation = trie_cache.lookup(test_sequence["tokens"]) + allocation = trie_cache.allocate( + test_sequence["tokens"][cached_allocation.num_tokens :], cached_allocation + ) assert len(allocation.pages) == test_sequence["expected_pages"] assert allocation.number_of_published_pages == 0 assert ( len(allocation.pages) - allocation.number_of_published_pages == test_sequence["expected_pages"] ) - allocation_updated = trie_cache.publish_pages_for_tokens( - allocation.tokens, allocation - ) + allocation_updated = trie_cache.publish_pages_for_tokens(allocation) assert allocation_updated.number_of_published_pages == ( len(test_sequence["tokens"]) // TEST_PAGE_SIZE ) @@ -267,16 +271,18 @@ def test_page_reuse(trie_cache, published_sequence, test_sequences): published_sequence(test_sequences["initial_tokens"]) # Try to reuse - allocation = trie_cache.allocate(test_sequences["reuse_tokens"]) + cached_allocation = trie_cache.lookup(test_sequences["reuse_tokens"]) + allocation = trie_cache.allocate( + test_sequences["reuse_tokens"][cached_allocation.num_tokens :], + cached_allocation, + ) assert len(allocation.pages) == test_sequences["total_pages"] assert allocation.number_of_published_pages == test_sequences["expected_cached"] assert ( len(allocation.pages) - allocation.number_of_published_pages == test_sequences["total_pages"] - test_sequences["expected_cached"] ) - allocation_updated = trie_cache.publish_pages_for_tokens( - allocation.tokens, allocation - ) + allocation_updated = trie_cache.publish_pages_for_tokens(allocation) trie_cache.release_pages(allocation_updated) @@ -310,10 +316,11 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nPublishing sequences to keep active:") for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) sequences.append(tokens) logger.debug(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -322,10 +329,11 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nAdding releasable sequences:") for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) # These can be evicted sequences.append(tokens) logger.debug(f"Added releasable sequence {i}") @@ -338,7 +346,10 @@ def test_lru_eviction(trie_cache, access_count): logger.debug(f"\nAccessing {access_count} sequences to update LRU order:") for i in range(access_count): logger.debug(f"\nAccessing sequence {i}:") - alloc = trie_cache.allocate(sequences[i]) + cached_allocation = trie_cache.lookup(sequences[i]) + alloc = trie_cache.allocate( + sequences[i][cached_allocation.num_tokens :], cached_allocation + ) print_tree_state(trie_cache, " ") trie_cache.release_pages(alloc) logger.debug(f"After releasing allocation {i}:") @@ -353,7 +364,10 @@ def test_lru_eviction(trie_cache, access_count): # Try to allocate new sequence - should evict least recently used unpublished sequence new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) logger.debug(f"\nAttempting to allocate new sequence: {new_tokens}") - new_alloc = trie_cache.allocate(new_tokens) + cached_allocation = trie_cache.lookup(new_tokens) + new_alloc = trie_cache.allocate( + new_tokens[cached_allocation.num_tokens :], cached_allocation + ) logger.debug("\nNew allocation succeeded:") logger.debug("\nCache state after new allocation:") print_tree_state(trie_cache, " ") @@ -363,7 +377,10 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nVerifying preserved sequences:") for i in range(max(access_count, keep_published)): logger.debug(f"\nChecking sequence {i}:") - recheck = trie_cache.allocate(sequences[i]) + cached_allocation = trie_cache.lookup(sequences[i]) + recheck = trie_cache.allocate( + sequences[i][cached_allocation.num_tokens :], cached_allocation + ) cached_pages = recheck.number_of_published_pages logger.debug(f"- Cached pages found: {cached_pages}") assert ( @@ -390,7 +407,11 @@ def test_progressive_publish(trie_cache, publish_steps): print_tree_state(trie_cache) logger.debug("\nAcquiring initial allocation...") - alloc = trie_cache.allocate(tokens) + token_list = list(tokens) + cached_allocation = trie_cache.lookup(token_list) + alloc = trie_cache.allocate( + token_list[cached_allocation.num_tokens :], cached_allocation + ) logger.debug(f"Initial allocation pages: {[p.index for p in alloc.pages]}") logger.debug("\nCache state after initial allocation:") print_tree_state(trie_cache) @@ -401,9 +422,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page logger.debug(f"Publishing up to page {step}") # Replace publishing with tokens - trie_cache.publish_pages_for_tokens( - alloc.tokens[: (step) * TEST_PAGE_SIZE], alloc - ) + trie_cache.publish_pages_for_tokens(alloc) logger.debug("\nCache state after publish:") print_tree_state(trie_cache) @@ -412,7 +431,11 @@ def test_progressive_publish(trie_cache, publish_steps): logger.debug(f"\nAttempting to reuse tokens: {reuse_tokens}") logger.debug(f"Expected cached pages: {step}") - reuse_alloc = trie_cache.allocate(reuse_tokens) + reuse_token_list = list(reuse_tokens) + cached_allocation = trie_cache.lookup(reuse_token_list) + reuse_alloc = trie_cache.allocate( + reuse_token_list[cached_allocation.num_tokens :], cached_allocation + ) logger.debug(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") logger.debug( f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}" @@ -446,18 +469,22 @@ def test_reference_counting(trie_cache, ref_count): allocations = [] # Create initial allocation and publish - first_alloc = trie_cache.allocate(tokens) - # Replace publishing with tokens - first_alloc_updated = trie_cache.publish_pages_for_tokens( - first_alloc.tokens, first_alloc + cached_allocation = trie_cache.lookup(tokens) + first_alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + # Replace publishing with tokens + first_alloc_updated = trie_cache.publish_pages_for_tokens(first_alloc) allocations.append(first_alloc_updated) logger.debug("\nInitial allocation created") print_tree_state(trie_cache, " ") # Create additional references for i in range(ref_count - 1): - alloc = trie_cache.allocate(tokens) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) allocations.append(alloc) logger.debug(f"\nCreated reference {i+1}") print_tree_state(trie_cache, " ") @@ -469,10 +496,11 @@ def test_reference_counting(trie_cache, ref_count): fill_tokens = list( range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) - alloc = trie_cache.allocate(fill_tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(fill_tokens) + alloc = trie_cache.allocate( + fill_tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) fill_allocations.append(alloc_updated) logger.debug(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") @@ -480,7 +508,10 @@ def test_reference_counting(trie_cache, ref_count): logger.debug("\nAttempting allocation that should fail...") try: new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) - new_alloc = trie_cache.allocate(new_tokens) + cached_allocation = trie_cache.lookup(new_tokens) + new_alloc = trie_cache.allocate( + new_tokens[cached_allocation.num_tokens :], cached_allocation + ) logger.debug("ERROR: Allocation succeeded when it should have failed!") logger.debug("\nPost-allocation state:") print_tree_state(trie_cache, " ") @@ -506,18 +537,20 @@ def test_reference_counting(trie_cache, ref_count): def test_fork_pages(trie_cache, tokens): """Test that fork_pages correctly creates a forked allocation sharing published pages.""" # Create and publish a sequence - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens(alloc.tokens, alloc) - published_pages = list(alloc_updated.pages) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) # Fork the published sequence - forked_alloc = trie_cache.fork_pages(published_pages, tokens) + forked_alloc = trie_cache.fork_pages(tokens, alloc_updated) try: # The forked allocation should reference the same pages assert forked_alloc.tokens == tokens - assert len(forked_alloc.pages) == len(published_pages) - for orig, forked in zip(published_pages, forked_alloc.pages): + assert len(forked_alloc.pages) == len(alloc_updated.pages) + for orig, forked in zip(alloc_updated.pages, forked_alloc.pages): assert orig.index == forked.index finally: trie_cache.release_pages(forked_alloc) diff --git a/shortfin/tests/apps/llm/components/messages_test.py b/shortfin/tests/apps/llm/components/messages_test.py index 4739207e65..db5b55094c 100644 --- a/shortfin/tests/apps/llm/components/messages_test.py +++ b/shortfin/tests/apps/llm/components/messages_test.py @@ -73,7 +73,7 @@ def test_inference_exec_request_reset(mock_void_future): def test_cache_page_indices(mock_void_future, mock_base_cache, dummy_pages): req = LlmInferenceExecRequest(InferencePhase.PREFILL, [1, 2, 3, 4], rid="test123") req._cache = mock_base_cache - req.allocated_cache_info = CacheInfo(4, dummy_pages, None, None) + req.allocated_cache_info = CacheInfo(4, [1, 2, 3, 4], dummy_pages, None, None) cache_page_indices = req.cache_page_indices(2) assert len(cache_page_indices) == 2 @@ -88,7 +88,7 @@ def test_free_cache_pages(mock_void_future, mock_base_cache, dummy_pages): assert not release_called req._cache = mock_base_cache - req.allocated_cache_info = CacheInfo(4, dummy_pages, None, None) + req.allocated_cache_info = CacheInfo(4, [1, 2, 3, 4], dummy_pages, None, None) with patch.object(req._cache, "release_pages") as mock_release_pages: req.free_cache_pages()