8080import datasets as hf_datasets
8181from transformers import AutoTokenizer
8282
83- import math
8483import functools
8584import itertools
8685import os
@@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None):
713712
714713 data_dir = os .path .join (data_dir , 'finewebedu' )
715714 tmp_dir = tmp_dir if tmp_dir is not None else '/tmp'
716- cache_dir = os .path .join (tmp_dir , 'lm' ) if tmp_dir is not None else os .path .expanduser ('~/.cache/huggingface/datasets' )
715+ cache_dir = os .path .join (tmp_dir ,
716+ 'lm' ) if tmp_dir is not None else os .path .expanduser (
717+ '~/.cache/huggingface/datasets' )
717718
718719 _maybe_mkdir (data_dir )
719720 _maybe_mkdir (tmp_dir )
@@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None):
722723 os .environ ["TMPDIR" ] = tmp_dir
723724
724725 ds = hf_datasets .load_dataset (
725- 'HuggingFaceFW/fineweb-edu' ,
726- name = 'sample-10BT' ,
727- split = 'train' ,
728- cache_dir = cache_dir
729- )
730- # TODO (nico): maybe save intermediate dataset to avoid re-downloading
726+ 'HuggingFaceFW/fineweb-edu' ,
727+ name = 'sample-10BT' ,
728+ split = 'train' ,
729+ cache_dir = cache_dir )
730+ # TODO (nico): maybe save intermediate dataset to avoid re-downloading
731731 # and allow re-chunking with different seq_len?
732732
733733 # Shuffle so that multiproc has shards of similar size.
734734 ds = ds .shuffle (seed = 1996 )
735735
736736 seq_len = 2048
737- max_seq_length = seq_len + 1
737+ max_seq_length = seq_len + 1
738738 map_setup = dict (batched = True , batch_size = 1024 , num_proc = 8 )
739739
740740 # Tokenize
741- tokenizer = AutoTokenizer .from_pretrained ('gpt2' )
742- logging .info (f"Vocab size of tokenizer = { len (tokenizer )} " )
741+ lm_tokenizer = AutoTokenizer .from_pretrained ('gpt2' )
742+ logging .info (f"Vocab size of lm_tokenizer = { len (lm_tokenizer )} " )
743+
743744 def tokenize (examples : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
744- add_eos = lambda seq : (seq + tokenizer .eos_token ) if seq else seq
745+ add_eos = lambda seq : (seq + lm_tokenizer .eos_token ) if seq else seq
745746 add_eos_batched = lambda seqs : [add_eos (seq ) for seq in seqs ]
746- return tokenizer (
747- add_eos_batched (examples ["text" ]),
748- return_special_tokens_mask = False ,
749- return_attention_mask = False
750- )
751- tokenizer .model_max_length = 1e30 # prevent truncation during tokenization
747+ return lm_tokenizer (
748+ add_eos_batched (examples ["text" ]),
749+ return_special_tokens_mask = False ,
750+ return_attention_mask = False )
751+
752+ lm_tokenizer .model_max_length = 1e30 # prevent truncation during tokenization
752753 logging .info (f"Tokenizing..." )
753754 tokenized_dataset = ds .map (
754- tokenize ,
755- remove_columns = ['text' , 'id' , 'dump' , 'url' , 'file_path' , 'language' ,
756- 'language_score' , 'token_count' , 'score' , 'int_score' ],
757- ** map_setup
758- )
759- tokenizer .model_max_length = seq_len
760-
755+ tokenize ,
756+ remove_columns = [
757+ 'text' ,
758+ 'id' ,
759+ 'dump' ,
760+ 'url' ,
761+ 'file_path' ,
762+ 'language' ,
763+ 'language_score' ,
764+ 'token_count' ,
765+ 'score' ,
766+ 'int_score'
767+ ],
768+ ** map_setup )
769+ lm_tokenizer .model_max_length = seq_len
770+
761771 tokenized_dataset .save_to_disk (os .path .join (data_dir , f"fwedu_10B_tokenized" ))
762772
763- # Find how many entries to take from dataset to have VAL_TOKENS in validation set.
764- VAL_TOKENS = 10_000_000
773+ # Find how many entries to take from dataset to have val_tokens in validation set.
774+ val_tokens = 10_000_000 # TODO: decide this value.
765775 tokens_accumulated , num_examples_for_val = 0 , 0
766776 for example in tokenized_dataset :
767777 tokens_accumulated += len (example ['input_ids' ])
768778 num_examples_for_val += 1
769- if tokens_accumulated >= VAL_TOKENS :
770- break
779+ if tokens_accumulated >= val_tokens :
780+ break
771781 # Split in train and valid.
772782 val_dataset = tokenized_dataset .select (range (num_examples_for_val ))
773- train_dataset = tokenized_dataset .select (range (num_examples_for_val , len (tokenized_dataset )))
783+ train_dataset = tokenized_dataset .select (
784+ range (num_examples_for_val , len (tokenized_dataset )))
774785
775786 # Concat in chunks of max_seq_len.
776787 # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk.
777788 def concat_chunck (examples : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
778789 """Concatenate text and generate chunks of max_seq_length"""
779- concatenated_examples = {k : list (itertools .chain (* examples [k ])) for k in examples .keys ()}
790+ concatenated_examples = {
791+ k : list (itertools .chain (* examples [k ])) for k in examples .keys ()
792+ }
780793 total_length = len (concatenated_examples [list (examples .keys ())[0 ]])
781794 if total_length >= max_seq_length :
782- total_length = (total_length // max_seq_length ) * max_seq_length
795+ total_length = (total_length // max_seq_length ) * max_seq_length
783796 result = {
784- k : [t [i : i + max_seq_length ] for i in range (0 , total_length , max_seq_length )]
785- for k , t in concatenated_examples .items ()
797+ k : [
798+ t [i :i + max_seq_length ]
799+ for i in range (0 , total_length , max_seq_length )
800+ ] for k , t in concatenated_examples .items ()
786801 }
787802 return result
803+
788804 # Concat text in validation and train sets.
789805 logging .info (f"Concatenating and chunking..." )
790806 val_dataset = val_dataset .map (concat_chunck , ** map_setup )
791807 train_dataset = train_dataset .map (concat_chunck , ** map_setup )
792- logging .info (f"Number of tokens in val_dataset: { len (val_dataset ) * max_seq_length :_} " )
793- logging .info (f"Number of tokens in train_dataset: { len (train_dataset ) * max_seq_length :_} " )
808+ logging .info (
809+ f"Number of tokens in val_dataset: { len (val_dataset ) * max_seq_length :_} " )
810+ logging .info (
811+ f"Number of tokens in train_dataset: { len (train_dataset ) * max_seq_length :_} "
812+ )
794813
795814 # Save datasets
796815 train_dataset .save_to_disk (os .path .join (data_dir , f"train" ))
0 commit comments