diff --git a/pkg/tokenization/prefixstore/block_hasher.go b/pkg/tokenization/prefixstore/block_hasher.go new file mode 100644 index 00000000..ff01b478 --- /dev/null +++ b/pkg/tokenization/prefixstore/block_hasher.go @@ -0,0 +1,70 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefixstore + +import ( + "encoding/binary" + "fmt" + + "github.com/cespare/xxhash/v2" +) + +// BlockHasher handles the computation of block hashes for chunked text data. +// It maintains state for sequential hash computation where each block's hash +// depends on the previous block's hash. +type BlockHasher struct { + digest *xxhash.Digest + previousHash uint64 +} + +func NewBlockHasher() *BlockHasher { + return &BlockHasher{ + digest: xxhash.New(), + previousHash: 0, + } +} + +// Reset resets the hasher state for a new sequence of blocks. +func (h *BlockHasher) Reset() { + h.digest.Reset() +} + +// ComputeBlockHash computes the hash for a block of text data. +// The hash depends on both the current block content and the previous block's hash. +func (h *BlockHasher) ComputeBlockHash(data []byte) (uint64, error) { + h.digest.Reset() + + // Include previous hash to create a chain + if err := binary.Write(h.digest, binary.LittleEndian, h.previousHash); err != nil { + return 0, fmt.Errorf("failed to write previous hash: %w", err) + } + + // Include current block data + if _, err := h.digest.Write(data); err != nil { + return 0, fmt.Errorf("failed to write block data: %w", err) + } + + blockHash := h.digest.Sum64() + h.previousHash = blockHash + + return blockHash, nil +} + +// GetPreviousHash returns the current previous hash value. +func (h *BlockHasher) GetPreviousHash() uint64 { + return h.previousHash +} diff --git a/pkg/tokenization/prefixstore/lru_store.go b/pkg/tokenization/prefixstore/lru_store.go index b4cf8b36..ae30601a 100644 --- a/pkg/tokenization/prefixstore/lru_store.go +++ b/pkg/tokenization/prefixstore/lru_store.go @@ -17,11 +17,9 @@ limitations under the License. package prefixstore import ( - "encoding/binary" "fmt" "sync" - "github.com/cespare/xxhash/v2" "github.com/daulet/tokenizers" lru "github.com/hashicorp/golang-lru/v2" ) @@ -109,8 +107,7 @@ func (c *LRUTokenStore) AddTokenization(modelName string, prompt string, tokens promptBytes := []byte(prompt) tokenIdxIterator := 0 - previousHash := uint64(0) - digest := xxhash.New() + hasher := NewBlockHasher() // Chunk the text into blocks and populate the cache for start := 0; start < len(promptBytes); start += c.blockSize { @@ -120,17 +117,11 @@ func (c *LRUTokenStore) AddTokenization(modelName string, prompt string, tokens } // Compute the hash for the current block - digest.Reset() - if err := binary.Write(digest, binary.LittleEndian, previousHash); err != nil { - return fmt.Errorf("failed to add token: %w", err) - } - if _, err := digest.Write(promptBytes[start:end]); err != nil { - return fmt.Errorf("failed to add token: %w", err) + blockHash, err := hasher.ComputeBlockHash(promptBytes[start:end]) + if err != nil { + return fmt.Errorf("failed to compute block hash: %w", err) } - blockHash := digest.Sum64() - previousHash = blockHash - // Only add tokens with [_, high] offset associated with the chunk range. // If a token's [low, _] index is less than the start, it is OK as long as // the above condition is satisfied. @@ -165,8 +156,7 @@ func (c *LRUTokenStore) FindLongestContainedTokens(prompt, modelName string) []u containedTokens := []uint32{} promptBytes := []byte(prompt) - previousHash := uint64(0) - digest := xxhash.New() + hasher := NewBlockHasher() // Chunk the text into blocks and populate the cache for i := 0; i < len(promptBytes); i += c.blockSize { @@ -176,17 +166,11 @@ func (c *LRUTokenStore) FindLongestContainedTokens(prompt, modelName string) []u } // Compute the hash for the current block - digest.Reset() - if err := binary.Write(digest, binary.LittleEndian, previousHash); err != nil { - break - } - if _, err := digest.Write(promptBytes[i:end]); err != nil { + blockHash, err := hasher.ComputeBlockHash(promptBytes[i:end]) + if err != nil { break } - blockHash := digest.Sum64() - previousHash = blockHash - block, ok := cache.Get(blockHash) if !ok { break // early-stop