Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions pkg/tokenization/prefixstore/block_hasher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
Copyright 2025 The llm-d Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package prefixstore

import (
"encoding/binary"
"fmt"

"github.com/cespare/xxhash/v2"
)

// BlockHasher handles the computation of block hashes for chunked text data.
// It maintains state for sequential hash computation where each block's hash
// depends on the previous block's hash.
type BlockHasher struct {
digest *xxhash.Digest
previousHash uint64
}

func NewBlockHasher() *BlockHasher {
return &BlockHasher{
digest: xxhash.New(),
previousHash: 0,
}
}

// Reset resets the hasher state for a new sequence of blocks.
func (h *BlockHasher) Reset() {
h.digest.Reset()
}

// ComputeBlockHash computes the hash for a block of text data.
// The hash depends on both the current block content and the previous block's hash.
func (h *BlockHasher) ComputeBlockHash(data []byte) (uint64, error) {
h.digest.Reset()

// Include previous hash to create a chain
if err := binary.Write(h.digest, binary.LittleEndian, h.previousHash); err != nil {
return 0, fmt.Errorf("failed to write previous hash: %w", err)
}

// Include current block data
if _, err := h.digest.Write(data); err != nil {
return 0, fmt.Errorf("failed to write block data: %w", err)
}

blockHash := h.digest.Sum64()
h.previousHash = blockHash

return blockHash, nil
}

// GetPreviousHash returns the current previous hash value.
func (h *BlockHasher) GetPreviousHash() uint64 {
return h.previousHash
}
30 changes: 7 additions & 23 deletions pkg/tokenization/prefixstore/lru_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ limitations under the License.
package prefixstore

import (
"encoding/binary"
"fmt"
"sync"

"github.com/cespare/xxhash/v2"
"github.com/daulet/tokenizers"
lru "github.com/hashicorp/golang-lru/v2"
)
Expand Down Expand Up @@ -109,8 +107,7 @@ func (c *LRUTokenStore) AddTokenization(modelName string, prompt string, tokens

promptBytes := []byte(prompt)
tokenIdxIterator := 0
previousHash := uint64(0)
digest := xxhash.New()
hasher := NewBlockHasher()

// Chunk the text into blocks and populate the cache
for start := 0; start < len(promptBytes); start += c.blockSize {
Expand All @@ -120,17 +117,11 @@ func (c *LRUTokenStore) AddTokenization(modelName string, prompt string, tokens
}

// Compute the hash for the current block
digest.Reset()
if err := binary.Write(digest, binary.LittleEndian, previousHash); err != nil {
return fmt.Errorf("failed to add token: %w", err)
}
if _, err := digest.Write(promptBytes[start:end]); err != nil {
return fmt.Errorf("failed to add token: %w", err)
blockHash, err := hasher.ComputeBlockHash(promptBytes[start:end])
if err != nil {
return fmt.Errorf("failed to compute block hash: %w", err)
}

blockHash := digest.Sum64()
previousHash = blockHash

// Only add tokens with [_, high] offset associated with the chunk range.
// If a token's [low, _] index is less than the start, it is OK as long as
// the above condition is satisfied.
Expand Down Expand Up @@ -165,8 +156,7 @@ func (c *LRUTokenStore) FindLongestContainedTokens(prompt, modelName string) []u
containedTokens := []uint32{}

promptBytes := []byte(prompt)
previousHash := uint64(0)
digest := xxhash.New()
hasher := NewBlockHasher()

// Chunk the text into blocks and populate the cache
for i := 0; i < len(promptBytes); i += c.blockSize {
Expand All @@ -176,17 +166,11 @@ func (c *LRUTokenStore) FindLongestContainedTokens(prompt, modelName string) []u
}

// Compute the hash for the current block
digest.Reset()
if err := binary.Write(digest, binary.LittleEndian, previousHash); err != nil {
break
}
if _, err := digest.Write(promptBytes[i:end]); err != nil {
blockHash, err := hasher.ComputeBlockHash(promptBytes[i:end])
if err != nil {
break
}

blockHash := digest.Sum64()
previousHash = blockHash

block, ok := cache.Get(blockHash)
if !ok {
break // early-stop
Expand Down
Loading