From 7a70202535fab4597f9df4325e705a1887ad5da7 Mon Sep 17 00:00:00 2001 From: lisaliu1 Date: Fri, 3 Oct 2025 16:54:46 +0000 Subject: [PATCH 1/4] update Trie APIs --- .../shortfin/tinystories_llama2_25m_test.py | 28 --- .../kvcache/trie_attention_cache.py | 204 ++++++++++++------ .../new_mock_pool_test.py | 111 ++++++---- 3 files changed, 207 insertions(+), 136 deletions(-) diff --git a/app_tests/integration_tests/llm/shortfin/tinystories_llama2_25m_test.py b/app_tests/integration_tests/llm/shortfin/tinystories_llama2_25m_test.py index 011ad6c90d..f93a0fa9d3 100644 --- a/app_tests/integration_tests/llm/shortfin/tinystories_llama2_25m_test.py +++ b/app_tests/integration_tests/llm/shortfin/tinystories_llama2_25m_test.py @@ -102,19 +102,12 @@ class TestLLMServer: def test_basic_generation( self, server: tuple[Any, int, ServerConfig], - request: pytest.FixtureRequest, ) -> None: """Tests basic text generation capabilities. Args: server: Tuple of (process, port) from server fixture """ - test_id = request.node.callspec.id - if "trie" in test_id: - pytest.skip( - reason="TrieAttentionCache APIs are under development, skip it for now." - ) - process, port, config = server assert process.poll() is None, "Server process terminated unexpectedly" dataset = ( @@ -145,19 +138,12 @@ def test_basic_generation( def test_multi_page_generation( self, server: tuple[Any, int, ServerConfig], - request: pytest.FixtureRequest, ) -> None: """Tests multi-page text generation capabilities. Args: server: Tuple of (process, port) from server fixture """ - test_id = request.node.callspec.id - if "trie" in test_id: - pytest.skip( - reason="TrieAttentionCache APIs are under development, skip it for now." - ) - process, port, config = server assert process.poll() is None, "Server process terminated unexpectedly" dataset = ( @@ -209,10 +195,6 @@ def test_concurrent_generation( pytest.skip( reason="Known issue with chunked prefill in batch case: https://github.com/nod-ai/shark-ai/issues/2235" ) - if "trie" in test_id: - pytest.skip( - reason="TrieAttentionCache APIs are under development, skip it for now." - ) process, port, config = server assert process.poll() is None, "Server process terminated unexpectedly" @@ -257,18 +239,12 @@ def _generate_task(prompt: str, port: int): def test_single_greedy_switch( self, server: tuple[Any, int, ServerConfig], - request: pytest.FixtureRequest, ): """Tests switching to single-beam greedy generation. Args: server: Tuple of (process, port, config) from server fixture """ - test_id = request.node.callspec.id - if "trie" in test_id: - pytest.skip( - reason="TrieAttentionCache APIs are under development, skip it for now." - ) process, port, _ = server assert process.poll() is None, "Server process terminated unexpectedly" @@ -316,10 +292,6 @@ def test_beam_search_switch( pytest.skip( "Beam search with 2 beams isn't compatible with logits returned by GPU argmax model." ) - if "trie" in test_id: - pytest.skip( - reason="TrieAttentionCache APIs are under development, skip it for now." - ) process, port, _ = server assert process.poll() is None, "Server process terminated unexpectedly" diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 83a2b67892..c4f6ddf428 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -56,8 +56,13 @@ def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": Returns: The newly created child node """ - new_node = TrieNode(tokens=tokens, page=page, parent=self) - self.children[tokens] = new_node + new_node = None + if tokens in self.children: + # If the child already exists, return it + new_node = self.children[tokens] + else: + new_node = TrieNode(tokens=tokens, page=page, parent=self) + self.children[tokens] = new_node return new_node def unlink(self) -> None: @@ -86,14 +91,10 @@ class TrieCacheInfo(CacheInfo): Contains information about the tokens, pages, and last cached node. Attributes: - tokens: List of tokens in the allocation last_cached_node: Last node in the trie that was cached - cached_pages: List of pages that were already cached - newly_acquired_pages: List of pages that were newly acquired for this allocation number_of_published_pages: Number of pages that have been published to the cache """ - tokens: List[int] number_of_published_pages: int last_cached_node: TrieNode @@ -135,9 +136,13 @@ def __init__(self, page_pool: PagePool, tokens_per_page: int): self.root = TrieNode(tokens=tuple(), page=dummy_page) self.leaves: Set[TrieNode] = set() self._lock: Lock = Lock() - self._allocated_pages: List[PageInfo] = [] + self._duplicated_pages: List[ + PageInfo + ] = ( + [] + ) # pages that are duplicated from existing pages in Trie tree. These pages can be safely freed when calling release_pages. - def fork_pages(self, pages: List[PageInfo], tokens: list[int]) -> TrieCacheInfo: + def fork_pages(self, tokens: list[int], cache_info: TrieCacheInfo) -> TrieCacheInfo: """Fork a sequence of pages into the trie. Share prefixes with existing nodes till N-1 tokens, then create a new node @@ -166,10 +171,10 @@ def fork_pages(self, pages: List[PageInfo], tokens: list[int]) -> TrieCacheInfo: pool=self.page_pool, ) - new_page = self.page_pool.copy_page(pages[-1]) + new_page = self.page_pool.copy_page(cache_info.pages[-1]) if new_page is None: self._evict_pages(1) - new_page = self.page_pool.copy_page(pages[-1]) + new_page = self.page_pool.copy_page(cache_info.pages[-1]) return TrieCacheInfo( tokens=list(tokens), @@ -208,6 +213,30 @@ def match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: return cur, matched_pages + def lookup(self, tokens: List[int]) -> TrieCacheInfo: + """Lookup the cache for the given token sequence. It only returns fully matched pages. + + Args: + tokens: Sequence of tokens to look up + returns: TrieCacheInfo with matched tokens and pages + """ + with self._lock: + page_aligned_token_len = ( + len(tokens) // self.tokens_per_page + ) * self.tokens_per_page + page_aligned_tokens = tokens[:page_aligned_token_len] + cur_node, matched_pages = self.match(page_aligned_tokens) + num_matched_tokens = len(matched_pages) * self.tokens_per_page + matched_tokens = page_aligned_tokens[:num_matched_tokens] + return TrieCacheInfo( + num_tokens=num_matched_tokens, + tokens=matched_tokens, + pages=matched_pages, + last_cached_node=cur_node, + number_of_published_pages=len(matched_pages), + pool=self.page_pool, + ) + def evict_pages(self, max_pages: int) -> int: """Evict up to max_pages pages using LRU strategy. @@ -232,6 +261,8 @@ def evict_pages(self, max_pages: int) -> int: # Evict least recently used nodes while unused_leaf_heap and len(pages_to_evict) < max_pages: _, leaf = heapq.heappop(unused_leaf_heap) + if leaf.page in pages_to_evict: + continue pages_to_evict.append(leaf.page) parent = leaf.parent @@ -249,9 +280,6 @@ def evict_pages(self, max_pages: int) -> int: heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) if pages_to_evict: - logger.debug( - f"TriePagedAttentionCache: Released allocated pages in evict_pages {[p.index for p in pages_to_evict]}" - ) self.page_pool.free_pages(pages_to_evict) return len(pages_to_evict) @@ -259,9 +287,8 @@ def evict_pages(self, max_pages: int) -> int: def allocate( self, tokens: List[int], - allocation_block_size: int = 0, cache_info: TrieCacheInfo = None, - lookup: bool = True, + allocation_block_size: int = 0, evict: bool = True, ) -> TrieCacheInfo: """Acquire pages for a sequence of tokens. @@ -271,12 +298,12 @@ def allocate( Args: tokens: Sequence of tokens needing pages + cache_info: Existing TrieCacheInfo to extend/update, if any allocation_block_size: number of pages to allocate at once, not used if it is 0 - lookup: Whether to look up existing tokens in the cache. evict: Whether to evict old tokens if the cache is full. Returns: - TrieCacheInfo: containing meta data for the allocation + PageAllocation containing both cached and newly allocated pages Raises: CacheAllocationFailure: If unable to allocate required pages @@ -284,32 +311,25 @@ def allocate( with self._lock: tokens = tuple(tokens) n_empty_pages = 0 - cached_pages = [] pages = [] cur_node = self.root - if lookup: - cur_node, matched_pages = self.match(tokens) - logger.debug( - f"TriePagedAttentionCache: Lookup found {len(matched_pages)} cached pages for token length {len(tokens)}" - ) + if not cache_info: + raise ValueError("cache_info cannot be None") + + cur_node = cache_info.last_cached_node - cached_pages = matched_pages - n_cached_tokens = 0 - if matched_pages: - n_cached_tokens = len(matched_pages) * self.tokens_per_page - remaining_length = len(tokens) - n_cached_tokens - n_empty_pages = math.ceil(remaining_length / self.tokens_per_page) - else: - n_empty_pages = math.ceil(len(tokens) / self.tokens_per_page) + n_empty_pages = math.ceil(len(tokens) / self.tokens_per_page) - if not cached_pages and allocation_block_size > 0: - n_empty_pages = allocation_block_size + if allocation_block_size > 0: + n_empty_pages = max(n_empty_pages, allocation_block_size) new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is None and evict: # Try eviction - self.evict_pages(n_empty_pages - len(self.page_pool.available_pages)) + number_evicted_pages = self.evict_pages( + n_empty_pages - len(self.page_pool.available_pages) + ) new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is None: @@ -317,37 +337,43 @@ def allocate( "Failed to acquire pages even after attempting eviction from LRU leaves" ) - cur_node.ref_count.increment() - pages = cached_pages + new_pages - self._allocated_pages.extend(new_pages) + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages and eviction is disabled" + ) - num_tokens = len(tokens) - if cache_info: - if ( - cache_info.last_cached_node - and not cache_info.last_cached_node.ref_count.is_empty() + if len(new_pages) > 0: + # some new pages are allocated and will be used to create children of cur_node, hence increment ref_count of cur_node. + # do not increment last_cached_node ref_count when we allocate along the same branch again + if len(cache_info.pages) == 0 or ( + len(cache_info.pages) > 0 + and cur_node.page.index == cache_info.pages[-1].index ): - cache_info.last_cached_node.ref_count.decrement() - pages = cache_info.pages + pages - num_tokens += cache_info.num_tokens + cur_node.ref_count.increment() + + self._allocated_pages.extend(new_pages) + pages = cache_info.pages + new_pages + num_tokens = len(tokens) + cache_info.num_tokens + tokens = cache_info.tokens + list(tokens) + number_of_published_pages = cache_info.number_of_published_pages return TrieCacheInfo( - num_tokens=len(tokens), + num_tokens=num_tokens, tokens=tokens, pages=pages, last_cached_node=cur_node, - number_of_published_pages=len(cached_pages), + number_of_published_pages=number_of_published_pages, pool=self.page_pool, ) def publish_pages_for_tokens( - self, tokens: List[int], cache_info: TrieCacheInfo + self, cache_info: TrieCacheInfo, publish_incomplete_page: bool = False ) -> TrieCacheInfo: """Make pages available in the cache for the specified tokens. Args: - tokens_to_publish: Tokens to publish to the cache cache_info: TrieCacheInfo object containing allocation metadata + publish_incomplete_page: Whether to publish the last page even if it is not full Raises: ValueError: If tokens don't match allocation or exceed available pages @@ -355,17 +381,27 @@ def publish_pages_for_tokens( with self._lock: # If we have more tokens, publish pages up to the incoming tokens. # If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens. + if not cache_info: + raise ValueError("cache_info cannot be None") updated_tokens = deepcopy(cache_info.tokens) tokens_per_page = self.tokens_per_page - matched_node, matched_pages = self.match(updated_tokens) + number_of_pages_to_publish = len(updated_tokens) // tokens_per_page + matched_node, matched_pages = self.match( + updated_tokens[: number_of_pages_to_publish * tokens_per_page] + ) + + for i, page in enumerate(matched_pages): + if page.index != cache_info.pages[i].index: + if page not in self._duplicated_pages: + self._duplicated_pages.append(cache_info.pages[i]) + if cache_info.last_cached_node.page.index == page.index: + cache_info.last_cached_node.page = page + last_number_of_published_pages = cache_info.number_of_published_pages + if len(matched_pages) > last_number_of_published_pages: last_number_of_published_pages = len(matched_pages) - number_of_pages_to_publish = -( - len(updated_tokens) // -tokens_per_page - ) # ceil division - # Create token blocks for unpublished pages start_token_index = last_number_of_published_pages * tokens_per_page unpublished_tokens = [] @@ -380,17 +416,26 @@ def publish_pages_for_tokens( ) unpublished_pages = cache_info.pages[ - last_number_of_published_pages:number_of_pages_to_publish + last_number_of_published_pages : last_number_of_published_pages + + len(unpublished_tokens) ] + number_of_published_pages = last_number_of_published_pages - number_of_published_pages = 0 + pages = matched_pages # using matched pages instead of cache_info.pages to avoid using the _duplicated pages + last_cached_node = cache_info.last_cached_node cur_node = matched_node - for token_block, page in zip(unpublished_tokens, unpublished_pages): + for token_block, page in zip( + unpublished_tokens[: len(unpublished_pages)], unpublished_pages + ): + if not publish_incomplete_page and len(token_block) < tokens_per_page: + # Do not publish incomplete page + break new_node = cur_node.create_child(token_block, page) - if page in self._allocated_pages: - self._allocated_pages.remove(page) - + if new_node.page.index != page.index: + if page not in self._duplicated_pages: + self._duplicated_pages.append(page) + pages.append(new_node.page) # remove parent node from the leaves. # No need to delete if it was deleted earlier. if cur_node in self.leaves: @@ -400,21 +445,29 @@ def publish_pages_for_tokens( if cur_node is not self.root and cur_node not in self.leaves: self.leaves.add(cur_node) + # we create a new node for each token block, but we only publish full pages, hence last_cached_node is updated only when a full page is published. if len(token_block) == tokens_per_page: number_of_published_pages += 1 + last_cached_node = new_node - # Update reference counts - last_cached_node = cache_info.last_cached_node + # Update reference counts only when we have unpublished tokens if unpublished_tokens: - cur_node.ref_count.increment() - if not last_cached_node.ref_count.is_empty(): - last_cached_node.ref_count.decrement() - last_cached_node = cur_node + last_cached_node.ref_count.increment() + if not cache_info.last_cached_node.ref_count.is_empty(): + cache_info.last_cached_node.ref_count.decrement() + + # Remove published pages from _allocated_pages + for page in pages: + if page in self._allocated_pages: + self._allocated_pages.remove(page) + # if we don't publish the last incomplete page, and len(pages) < len(cache_info.pages), we should return cache_info.pages to avoid losing the reference to the last incomplete page. + if not publish_incomplete_page and len(pages) < len(cache_info.pages): + pages = pages + cache_info.pages[len(pages) :] return TrieCacheInfo( num_tokens=len(updated_tokens), tokens=updated_tokens, - pages=cache_info.pages, + pages=pages, last_cached_node=last_cached_node, number_of_published_pages=number_of_published_pages, pool=self.page_pool, @@ -455,6 +508,17 @@ def free_cache_pages(self): self.page_pool.free_pages(pages_to_free) self.page_pool.free_pages(self._allocated_pages) + self._allocated_pages = [] + + def free_allocated_pages(self, page_ids: List[int]): + pages = [] + for id in page_ids: + for page in self._allocated_pages: + if page.index == id: + pages.append(page) + self._allocated_pages.remove(page) + break + self.page_pool.free_pages(pages) def release_pages(self, cache_info: TrieCacheInfo): """Release the allocation's reference to its pages. @@ -468,13 +532,15 @@ def release_pages(self, cache_info: TrieCacheInfo): if not last_cached_node.ref_count.is_empty(): last_cached_node.ref_count.decrement() - self.page_pool.free_pages(self._allocated_pages) - self._allocated_pages = [] + # free duplicated pages + self.page_pool.free_pages(self._duplicated_pages) + self._duplicated_pages = [] def shutdown(self): self.free_cache_pages() available = self.page_pool.available_page_count() total = self.page_pool.total_page_count() + if available != total: raise ValueError(f"Pages lost: {total - available} of {total} unfreed") diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py index 2d83152939..693d02a4df 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py @@ -155,8 +155,11 @@ def published_sequence(trie_cache): """Helper fixture that returns a function to publish token sequences""" def _publish_sequence(tokens: List[int]) -> None: - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens(alloc.tokens, alloc) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) return _publish_sequence @@ -211,16 +214,17 @@ def print_node(node, depth=0): @pytest.mark.parametrize("test_sequence", basic_sequences) def test_basic_allocation(trie_cache, test_sequence): """Test basic page allocation without reuse""" - allocation = trie_cache.allocate(test_sequence["tokens"]) + cached_allocation = trie_cache.lookup(test_sequence["tokens"]) + allocation = trie_cache.allocate( + test_sequence["tokens"][cached_allocation.num_tokens :], cached_allocation + ) assert len(allocation.pages) == test_sequence["expected_pages"] assert allocation.number_of_published_pages == 0 assert ( len(allocation.pages) - allocation.number_of_published_pages == test_sequence["expected_pages"] ) - allocation_updated = trie_cache.publish_pages_for_tokens( - allocation.tokens, allocation - ) + allocation_updated = trie_cache.publish_pages_for_tokens(allocation) assert allocation_updated.number_of_published_pages == ( len(test_sequence["tokens"]) // TEST_PAGE_SIZE ) @@ -267,16 +271,18 @@ def test_page_reuse(trie_cache, published_sequence, test_sequences): published_sequence(test_sequences["initial_tokens"]) # Try to reuse - allocation = trie_cache.allocate(test_sequences["reuse_tokens"]) + cached_allocation = trie_cache.lookup(test_sequences["reuse_tokens"]) + allocation = trie_cache.allocate( + test_sequences["reuse_tokens"][cached_allocation.num_tokens :], + cached_allocation, + ) assert len(allocation.pages) == test_sequences["total_pages"] assert allocation.number_of_published_pages == test_sequences["expected_cached"] assert ( len(allocation.pages) - allocation.number_of_published_pages == test_sequences["total_pages"] - test_sequences["expected_cached"] ) - allocation_updated = trie_cache.publish_pages_for_tokens( - allocation.tokens, allocation - ) + allocation_updated = trie_cache.publish_pages_for_tokens(allocation) trie_cache.release_pages(allocation_updated) @@ -310,10 +316,11 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nPublishing sequences to keep active:") for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) sequences.append(tokens) logger.debug(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -322,10 +329,11 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nAdding releasable sequences:") for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) # These can be evicted sequences.append(tokens) logger.debug(f"Added releasable sequence {i}") @@ -338,7 +346,10 @@ def test_lru_eviction(trie_cache, access_count): logger.debug(f"\nAccessing {access_count} sequences to update LRU order:") for i in range(access_count): logger.debug(f"\nAccessing sequence {i}:") - alloc = trie_cache.allocate(sequences[i]) + cached_allocation = trie_cache.lookup(sequences[i]) + alloc = trie_cache.allocate( + sequences[i][cached_allocation.num_tokens :], cached_allocation + ) print_tree_state(trie_cache, " ") trie_cache.release_pages(alloc) logger.debug(f"After releasing allocation {i}:") @@ -353,7 +364,10 @@ def test_lru_eviction(trie_cache, access_count): # Try to allocate new sequence - should evict least recently used unpublished sequence new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) logger.debug(f"\nAttempting to allocate new sequence: {new_tokens}") - new_alloc = trie_cache.allocate(new_tokens) + cached_allocation = trie_cache.lookup(new_tokens) + new_alloc = trie_cache.allocate( + new_tokens[cached_allocation.num_tokens :], cached_allocation + ) logger.debug("\nNew allocation succeeded:") logger.debug("\nCache state after new allocation:") print_tree_state(trie_cache, " ") @@ -363,7 +377,10 @@ def test_lru_eviction(trie_cache, access_count): logger.debug("\nVerifying preserved sequences:") for i in range(max(access_count, keep_published)): logger.debug(f"\nChecking sequence {i}:") - recheck = trie_cache.allocate(sequences[i]) + cached_allocation = trie_cache.lookup(sequences[i]) + recheck = trie_cache.allocate( + sequences[i][cached_allocation.num_tokens :], cached_allocation + ) cached_pages = recheck.number_of_published_pages logger.debug(f"- Cached pages found: {cached_pages}") assert ( @@ -390,7 +407,11 @@ def test_progressive_publish(trie_cache, publish_steps): print_tree_state(trie_cache) logger.debug("\nAcquiring initial allocation...") - alloc = trie_cache.allocate(tokens) + token_list = list(tokens) + cached_allocation = trie_cache.lookup(token_list) + alloc = trie_cache.allocate( + token_list[cached_allocation.num_tokens :], cached_allocation + ) logger.debug(f"Initial allocation pages: {[p.index for p in alloc.pages]}") logger.debug("\nCache state after initial allocation:") print_tree_state(trie_cache) @@ -401,9 +422,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page logger.debug(f"Publishing up to page {step}") # Replace publishing with tokens - trie_cache.publish_pages_for_tokens( - alloc.tokens[: (step) * TEST_PAGE_SIZE], alloc - ) + trie_cache.publish_pages_for_tokens(alloc) logger.debug("\nCache state after publish:") print_tree_state(trie_cache) @@ -412,7 +431,11 @@ def test_progressive_publish(trie_cache, publish_steps): logger.debug(f"\nAttempting to reuse tokens: {reuse_tokens}") logger.debug(f"Expected cached pages: {step}") - reuse_alloc = trie_cache.allocate(reuse_tokens) + reuse_token_list = list(reuse_tokens) + cached_allocation = trie_cache.lookup(reuse_token_list) + reuse_alloc = trie_cache.allocate( + reuse_token_list[cached_allocation.num_tokens :], cached_allocation + ) logger.debug(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") logger.debug( f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}" @@ -446,18 +469,22 @@ def test_reference_counting(trie_cache, ref_count): allocations = [] # Create initial allocation and publish - first_alloc = trie_cache.allocate(tokens) - # Replace publishing with tokens - first_alloc_updated = trie_cache.publish_pages_for_tokens( - first_alloc.tokens, first_alloc + cached_allocation = trie_cache.lookup(tokens) + first_alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation ) + # Replace publishing with tokens + first_alloc_updated = trie_cache.publish_pages_for_tokens(first_alloc) allocations.append(first_alloc_updated) logger.debug("\nInitial allocation created") print_tree_state(trie_cache, " ") # Create additional references for i in range(ref_count - 1): - alloc = trie_cache.allocate(tokens) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) allocations.append(alloc) logger.debug(f"\nCreated reference {i+1}") print_tree_state(trie_cache, " ") @@ -469,10 +496,11 @@ def test_reference_counting(trie_cache, ref_count): fill_tokens = list( range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) - alloc = trie_cache.allocate(fill_tokens) - alloc_updated = trie_cache.publish_pages_for_tokens( - alloc.tokens[:TEST_PAGE_SIZE], alloc + cached_allocation = trie_cache.lookup(fill_tokens) + alloc = trie_cache.allocate( + fill_tokens[cached_allocation.num_tokens :], cached_allocation ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) fill_allocations.append(alloc_updated) logger.debug(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") @@ -480,7 +508,10 @@ def test_reference_counting(trie_cache, ref_count): logger.debug("\nAttempting allocation that should fail...") try: new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) - new_alloc = trie_cache.allocate(new_tokens) + cached_allocation = trie_cache.lookup(new_tokens) + new_alloc = trie_cache.allocate( + new_tokens[cached_allocation.num_tokens :], cached_allocation + ) logger.debug("ERROR: Allocation succeeded when it should have failed!") logger.debug("\nPost-allocation state:") print_tree_state(trie_cache, " ") @@ -506,18 +537,20 @@ def test_reference_counting(trie_cache, ref_count): def test_fork_pages(trie_cache, tokens): """Test that fork_pages correctly creates a forked allocation sharing published pages.""" # Create and publish a sequence - alloc = trie_cache.allocate(tokens) - alloc_updated = trie_cache.publish_pages_for_tokens(alloc.tokens, alloc) - published_pages = list(alloc_updated.pages) + cached_allocation = trie_cache.lookup(tokens) + alloc = trie_cache.allocate( + tokens[cached_allocation.num_tokens :], cached_allocation + ) + alloc_updated = trie_cache.publish_pages_for_tokens(alloc) trie_cache.release_pages(alloc_updated) # Fork the published sequence - forked_alloc = trie_cache.fork_pages(published_pages, tokens) + forked_alloc = trie_cache.fork_pages(tokens, alloc_updated) try: # The forked allocation should reference the same pages assert forked_alloc.tokens == tokens - assert len(forked_alloc.pages) == len(published_pages) - for orig, forked in zip(published_pages, forked_alloc.pages): + assert len(forked_alloc.pages) == len(alloc_updated.pages) + for orig, forked in zip(alloc_updated.pages, forked_alloc.pages): assert orig.index == forked.index finally: trie_cache.release_pages(forked_alloc) From bbc3d3982b316126696d3e9d5a69e08775c4a6ca Mon Sep 17 00:00:00 2001 From: lisaliu1 Date: Fri, 3 Oct 2025 20:10:57 +0000 Subject: [PATCH 2/4] address reviews --- .../kvcache/base_attention_cache.py | 13 ---- .../kvcache/trie_attention_cache.py | 73 +++++-------------- .../kvcache/new_base_attention_cache_test.py | 13 ---- .../new_mock_pool_test.py | 30 -------- 4 files changed, 18 insertions(+), 111 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 170dc747cb..855f20bfec 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -68,19 +68,6 @@ def shutdown(self): def free_pages(self, pages: List[PageInfo]): self.page_pool.free_pages(pages) - def fork_pages(self, tokens: list[int], cache_info: CacheInfo) -> CacheInfo: - new_pages = cache_info.pages.copy() - last_page = new_pages.pop(-1) - new_page = self.page_pool.copy_page(last_page) - if new_page is None: - raise CacheAllocationFailure() - - new_pages.append(new_page) - cache_info.pages = new_pages - cache_info.tokens.extend(tokens) - cache_info.num_tokens += len(tokens) - return cache_info - def lookup(self, tokens: List[int]) -> CacheInfo: return CacheInfo( num_tokens=0, diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index c4f6ddf428..f6f6a552f8 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -65,6 +65,16 @@ def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": self.children[tokens] = new_node return new_node + def register_allocation(self): + """Increment the reference count for this node to register that more following pages have been allocated.""" + self.ref_count.increment() + + def publish_descendant(self, descendant: "TrieNode") -> None: + """Because ref_count is used to track allocations that depend on this node, if we have already created descendant nodes for the allocation, we need to decrease the ref_count of this node and increase the ref_count of the descendant node to reflect that the allocations have already been recorded as the descends of this node, and the new allocation depends on the descendant node.""" + if not self.ref_count.is_empty(): + self.ref_count.decrement() + descendant.ref_count.increment() + def unlink(self) -> None: """Remove this node from its parent's children.""" if self.parent is not None: @@ -142,49 +152,6 @@ def __init__(self, page_pool: PagePool, tokens_per_page: int): [] ) # pages that are duplicated from existing pages in Trie tree. These pages can be safely freed when calling release_pages. - def fork_pages(self, tokens: list[int], cache_info: TrieCacheInfo) -> TrieCacheInfo: - """Fork a sequence of pages into the trie. - - Share prefixes with existing nodes till N-1 tokens, then create a new node - for the last token block. This allows sharing of common prefixes - - - Args: - pages: List of PageInfo objects to fork into the trie - - Returns: - TrieCacheInfo containing both cached and newly allocated pages - """ - with self._lock: - curr, matched_pages = self.match(tokens) - curr.ref_count.increment() - - n_cached_tokens = len(matched_pages) * self.tokens_per_page - if n_cached_tokens >= len(tokens): - # If all tokens are already cached, no need to fork - return TrieCacheInfo( - tokens=list(tokens), - num_tokens=len(tokens), - last_cached_node=curr, - pages=matched_pages + [], - number_of_published_pages=len(matched_pages), - pool=self.page_pool, - ) - - new_page = self.page_pool.copy_page(cache_info.pages[-1]) - if new_page is None: - self._evict_pages(1) - new_page = self.page_pool.copy_page(cache_info.pages[-1]) - - return TrieCacheInfo( - tokens=list(tokens), - num_tokens=len(tokens), - last_cached_node=curr, - pages=matched_pages + [new_page], - number_of_published_pages=len(matched_pages), - pool=self.page_pool, - ) - def match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: """ Find the longest prefix match in the trie. @@ -349,7 +316,7 @@ def allocate( len(cache_info.pages) > 0 and cur_node.page.index == cache_info.pages[-1].index ): - cur_node.ref_count.increment() + cur_node.register_allocation() self._allocated_pages.extend(new_pages) pages = cache_info.pages + new_pages @@ -390,9 +357,10 @@ def publish_pages_for_tokens( updated_tokens[: number_of_pages_to_publish * tokens_per_page] ) + duplicated_page_set = set([page.index for page in self._duplicated_pages]) for i, page in enumerate(matched_pages): if page.index != cache_info.pages[i].index: - if page not in self._duplicated_pages: + if page.index not in duplicated_page_set: self._duplicated_pages.append(cache_info.pages[i]) if cache_info.last_cached_node.page.index == page.index: cache_info.last_cached_node.page = page @@ -452,9 +420,7 @@ def publish_pages_for_tokens( # Update reference counts only when we have unpublished tokens if unpublished_tokens: - last_cached_node.ref_count.increment() - if not cache_info.last_cached_node.ref_count.is_empty(): - cache_info.last_cached_node.ref_count.decrement() + cache_info.last_cached_node.publish_descendant(last_cached_node) # Remove published pages from _allocated_pages for page in pages: @@ -511,13 +477,10 @@ def free_cache_pages(self): self._allocated_pages = [] def free_allocated_pages(self, page_ids: List[int]): - pages = [] - for id in page_ids: - for page in self._allocated_pages: - if page.index == id: - pages.append(page) - self._allocated_pages.remove(page) - break + page_id_set = set(page_ids) + pages = [page for page in self._allocated_pages if page.index in page_id_set] + for page in pages: + self._allocated_pages.remove(page) self.page_pool.free_pages(pages) def release_pages(self, cache_info: TrieCacheInfo): diff --git a/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py index a2ea6cc43f..a1aba2c555 100644 --- a/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/new_base_attention_cache_test.py @@ -258,16 +258,3 @@ def test_free_pages(cache, tokens, expected_pages, case_name): assert ( qsize == total_pages ), f"All pages should be freed for {case_name}, but only freed {qsize}" - - -@pytest.mark.asyncio -async def test_fork_pages_allocation_error(cache_ref_count): - # Use all pages - tokens = list(range(TEST_PAGE_SIZE * TEST_POOL_CAPACITY)) - - allocation = cache_ref_count.allocate(tokens) - pages = allocation.pages - - # Should throw an allocation error when forking - with pytest.raises(CacheAllocationFailure): - cache_ref_count.fork_pages([], allocation) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py index 693d02a4df..1a2564bf2d 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache/new_mock_pool_test.py @@ -524,33 +524,3 @@ def test_reference_counting(trie_cache, ref_count): logger.debug("\nCleaning up allocations...") for alloc in allocations + fill_allocations: trie_cache.release_pages(alloc) - - -@pytest.mark.parametrize( - "tokens", - [ - list(range(TEST_PAGE_SIZE * 2)), - list(range(TEST_PAGE_SIZE * 3)), - list(range(TEST_PAGE_SIZE * 4)), - ], -) -def test_fork_pages(trie_cache, tokens): - """Test that fork_pages correctly creates a forked allocation sharing published pages.""" - # Create and publish a sequence - cached_allocation = trie_cache.lookup(tokens) - alloc = trie_cache.allocate( - tokens[cached_allocation.num_tokens :], cached_allocation - ) - alloc_updated = trie_cache.publish_pages_for_tokens(alloc) - trie_cache.release_pages(alloc_updated) - - # Fork the published sequence - forked_alloc = trie_cache.fork_pages(tokens, alloc_updated) - try: - # The forked allocation should reference the same pages - assert forked_alloc.tokens == tokens - assert len(forked_alloc.pages) == len(alloc_updated.pages) - for orig, forked in zip(alloc_updated.pages, forked_alloc.pages): - assert orig.index == forked.index - finally: - trie_cache.release_pages(forked_alloc) From 155bce03bb693e753e45ad81e5ab90102189b39c Mon Sep 17 00:00:00 2001 From: lisaliu1 Date: Fri, 3 Oct 2025 20:13:52 +0000 Subject: [PATCH 3/4] add trie test to accuracy test --- .../integration_tests/llm/shortfin/accuracy/accuracy_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py b/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py index af045b16e1..2fee413e33 100644 --- a/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py +++ b/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py @@ -244,6 +244,10 @@ def _request_loop( ModelConfig.get(name="local_meta_llama3.1_8b_instruct"), {"prefix_sharing_algorithm": "none"}, ), # noqa: E501 + ( + ModelConfig.get(name="local_meta_llama3.1_8b_instruct"), + {"prefix_sharing_algorithm": "trie"}, + ), # noqa: E501 ], ids=[ "meta_llama3.1_8b_instruct-no_prefix_sharing", From dffc15f157aea42faecb44d1e3c4c80f11929ca4 Mon Sep 17 00:00:00 2001 From: lisaliu1 Date: Fri, 3 Oct 2025 20:21:32 +0000 Subject: [PATCH 4/4] fix accuracy test --- .../integration_tests/llm/shortfin/accuracy/accuracy_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py b/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py index 2fee413e33..5985d2f43b 100644 --- a/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py +++ b/app_tests/integration_tests/llm/shortfin/accuracy/accuracy_test.py @@ -251,6 +251,7 @@ def _request_loop( ], ids=[ "meta_llama3.1_8b_instruct-no_prefix_sharing", + "meta_llama3.1_8b_instruct-trie_prefix_sharing", ], indirect=True, )