diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index ef2c31e56..311fbe58b 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -181,6 +181,11 @@ macro_rules! setter { /// This can help with reducing polluting your vocabulary with /// highly repetitive tokens like `======` for wikipedia /// +/// initial_tokens (:obj:`List[str]`, `optional`): +/// A list of multi-character tokens to pre-seed the vocabulary (non-special). +/// They are added after the alphabet computation and before merges; they may +/// subsequently be produced by merges and will reuse their pre-assigned ids. +/// Alias: `seed_tokens`. #[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")] pub struct PyBpeTrainer {} #[pymethods] @@ -291,6 +296,16 @@ impl PyBpeTrainer { ); } + #[getter] + fn get_initial_tokens(self_: PyRef) -> Vec { + getter!(self_, BpeTrainer, initial_tokens.clone()) + } + + #[setter] + fn set_initial_tokens(self_: PyRef, tokens: Vec) { + setter!(self_, BpeTrainer, initial_tokens, tokens); + } + #[getter] fn get_continuing_subword_prefix(self_: PyRef) -> Option { getter!(self_, BpeTrainer, continuing_subword_prefix.clone()) @@ -358,6 +373,14 @@ impl PyBpeTrainer { builder = builder.continuing_subword_prefix(val.extract()?) } "end_of_word_suffix" => builder = builder.end_of_word_suffix(val.extract()?), + "initial_tokens" => { + let toks: Vec = val.extract()?; + builder = builder.initial_tokens(toks); + } + "seed_tokens" => { + let toks: Vec = val.extract()?; + builder = builder.initial_tokens(toks); + } _ => println!("Ignored unknown kwargs option {key}"), }; } diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index cda6aea65..459dd02da 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -48,6 +48,7 @@ struct Config { continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, + initial_tokens: Vec, } /// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom @@ -69,6 +70,7 @@ impl Default for BpeTrainerBuilder { continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, + initial_tokens: vec![], }, } } @@ -137,6 +139,7 @@ impl BpeTrainerBuilder { self.config.end_of_word_suffix = Some(suffix); self } + /// Set max_token_length #[must_use] pub fn max_token_length(mut self, max_token_length: Option) -> Self { @@ -144,6 +147,16 @@ impl BpeTrainerBuilder { self } + /// Set the initial multi-character tokens to seed the vocabulary. + #[must_use] + pub fn initial_tokens>(mut self, tokens: Vec) -> Self { + self.config.initial_tokens = tokens.into_iter() + .map(Into::into) + .filter(|s| !s.is_empty()) + .collect(); + self + } + /// Constructs the final BpeTrainer pub fn build(self) -> BpeTrainer { BpeTrainer { @@ -156,6 +169,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, + initial_tokens: self.config.initial_tokens, words: AHashMap::new(), } } @@ -199,6 +213,9 @@ pub struct BpeTrainer { pub end_of_word_suffix: Option, /// An optional parameter to limit the max length of any single token pub max_token_length: Option, + /// Initial multi-character tokens to seed the vocabulary (non-special). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub initial_tokens: Vec, words: AHashMap, } @@ -290,6 +307,13 @@ impl BpeTrainer { *alphabet.entry(*c).or_default() = usize::MAX; } + // CRITICAL FIX: Add all characters from initial tokens to ensure they're in the alphabet + for token in &self.initial_tokens { + for c in token.chars() { + *alphabet.entry(c).or_default() = usize::MAX; + } + } + let mut kept = alphabet.iter().collect::>(); // Compute the number of chars to remove from the alphabet @@ -310,12 +334,6 @@ impl BpeTrainer { kept.sort_unstable_by_key(|k| *k.0 as u32); kept.into_iter().for_each(|(c, _)| { let s = c.to_string(); - /* - if !w2id.contains_key(&s) { - id2w.push(s.clone()); - w2id.insert(s, (id2w.len() - 1) as u32); - } - */ // u64 hash version if !w2id.contains_key(&CompactString::from(&s)) { id2w.push(CompactString::from(&s)); @@ -417,6 +435,215 @@ impl BpeTrainer { ) } + /// Core merge operation shared by initial merges and regular BPE training + fn apply_merge( + &self, + pair: Pair, + positions: AHashSet, + words: &mut [Word], + counts: &[u64], + word_to_id: &mut AHashMap, + id_to_word: &mut Vec, + merges: &mut Vec<(Pair, u32)>, + pair_counts: &mut AHashMap, + max_token_length: usize, + ) -> AHashMap> { + // Build the merged token + let part_a = &id_to_word[pair.0 as usize]; + let mut part_b = id_to_word[pair.1 as usize].as_str(); + + // Remove continuing_subword_prefix if present + if let Some(prefix) = &self.continuing_subword_prefix { + if let Some(rest) = part_b.strip_prefix(prefix) { + part_b = rest; + } + } + + let new_token = format!("{}{}", part_a, part_b); + let new_token_id = word_to_id + .get(&CompactString::from(&new_token)) + .copied() + .unwrap_or(id_to_word.len() as u32); + + if !word_to_id.contains_key(&CompactString::from(&new_token)) { + id_to_word.push(CompactString::from(&new_token)); + word_to_id.insert(CompactString::from(&new_token), new_token_id); + } + + merges.push((pair, new_token_id)); + + // Apply merge to all word positions + let words_len = words.len(); + struct WordPtr(*mut Word); + unsafe impl Sync for WordPtr {} + let word_start = WordPtr(words.as_mut_ptr()); + + let changes = positions + .maybe_par_iter() + .flat_map(|&i| { + unsafe { + assert!(i < words_len); + let word = word_start.0.add(i); + (*word) + .merge(pair.0, pair.1, new_token_id, max_token_length) + .into_iter() + .map(|c| (c, i)) + .collect::>() + } + }) + .collect::>(); + + // Update pair counts and collect new positions to update + let mut where_to_update: AHashMap> = AHashMap::new(); + for ((new_pair, change), iw) in changes { + let count = change * counts[iw] as i32; + *pair_counts.entry(new_pair).or_default() += count; + if change > 0 { + where_to_update.entry(new_pair).or_default().insert(iw); + } + } + + where_to_update + } + + /// Apply initial token merges to jump-start the vocabulary with seed tokens + fn apply_initial_token_merges( + &self, + words: &mut [Word], + counts: &[u64], + word_to_id: &mut AHashMap, + id_to_word: &mut Vec, + merges: &mut Vec<(Pair, u32)>, + pair_counts: &mut AHashMap, + where_to_update: &mut AHashMap>, + max_token_length: usize, + ) { + // Generate merges needed to build initial tokens from characters + let mut initial_merges = Vec::new(); + let mut seen_pairs = AHashSet::new(); + + for token in &self.initial_tokens { + let chars: Vec = token.chars().collect(); + if chars.len() <= 1 { + continue; + } + + // For every possible substring length >= 2 + for start in 0..chars.len() { + for end in start + 2..=chars.len() { + // substring = chars[start..end] + let substring: String = chars[start..end].iter().collect(); + + // Now split it into two parts: left + right + for split in (start + 1)..end { + let left: String = chars[start..split].iter().collect(); + let right: String = chars[split..end].iter().collect(); + + let pair = (left.clone(), right.clone()); + if seen_pairs.insert(pair.clone()) { + initial_merges.push((pair.0, pair.1, substring.clone())); + } + } + } + } + } + + // CRITICAL FIX: Apply each initial merge unconditionally + for (left, right, merged) in initial_merges { + // Unconditionally add all tokens (left, right, merged) to vocabulary + let left_compact = CompactString::from(&left); + let right_compact = CompactString::from(&right); + let merged_compact = CompactString::from(&merged); + + // Ensure left token is in vocabulary + if !word_to_id.contains_key(&left_compact) { + id_to_word.push(left_compact.clone()); + word_to_id.insert(left_compact.clone(), (id_to_word.len() - 1) as u32); + } + + // Ensure right token is in vocabulary + if !word_to_id.contains_key(&right_compact) { + id_to_word.push(right_compact.clone()); + word_to_id.insert(right_compact.clone(), (id_to_word.len() - 1) as u32); + } + + // Ensure merged token is in vocabulary + if !word_to_id.contains_key(&merged_compact) { + id_to_word.push(merged_compact.clone()); + word_to_id.insert(merged_compact.clone(), (id_to_word.len() - 1) as u32); + } + + let left_id = word_to_id[&left_compact]; + let right_id = word_to_id[&right_compact]; + let merged_id = word_to_id[&merged_compact]; + + // UNCONDITIONALLY add the merge to the merge list + merges.push(((left_id, right_id), merged_id)); + + // Now try to apply the merge to corpus, but only if the pair exists + // Try to find the right token with different prefix/suffix combinations + let right_variants = vec![ + right.clone(), + if let Some(prefix) = &self.continuing_subword_prefix { + format!("{}{}", prefix, right) + } else { + String::new() + }, + if let Some(suffix) = &self.end_of_word_suffix { + format!("{}{}", right, suffix) + } else { + String::new() + }, + if let (Some(prefix), Some(suffix)) = (&self.continuing_subword_prefix, &self.end_of_word_suffix) { + format!("{}{}{}", prefix, right, suffix) + } else { + String::new() + }, + ]; + + for right_variant in right_variants { + if right_variant.is_empty() { + continue; + } + + if let Some(&variant_right_id) = word_to_id.get(&CompactString::from(&right_variant)) { + let pair = (left_id, variant_right_id); + + // Check if this pair exists in our corpus words + if let Some(positions) = where_to_update.remove(&pair) { + // Apply merge to corpus words + let new_where_to_update = self.apply_merge( + pair, + positions, + words, + counts, + word_to_id, + id_to_word, + &mut Vec::new(), // Don't add duplicate merge, we already added it above + pair_counts, + max_token_length, + ); + + // Merge the new positions back + for (k, v) in new_where_to_update { + where_to_update.entry(k).or_default().extend(v); + } + break; // Found and processed this variant + } + } + } + } + + // Finally, add the initial tokens themselves to the vocabulary + for token in &self.initial_tokens { + let ct = CompactString::from(token); + if !word_to_id.contains_key(&ct) { + id_to_word.push(ct.clone()); + word_to_id.insert(ct, (id_to_word.len() - 1) as u32); + } + } + } + pub fn do_train( &self, word_counts: &AHashMap, @@ -451,7 +678,34 @@ impl BpeTrainer { // self.update_progress(&progress, words.len(), "Count pairs"); let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress); - // Insert them in the queue + self.finalize_progress(&progress, words.len()); + + // + // 5. Initialize merges vector + // + let mut merges: Vec<(Pair, u32)> = vec![]; + + // + // 6. Apply initial token merges if we have seed tokens + // + if !self.initial_tokens.is_empty() { + self.update_progress(&progress, self.initial_tokens.len(), "Apply initial merges"); + self.apply_initial_token_merges( + &mut words, + &counts, + &mut word_to_id, + &mut id_to_word, + &mut merges, + &mut pair_counts, + &mut where_to_update, + max_token_length, + ); + self.finalize_progress(&progress, self.initial_tokens.len()); + } + + // + // 7. Build the priority queue from remaining pairs + // let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; @@ -463,13 +717,11 @@ impl BpeTrainer { }); } }); - self.finalize_progress(&progress, words.len()); // - // 5. Do merges + // 8. Do regular BPE merges // self.update_progress(&progress, self.vocab_size, "Compute merges"); - let mut merges: Vec<(Pair, u32)> = vec![]; loop { // Stop as soon as we have a big enough vocabulary if word_to_id.len() >= self.vocab_size { @@ -490,68 +742,21 @@ impl BpeTrainer { break; } - let part_a = &id_to_word[top.pair.0 as usize]; - let mut part_b = id_to_word[top.pair.1 as usize].as_str(); - - // Build new token - if let Some(prefix) = &self.continuing_subword_prefix { - if let Some(rest) = part_b.strip_prefix(prefix) { - part_b = rest; - } - } + // Apply the merge + let new_where_to_update = self.apply_merge( + top.pair, + top.pos, + &mut words, + &counts, + &mut word_to_id, + &mut id_to_word, + &mut merges, + &mut pair_counts, + max_token_length, + ); - // Insert new token if it does not already exist - let new_token = format!("{part_a}{part_b}"); - let new_token_id = word_to_id - .get(&CompactString::from(&new_token)) - .copied() - .unwrap_or(id_to_word.len() as u32); - if !word_to_id.contains_key(&CompactString::from(&new_token)) { - id_to_word.push(CompactString::from(&new_token)); - word_to_id.insert(CompactString::from(&new_token), new_token_id); - } - merges.push((top.pair, new_token_id)); - - // Merge the new pair in every words - // Safety: This is just a type assertion, the code below may no longer be safe - // if the type of `pos` changes - let pos: &AHashSet = &top.pos; - - let words_len = words.len(); - struct WordPtr(*mut Word); - // Safety: We do not actually use this for concurrent access to the same memory, - // only to different chunks within the same allocation. - unsafe impl Sync for WordPtr {} - let word_start = WordPtr(words.as_mut_ptr()); - - let changes = pos - .maybe_par_iter() - .flat_map(|&i| { - // We can merge each of these words in parallel here because each position - // can be there only once (AHashSet). So this is safe. - unsafe { - assert!(i < words_len); - // This is words[i], but avoids needing to go through &T (which triggers UB) - let word = word_start.0.add(i); - // let word: &mut Word = &mut (*word); - (*word) - .merge(top.pair.0, top.pair.1, new_token_id, max_token_length) - .into_iter() - .map(|c| (c, i)) - .collect::>() - } - }) - .collect::>(); - - // Introduce new formed pairs - for ((pair, change), iw) in changes { - let count = change * counts[iw] as i32; - *pair_counts.entry(pair).or_default() += count; - if change > 0 { - where_to_update.entry(pair).or_default().insert(iw); - } - } - where_to_update.drain().for_each(|(pair, pos)| { + // Add new pairs to the priority queue + new_where_to_update.into_iter().for_each(|(pair, pos)| { let count = pair_counts[&pair]; if count > 0 { queue.push(Merge { @@ -569,7 +774,6 @@ impl BpeTrainer { self.finalize_progress(&progress, merges.len()); // Transfer new vocab & options to model - //model.vocab = word_to_id; model.vocab = word_to_id .into_iter() // we have to look up the string in id_to_word because the key in word_to_id is a hash @@ -717,6 +921,7 @@ mod tests { .collect(); assert_eq!(model.merges, expected_merges); } + #[test] fn bpe_test_max_token_length_16() { /* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer