Skip to content

Conversation

@sambhavnoobcoder
Copy link

@sambhavnoobcoder sambhavnoobcoder commented Nov 22, 2025

Problem Statement

The Issue

In on-policy distillation (GKD and MiniLLM trainers), when student and teacher models use different tokenizers, the training process produces incorrect results. Specifically:

  • The student model generates rollouts using its own tokenizer
  • The teacher model receives the raw student-tokenized input without re-tokenization
  • This causes teacher logprobs to be computed in low-probability regions
  • Different chat templates (e.g., Llama's <|start_header_id|> vs Qwen's <|im_start|>) further exacerbate the issue

Root Cause Analysis

What We Found

Through careful code inspection and testing, we identified that both GKDTrainer and MiniLLMTrainer had the same fundamental issue:

  1. Student Generation Phase: The student model generates completions using its own tokenizer, producing token IDs specific to its vocabulary
  2. Teacher Evaluation Phase (BUGGY): The teacher model receives these student-tokenized IDs directly
  3. The Problem: The teacher's tokenizer has a different vocabulary mapping, so these token IDs represent completely different tokens in the teacher's vocabulary
  4. Result: Teacher logprobs are computed on nonsensical token sequences, leading to incorrect probability distributions

Example Scenario

  • Student (Qwen) tokenizes "Hello" → token ID 123
  • In Qwen's vocabulary: 123 = "Hello"
  • In teacher's vocabulary (Llama): 123 = "World" (different token!)
  • Teacher computes logprobs for "World" instead of "Hello" → wrong probability distribution

Proposed Solution

High-Level Approach

We implemented a text-based re-tokenization approach:

  1. Text Preservation: Preserve the text content from student-generated rollouts
  2. Re-tokenization: Convert the text back to tokens using the teacher's tokenizer
  3. Correct Evaluation: Teacher processes tokens from its own vocabulary
  4. Text-Aligned Loss: For different vocabulary sizes, align predictions via text decoding/encoding

Key Design Decisions

  1. Opt-in Feature: Only activates when teacher_tokenizer_name_or_path is specified in config
  2. Backward Compatible: Same-tokenizer scenarios use the original code path (no performance overhead)
  3. Handles Vocab Mismatches: Text-aligned loss works with any vocabulary sizes
  4. Both Trainers Fixed: Consistent implementation across GKD and MiniLLM

Implementation Details

Components Implemented

1. Configuration Updates

  • Added teacher_tokenizer_name_or_path parameter to both GKDConfig and MiniLLMConfig
  • Comprehensive docstrings explaining when and how to use the parameter

2. Teacher Tokenizer Loading

  • Automatic loading of teacher's tokenizer when config parameter is set
  • Warning system when models appear different but tokenizer not specified
  • Liger kernel incompatibility check (raises clear error if both are enabled)

3. Text Preservation Pipeline

  • Modified generation pipeline to return text alongside tokens
  • Fallback decode path if text not preserved in inputs dictionary
  • Special token preservation to maintain structure

4. Re-tokenization Utility

  • Shared utility function build_teacher_inputs_from_texts() for both trainers
  • Handles prompt and completion concatenation
  • Proper label masking (padding tokens → -100, prompt tokens → -100)
  • Device placement handling

5. Cross-Tokenizer Loss Computation

  • Conditional branching: cross-tokenizer path vs same-tokenizer path
  • Text-aligned loss for vocabulary size mismatches
  • Teacher predictions decoded and re-encoded to student vocab space
  • Sequence length alignment (handles different tokenizer outputs)

6. Safety and Validation

  • Model mismatch detection with user-friendly warnings
  • Assertion checks for teacher tokenizer loading
  • Proper error messages for incompatible configurations

Testing Strategy

Unit Tests

We verified each component independently:

  1. Config Parameter Test: Verified teacher_tokenizer_name_or_path loads correctly
  2. Tokenizer Loading Test: Confirmed teacher tokenizer loads with correct vocabulary size
  3. Backward Compatibility Test: Verified same-tokenizer scenarios use original path
  4. Warning System Test: Confirmed warnings trigger when models differ without tokenizer specified

End-to-End Tests

We validated the complete pipeline with real scenarios:

  1. Different Tokenizers Test (Qwen ↔ Llama):

    • Student: tiny-Qwen2ForCausalLM-2.5 (151,665 tokens)
    • Teacher: tiny-LlamaForCausalLM-3.2 (128,256 tokens)
    • Result: Training completed successfully, loss stable (11.9269)
  2. Same Tokenizer Test (Qwen ↔ Qwen):

    • Both models using same tokenizer
    • Result: Training completed successfully, loss stable (11.8981)
    • Confirms backward compatibility

Edge Cases Verified

We systematically tested all critical edge cases:

  • No teacher tokenizer specified (backward compatibility)
  • Models differ but tokenizer not specified (warning issued)
  • Liger kernel + cross-tokenizer (error raised)
  • Different sequence lengths (min_length alignment)
  • Padding token handling (masked with -100)
  • Text not preserved (fallback decode path)
  • Different vocabulary sizes (real test: 151K vs 128K)
  • Empty completions (special tokens preserved)
  • Batch size = 1
  • Teacher in eval mode
  • Device mismatches
  • Chat template differences

Verification Results

What We Validated

Original Bug Fixed: Teacher now receives correctly tokenized inputs
GKD Implementation: Full cross-tokenizer support with text-aligned loss
MiniLLM Implementation: Full cross-tokenizer support reusing GKD utilities
Backward Compatible: Same-tokenizer scenarios unchanged
Different Vocab Sizes: Successfully handles 151K ↔ 128K token vocabularies
Warning System: Alerts users when configuration may be incorrect
Error Handling: Clear errors for incompatible configurations
All Edge Cases: 12 edge cases systematically verified

Test Coverage Summary

Test Category Status Details
Config Loading ✅ PASSED Parameter exists and works correctly
Tokenizer Loading ✅ PASSED Teacher tokenizer loads with correct vocab
Different Tokenizers ✅ PASSED Qwen (151K) ↔ Llama (128K) training succeeds
Same Tokenizer ✅ PASSED Backward compatibility maintained
Warning System ✅ PASSED Alerts for potential misconfigurations
Edge Cases ✅ PASSED All 12 edge cases handled correctly

Usage Example

Before (Would Fail or Give Wrong Results)

# This would fail or produce incorrect results
config = GKDConfig(output_dir="./output")
trainer = GKDTrainer(
    model="Qwen/Qwen2.5-0.5B",           # Different tokenizer
    teacher_model="meta-llama/Llama-3.2-1B",  # Different tokenizer
    args=config,
    # ... other args
)
trainer.train()  # ❌ Wrong teacher logprobs!

After (Works Correctly)

# Now works correctly with cross-tokenizer support
config = GKDConfig(
    output_dir="./output",
    teacher_tokenizer_name_or_path="meta-llama/Llama-3.2-1B",  # Key parameter!
)
trainer = GKDTrainer(
    model="Qwen/Qwen2.5-0.5B",           # Different tokenizer
    teacher_model="meta-llama/Llama-3.2-1B",  # Different tokenizer
    args=config,
    # ... other args
)
trainer.train()  # ✅ Correct teacher logprobs!

Screenshots

Test Results

1 . Existing tests

existing_tests

2. GKD with Different Tokenizers (Qwen ↔ Llama)

cross_tokeniser_with_diff_models

3. GKD Same Tokenizer (Backward Compatibility)

cross_tokeniser_with_same_tokeniser

4. MiniLLM Sanity Check

minillm_test

Related Issues

@sambhavnoobcoder sambhavnoobcoder changed the title Fix: Add cross-tokenizer distillation support for GKD and MiniLLM trainers Add cross-tokenizer distillation support for GKD and MiniLLM trainers Nov 22, 2025
…nce instead of CE loss with pseudo-labels, preserving on-policy objective
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cross-tokenizer distillation fails in GKD and MiniLLM trainers

1 participant