Skip to content

Commit 2683099

Browse files
fix style
1 parent 5af0fdc commit 2683099

File tree

2 files changed

+56
-37
lines changed

2 files changed

+56
-37
lines changed

algoperf/workloads/lm/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class BaseLmWorkload(spec.Workload):
2424
_seq_len: int = 2048
2525

2626
def __init__(self) -> None:
27-
super().__init__()
27+
pass
2828

2929
@property
3030
def target_metric_name(self) -> str:

dataset/dataset_setup.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
import datasets as hf_datasets
8181
from transformers import AutoTokenizer
8282

83-
import math
8483
import functools
8584
import itertools
8685
import os
@@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None):
713712

714713
data_dir = os.path.join(data_dir, 'finewebedu')
715714
tmp_dir = tmp_dir if tmp_dir is not None else '/tmp'
716-
cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets')
715+
cache_dir = os.path.join(tmp_dir,
716+
'lm') if tmp_dir is not None else os.path.expanduser(
717+
'~/.cache/huggingface/datasets')
717718

718719
_maybe_mkdir(data_dir)
719720
_maybe_mkdir(tmp_dir)
@@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None):
722723
os.environ["TMPDIR"] = tmp_dir
723724

724725
ds = hf_datasets.load_dataset(
725-
'HuggingFaceFW/fineweb-edu',
726-
name='sample-10BT',
727-
split='train',
728-
cache_dir=cache_dir
729-
)
730-
# TODO (nico): maybe save intermediate dataset to avoid re-downloading
726+
'HuggingFaceFW/fineweb-edu',
727+
name='sample-10BT',
728+
split='train',
729+
cache_dir=cache_dir)
730+
# TODO (nico): maybe save intermediate dataset to avoid re-downloading
731731
# and allow re-chunking with different seq_len?
732732

733733
# Shuffle so that multiproc has shards of similar size.
734734
ds = ds.shuffle(seed=1996)
735735

736736
seq_len = 2048
737-
max_seq_length = seq_len+1
737+
max_seq_length = seq_len + 1
738738
map_setup = dict(batched=True, batch_size=1024, num_proc=8)
739739

740740
# Tokenize
741-
tokenizer = AutoTokenizer.from_pretrained('gpt2')
742-
logging.info(f"Vocab size of tokenizer = {len(tokenizer)}")
741+
lm_tokenizer = AutoTokenizer.from_pretrained('gpt2')
742+
logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}")
743+
743744
def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
744-
add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq
745+
add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq
745746
add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]
746-
return tokenizer(
747-
add_eos_batched(examples["text"]),
748-
return_special_tokens_mask=False,
749-
return_attention_mask=False
750-
)
751-
tokenizer.model_max_length = 1e30 # prevent truncation during tokenization
747+
return lm_tokenizer(
748+
add_eos_batched(examples["text"]),
749+
return_special_tokens_mask=False,
750+
return_attention_mask=False)
751+
752+
lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization
752753
logging.info(f"Tokenizing...")
753754
tokenized_dataset = ds.map(
754-
tokenize,
755-
remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language',
756-
'language_score', 'token_count', 'score', 'int_score'],
757-
**map_setup
758-
)
759-
tokenizer.model_max_length = seq_len
760-
755+
tokenize,
756+
remove_columns=[
757+
'text',
758+
'id',
759+
'dump',
760+
'url',
761+
'file_path',
762+
'language',
763+
'language_score',
764+
'token_count',
765+
'score',
766+
'int_score'
767+
],
768+
**map_setup)
769+
lm_tokenizer.model_max_length = seq_len
770+
761771
tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized"))
762772

763-
# Find how many entries to take from dataset to have VAL_TOKENS in validation set.
764-
VAL_TOKENS = 10_000_000
773+
# Find how many entries to take from dataset to have val_tokens in validation set.
774+
val_tokens = 10_000_000 # TODO: decide this value.
765775
tokens_accumulated, num_examples_for_val = 0, 0
766776
for example in tokenized_dataset:
767777
tokens_accumulated += len(example['input_ids'])
768778
num_examples_for_val += 1
769-
if tokens_accumulated >= VAL_TOKENS:
770-
break
779+
if tokens_accumulated >= val_tokens:
780+
break
771781
# Split in train and valid.
772782
val_dataset = tokenized_dataset.select(range(num_examples_for_val))
773-
train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset)))
783+
train_dataset = tokenized_dataset.select(
784+
range(num_examples_for_val, len(tokenized_dataset)))
774785

775786
# Concat in chunks of max_seq_len.
776787
# NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk.
777788
def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
778789
"""Concatenate text and generate chunks of max_seq_length"""
779-
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
790+
concatenated_examples = {
791+
k: list(itertools.chain(*examples[k])) for k in examples.keys()
792+
}
780793
total_length = len(concatenated_examples[list(examples.keys())[0]])
781794
if total_length >= max_seq_length:
782-
total_length = (total_length // max_seq_length) * max_seq_length
795+
total_length = (total_length // max_seq_length) * max_seq_length
783796
result = {
784-
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
785-
for k, t in concatenated_examples.items()
797+
k: [
798+
t[i:i + max_seq_length]
799+
for i in range(0, total_length, max_seq_length)
800+
] for k, t in concatenated_examples.items()
786801
}
787802
return result
803+
788804
# Concat text in validation and train sets.
789805
logging.info(f"Concatenating and chunking...")
790806
val_dataset = val_dataset.map(concat_chunck, **map_setup)
791807
train_dataset = train_dataset.map(concat_chunck, **map_setup)
792-
logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}")
793-
logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}")
808+
logging.info(
809+
f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}")
810+
logging.info(
811+
f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}"
812+
)
794813

795814
# Save datasets
796815
train_dataset.save_to_disk(os.path.join(data_dir, f"train"))

0 commit comments

Comments
 (0)