Skip to content

[feat] refactor the common embedding document chunking that each provider currently does into an external DocEmbedder helper (ala Chat) #34

@csells

Description

@csells

Problem Statement

Current Issues

  1. Duplicated Chunking Logic: All four embeddings model implementations (OpenAIEmbeddingsModel, GoogleEmbeddingsModel, MistralEmbeddingsModel, CohereEmbeddingsModel) contain their own chunking logic in the embedDocuments method.

  2. Inconsistent Batch Sizes:

    • OpenAI: 512 texts (code) vs no hard limit but ~300K tokens (API docs)
    • Google: 100 texts (code) vs 250 texts, 20K tokens (API docs)
    • Mistral: 100 texts (code) vs 512 texts default (API docs)
    • Cohere: 96 texts (hardcoded constant)
  3. Lack of Symmetry: The Chat class provides a convenient wrapper around Agent for managing conversation history, but there's no equivalent wrapper for embeddings operations.

  4. No Caching: Unlike Chat which maintains message history, there's no way to cache computed embeddings for reuse.

  5. Maintenance Burden: Changes to chunking strategy require modifications in 4+ files.

Current Implementation Example

// From OpenAIEmbeddingsModel.embedDocuments():
final effectiveBatchSize = options?.batchSize ?? batchSize ?? 512;
final batches = chunkList(texts, chunkSize: effectiveBatchSize);
// ... chunking logic repeated in each model

Proposed Solution

Create a DocEmbedder class that:

  1. Wraps an Agent instance (similar to how Chat works)
  2. Centralizes all document chunking logic
  3. Maintains a cache of computed embeddings
  4. Provides utility methods for similarity search
  5. Uses provider-specific defaults from web research

High-Level Architecture

┌─────────────┐     ┌─────────────┐
│    Chat     │     │ DocEmbedder │
└──────┬──────┘     └──────┬──────┘
       │                   │
       ▼                   ▼
┌─────────────────────────────────┐
│            Agent                │
└──────┬──────────────────┬───────┘
       │                  │
       ▼                  ▼
┌──────────────┐   ┌──────────────┐
│  ChatModel   │   │EmbeddingsModel│
└──────────────┘   └──────────────┘

Architecture Design

DocEmbedder Class Structure

/// A document embedder that wraps an agent for embeddings operations.
///
/// This class provides a simple interface for embedding documents with an
/// agent. It maintains a cache of embeddings and provides methods to embed
/// queries and documents with automatic chunking.
class DocEmbedder {
  /// Creates a new document embedder with the given [agent].
  DocEmbedder(this.agent);

  /// The agent that will be used for embeddings operations.
  final Agent agent;

  /// Cache mapping document hashes to their embeddings.
  final cache = <String, List<double>>{};

  /// History of embedding results for usage tracking.
  final results = List<BatchEmbeddingsResult>.empty(growable: true);

  /// The display name of the agent.
  String get displayName => agent.displayName;

  /// Embeds a single query and returns the result.
  Future<EmbeddingsResult> embedQuery(String query);

  /// Embeds multiple documents with automatic chunking.
  Future<BatchEmbeddingsResult> embedDocuments(List<String> documents);

  /// Finds the indices of documents most similar to the query.
  Future<List<int>> findMostSimilar(String query, List<String> documents);

  /// Clears the embeddings cache.
  void clearCache();
}

Hash Function Addition to EmbeddingsModel

abstract class EmbeddingsModel<TOptions extends EmbeddingsModelOptions> {
  // ... existing code ...

  /// Computes a consistent hash for a text document.
  /// 
  /// Uses SHA256 to ensure:
  /// - Consistent hashing across platforms
  /// - Low collision probability
  /// - Fast computation
  static String hashDocument(String document) {
    final bytes = utf8.encode(document);
    final digest = sha256.convert(bytes);
    return digest.toString();
  }
}

Provider-Specific Batch Sizes

Based on web research (not the current code which is outdated):

class ProviderDefaults {
  static const Map<String, int> batchSizes = {
    'openai': 2048,      // No hard limit, but practical limit based on tokens
    'google': 250,       // 250 texts per request
    'mistral': 512,      // Default from API docs
    'cohere': 96,        // Maximum texts per call
  };
}

Implementation Details

DocEmbedder Implementation

import 'package:langchain_compat/langchain_compat.dart';

class DocEmbedder {
  DocEmbedder(this.agent);

  final Agent agent;
  final cache = <String, List<double>>{};
  final results = List<BatchEmbeddingsResult>.empty(growable: true);

  String get displayName => agent.displayName;

  Future<EmbeddingsResult> embedQuery(String query) async {
    // Check cache first
    final hash = EmbeddingsModel.hashDocument(query);
    if (cache.containsKey(hash)) {
      return EmbeddingsResult(
        output: cache[hash]!,
        finishReason: FinishReason.stop,
        metadata: {'source': 'cache', 'hash': hash},
      );
    }

    // Compute embedding
    final result = await agent.embedQuery(query);
    
    // Cache the result
    cache[hash] = result.output;
    
    return result;
  }

  Future<BatchEmbeddingsResult> embedDocuments(List<String> documents) async {
    if (documents.isEmpty) {
      return BatchEmbeddingsResult(
        output: [],
        finishReason: FinishReason.stop,
        usage: const LanguageModelUsage(),
      );
    }

    // Separate cached and uncached documents
    final uncachedDocs = <String>[];
    final uncachedIndices = <int>[];
    final allEmbeddings = List<List<double>?>.filled(documents.length, null);

    for (var i = 0; i < documents.length; i++) {
      final hash = EmbeddingsModel.hashDocument(documents[i]);
      if (cache.containsKey(hash)) {
        allEmbeddings[i] = cache[hash]!;
      } else {
        uncachedDocs.add(documents[i]);
        uncachedIndices.add(i);
      }
    }

    // If all documents are cached
    if (uncachedDocs.isEmpty) {
      return BatchEmbeddingsResult(
        output: allEmbeddings.cast<List<double>>(),
        finishReason: FinishReason.stop,
        metadata: {'cached_count': documents.length},
        usage: const LanguageModelUsage(),
      );
    }

    // Get batch size for provider
    final providerName = agent._provider.name;
    final batchSize = ProviderDefaults.batchSizes[providerName] ?? 100;

    // Chunk uncached documents
    final chunks = _chunkDocuments(uncachedDocs, batchSize);
    var totalUsage = const LanguageModelUsage();

    // Process each chunk
    for (var i = 0; i < chunks.length; i++) {
      final chunk = chunks[i];
      
      // Create a temporary model for this batch
      final model = agent._provider.createEmbeddingsModel(
        name: agent._embeddingsModelName,
        options: agent.embeddingsModelOptions,
      );

      try {
        // Call the model's embedDocuments (which no longer chunks)
        final chunkResult = await model.embedDocuments(chunk);
        
        // Cache results and fill in allEmbeddings
        for (var j = 0; j < chunk.length; j++) {
          final docIndex = uncachedIndices[i * batchSize + j];
          final hash = EmbeddingsModel.hashDocument(chunk[j]);
          final embedding = chunkResult.output[j];
          
          cache[hash] = embedding;
          allEmbeddings[docIndex] = embedding;
        }

        totalUsage = totalUsage.concat(chunkResult.usage);
      } finally {
        model.dispose();
      }
    }

    final result = BatchEmbeddingsResult(
      output: allEmbeddings.cast<List<double>>(),
      finishReason: FinishReason.stop,
      metadata: {
        'total_documents': documents.length,
        'cached_count': documents.length - uncachedDocs.length,
        'chunks_processed': chunks.length,
      },
      usage: totalUsage,
    );

    results.add(result);
    return result;
  }

  Future<List<int>> findMostSimilar(
    String query,
    List<String> documents,
  ) async {
    // Ensure all documents are embedded
    await embedDocuments(documents);
    
    // Embed the query
    final queryResult = await embedQuery(query);
    final queryEmbedding = queryResult.output;

    // Get embeddings for all documents
    final docEmbeddings = documents.map((doc) {
      final hash = EmbeddingsModel.hashDocument(doc);
      return cache[hash]!;
    }).toList();

    // Use the static method from EmbeddingsModel
    return EmbeddingsModel.getIndexesMostSimilarEmbeddings(
      queryEmbedding,
      docEmbeddings,
    );
  }

  void clearCache() {
    cache.clear();
  }

  List<List<String>> _chunkDocuments(List<String> documents, int chunkSize) {
    final chunks = <List<String>>[];
    for (var i = 0; i < documents.length; i += chunkSize) {
      final end = (i + chunkSize).clamp(0, documents.length);
      chunks.add(documents.sublist(i, end));
    }
    return chunks;
  }
}

Refactored EmbeddingsModel (example for OpenAI)

// BEFORE: OpenAIEmbeddingsModel.embedDocuments()
Future<BatchEmbeddingsResult> embedDocuments(
  List<String> texts, {
  OpenAIEmbeddingsModelOptions? options,
}) async {
  final effectiveBatchSize = options?.batchSize ?? batchSize ?? 512;
  final batches = chunkList(texts, chunkSize: effectiveBatchSize);
  // ... chunking logic ...
}

// AFTER: OpenAIEmbeddingsModel.embedDocuments()
Future<BatchEmbeddingsResult> embedDocuments(
  List<String> texts, {
  OpenAIEmbeddingsModelOptions? options,
}) async {
  // Remove all chunking logic - just make a single API call
  final effectiveDimensions = options?.dimensions ?? dimensions;
  
  final data = await _client.createEmbedding(
    request: CreateEmbeddingRequest(
      model: EmbeddingModel.modelId(name),
      input: EmbeddingInput.listString(texts),
      dimensions: effectiveDimensions,
      user: options?.user ?? _user,
    ),
  );

  return BatchEmbeddingsResult(
    output: data.data.map((d) => d.embeddingVector).toList(),
    finishReason: FinishReason.stop,
    metadata: {
      'model': name,
      'dimensions': effectiveDimensions,
      'total_texts': texts.length,
    },
    usage: LanguageModelUsage(
      promptTokens: data.usage?.promptTokens,
      totalTokens: data.usage?.totalTokens,
    ),
  );
}

Design Decisions & Rationale

1. Why SHA256 for Hashing?

Chosen: SHA256

  • Cryptographically secure (low collision probability)
  • Fast computation (hardware acceleration on many platforms)
  • Consistent 64-character hex output
  • Part of Dart's crypto package (already a dependency)

Alternatives Considered:

  • MD5: Faster but higher collision risk
  • xxHash: Very fast but requires additional dependency
  • Simple string hash: Too high collision probability

2. Why Cache Hash->Embedding Instead of Full Results?

Chosen: Map<String, List<double>> (hash to embedding)

  • Memory efficient - stores only the vectors
  • Simple lookup by document content
  • Enables deduplication (same text = same hash)

Alternatives Considered:

  • Cache full BatchEmbeddingsResult: Wastes memory on metadata
  • Cache document->embedding: No deduplication for identical texts
  • LRU cache: Added complexity for marginal benefit

3. Why Expose Same Method Names?

Chosen: embedQuery and embedDocuments (same as Agent)

  • Consistency with Chat pattern (send, sendFor, sendStream)
  • Easy mental model for developers
  • Drop-in replacement in many cases

Alternatives Considered:

  • Different names (process, embed): Confusing API differences
  • Single method with overloads: Less clear intent

4. Provider-Specific Defaults

Chosen: Use researched API limits as defaults

  • Optimal performance per provider
  • Avoids rate limiting
  • Future-proof as models improve

Alternatives Considered:

  • Single default for all: Suboptimal performance
  • User-specified always: Poor developer experience

Implementation Steps

Phase 1: Add Hash Function (Must Have)

  1. File: lib/src/embeddings_models/embeddings_model.dart
    • Add import 'package:crypto/crypto.dart';
    • Add static hashDocument method
    • Add tests in test/embeddings_test.dart

Phase 2: Create DocEmbedder (Must Have)

  1. File: lib/src/embeddings_models/doc_embedder.dart

    • Create new file with DocEmbedder class
    • Implement all methods as designed
    • Handle empty inputs gracefully
  2. File: lib/src/embeddings_models/embeddings_models.dart

    • Add export for doc_embedder.dart

Phase 3: Update Provider Defaults (Must Have)

  1. Update batch size defaults in each model's constructor:
    • OpenAIEmbeddingsModel: Keep 512 (reasonable default)
    • GoogleEmbeddingsModel: Change from 100 to 250
    • MistralEmbeddingsModel: Change from 100 to 512
    • CohereEmbeddingsModel: Keep 96 (API limit)

Phase 4: Refactor Models (Must Have)

Remove chunking logic from embedDocuments in:

  1. lib/src/embeddings_models/openai_embeddings/openai_embeddings_model.dart
  2. lib/src/embeddings_models/google_embeddings/google_embeddings_model.dart
  3. lib/src/embeddings_models/mistral_embeddings/mistral_embeddings_model.dart
  4. lib/src/embeddings_models/cohere_embeddings/cohere_embeddings_model.dart

Phase 5: Remove batchSize Parameter (Must Have)

  1. Remove from model constructors (4 files)
  2. Remove from options classes (4 files)
  3. Remove batchSize field from EmbeddingsModel base class
  4. Remove batchSize field from EmbeddingsModelOptions base class

Phase 6: Testing & Documentation (Must Have)

  1. Create test/doc_embedder_test.dart
  2. Update existing embeddings tests
  3. Create example at example/doc_embedder.dart

Testing Requirements

Unit Tests for DocEmbedder

// test/doc_embedder_test.dart

group('DocEmbedder', () {
  test('should cache query embeddings', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    // First call should compute
    final result1 = await embedder.embedQuery('Hello world');
    expect(embedder.cache.length, 1);
    
    // Second call should use cache
    final result2 = await embedder.embedQuery('Hello world');
    expect(result2.metadata?['source'], 'cache');
    expect(result1.output, result2.output);
  });

  test('should handle empty document list', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    final result = await embedder.embedDocuments([]);
    expect(result.output, isEmpty);
    expect(result.usage.totalTokens, 0);
  });

  test('should chunk large document lists', () async {
    final agent = Agent('google:text-embedding-004');
    final embedder = DocEmbedder(agent);
    
    // Create 500 documents (> 250 batch size)
    final docs = List.generate(500, (i) => 'Document $i');
    
    final result = await embedder.embedDocuments(docs);
    expect(result.output.length, 500);
    expect(result.metadata?['chunks_processed'], 2);
  });

  test('should find most similar documents', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    final documents = [
      'The weather is sunny today',
      'Cats are cute animals',
      'The forecast shows rain tomorrow',
      'Dogs are loyal pets',
    ];
    
    final indices = await embedder.findMostSimilar(
      'What is the weather like?',
      documents,
    );
    
    // Weather-related documents should rank higher
    expect(indices.take(2), containsAll([0, 2]));
  });

  test('should deduplicate identical documents', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    final docs = ['Hello', 'World', 'Hello', 'World', 'Hello'];
    final result = await embedder.embedDocuments(docs);
    
    // Should only compute 2 unique embeddings
    expect(embedder.cache.length, 2);
    expect(result.output.length, 5);
    
    // Identical documents should have identical embeddings
    expect(result.output[0], result.output[2]);
    expect(result.output[0], result.output[4]);
    expect(result.output[1], result.output[3]);
  });

  test('should handle mixed cached/uncached documents', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    // Pre-cache some documents
    await embedder.embedDocuments(['A', 'B']);
    expect(embedder.cache.length, 2);
    
    // Mix of cached and new
    final result = await embedder.embedDocuments(['A', 'C', 'B', 'D']);
    expect(result.metadata?['cached_count'], 2);
    expect(embedder.cache.length, 4);
  });

  test('should clear cache', () async {
    final agent = Agent('openai:text-embedding-3-small');
    final embedder = DocEmbedder(agent);
    
    await embedder.embedDocuments(['A', 'B', 'C']);
    expect(embedder.cache.length, 3);
    
    embedder.clearCache();
    expect(embedder.cache.length, 0);
  });
});

Integration Tests

group('DocEmbedder Integration', () {
  test('should work with all providers', () async {
    final providers = ['openai', 'google', 'mistral', 'cohere'];
    
    for (final provider in providers) {
      final agent = Agent(provider);
      final embedder = DocEmbedder(agent);
      
      final result = await embedder.embedQuery('Test query');
      expect(result.output, isNotEmpty);
      expect(result.output.length, greaterThan(100)); // All have 100+ dims
    }
  });

  test('should handle provider-specific batch sizes', () async {
    // Test with Cohere's small batch size
    final agent = Agent('cohere');
    final embedder = DocEmbedder(agent);
    
    // 200 documents > 96 batch size
    final docs = List.generate(200, (i) => 'Document $i');
    final result = await embedder.embedDocuments(docs);
    
    expect(result.output.length, 200);
    expect(result.metadata?['chunks_processed'], greaterThan(2));
  });
});

Edge Cases to Test

  1. Very long documents - Test token limit handling
  2. Special characters - Unicode, emojis in documents
  3. Empty strings - How providers handle empty input
  4. Null/invalid data - Defensive programming
  5. Provider switching - Using multiple providers with same DocEmbedder
  6. Memory pressure - Large cache scenarios

Migration Guide

Before (Direct Model Usage)

// Old approach - using model directly
final provider = Provider.openai;
final model = provider.createEmbeddingsModel();

// Manual chunking needed
final chunks = chunkList(documents, chunkSize: 512);
final allEmbeddings = <List<double>>[];

for (final chunk in chunks) {
  final result = await model.embedDocuments(chunk);
  allEmbeddings.addAll(result.output);
}

After (Using DocEmbedder)

// New approach - using DocEmbedder
final agent = Agent('openai');
final embedder = DocEmbedder(agent);

// Automatic chunking and caching
final result = await embedder.embedDocuments(documents);
final embeddings = result.output;

// Bonus: Built-in similarity search
final similar = await embedder.findMostSimilar(query, documents);

Example: Semantic Search Application

// example/doc_embedder.dart
import 'package:langchain_compat/langchain_compat.dart';

void main() async {
  // Create embedder
  final agent = Agent('openai:text-embedding-3-small');
  final embedder = DocEmbedder(agent);

  // Knowledge base documents
  final documents = [
    'The capital of France is Paris.',
    'The Eiffel Tower is located in Paris.',
    'The capital of Germany is Berlin.',
    'The Brandenburg Gate is in Berlin.',
    'The capital of Japan is Tokyo.',
    'Mount Fuji is near Tokyo.',
  ];

  // Embed all documents (with caching)
  print('Embedding ${documents.length} documents...');
  final result = await embedder.embedDocuments(documents);
  print('Used ${result.usage.totalTokens} tokens');
  print('Cached ${embedder.cache.length} unique embeddings');

  // Semantic search
  final queries = [
    'What is the capital of France?',
    'Tell me about German landmarks',
    'Where is Mount Fuji?',
  ];

  for (final query in queries) {
    print('\nQuery: $query');
    final indices = await embedder.findMostSimilar(query, documents);
    
    print('Top 3 results:');
    for (var i = 0; i < 3 && i < indices.length; i++) {
      print('  ${i + 1}. ${documents[indices[i]]}');
    }
  }

  // Demonstrate caching
  print('\nRe-embedding documents to show caching...');
  final result2 = await embedder.embedDocuments(documents);
  print('Used ${result2.usage.totalTokens} tokens (should be 0)');
  print('${result2.metadata?['cached_count']} documents were cached');
}

Appendix: File Change Summary

NOTE: These file recommendations are based on the current project folder structure in place at the creation of this design. The current project structure should be taken into account before using these file name recommendations.

Files to Create:

  1. lib/src/embeddings_models/doc_embedder.dart
  2. test/doc_embedder_test.dart
  3. example/doc_embedder.dart

Files to Modify:

  1. lib/src/embeddings_models/embeddings_model.dart - Add hash function
  2. lib/src/embeddings_models/embeddings_models.dart - Export DocEmbedder
  3. lib/src/embeddings_models/*/embeddings_model.dart (4 files) - Remove chunking
  4. lib/src/embeddings_models/*/embeddings_model_options.dart (4 files) - Remove batchSize
  5. lib/src/embeddings_models/embeddings_model_options.dart - Remove batchSize field

Files to Consider:

  1. lib/src/embeddings_models/chunk_list.dart - May become unused after refactoring
  2. Existing tests may need updates after removing chunking

Conclusion

This design provides a clean, maintainable solution that:

  • Eliminates code duplication across embeddings models
  • Provides symmetry with the Chat wrapper pattern
  • Adds valuable caching and utility features
  • Maintains backward compatibility where possible
  • Sets up the codebase for future enhancements

The implementation is straightforward and can be completed in phases, with each phase providing incremental value.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions