Skip to content

Commit 118100c

Browse files
[QEff. Finetune]: Added fix for pad_to_max_length in tokenization. (#599)
Signed-off-by: meetkuma <[email protected]>
1 parent 35d8fd8 commit 118100c

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

QEfficient/finetune/dataset/alpaca_dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,15 @@ def __getitem__(self, index):
5858
else:
5959
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
6060
example = prompt + ann["output"]
61+
62+
if self.context_length is not None:
63+
padding_type = "max_length"
64+
else:
65+
padding_type = True
6166
prompt = torch.tensor(
62-
self.tokenizer.encode(prompt, max_length=self.context_length, pad_to_max_length=True), dtype=torch.int64
67+
self.tokenizer.encode(prompt, max_length=self.context_length, padding=padding_type), dtype=torch.int64
6368
)
64-
example = self.tokenizer.encode(example, max_length=self.context_length, pad_to_max_length=True)
69+
example = self.tokenizer.encode(example, max_length=self.context_length, padding=padding_type)
6570
example.append(self.tokenizer.eos_token_id)
6671
example = torch.tensor(example, dtype=torch.int64)
6772
labels = copy.deepcopy(example)

QEfficient/finetune/dataset/custom_dataset/sample_dataset_preproc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,22 @@ def apply_prompt_template(sample):
6161
dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
6262

6363
def tokenize_add_label(sample):
64+
if context_length is not None:
65+
padding_type = "max_length"
66+
else:
67+
padding_type = True
68+
6469
input = tokenizer.encode(
6570
tokenizer.bos_token + sample["input"],
6671
add_special_tokens=False,
6772
max_length=context_length,
68-
pad_to_max_length=True,
73+
padding=padding_type,
6974
)
7075
label = tokenizer.encode(
7176
sample["label"] + tokenizer.pad_token + tokenizer.eos_token,
7277
add_special_tokens=False,
7378
max_length=context_length,
74-
pad_to_max_length=True,
79+
padding=padding_type,
7580
)
7681

7782
sample = {

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,23 @@ def convert_to_features(self, example_batch):
4444
target_ = example_batch["target"]
4545

4646
prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
47+
48+
if self.context_length is not None:
49+
padding_type = "max_length"
50+
else:
51+
padding_type = True
52+
4753
prompt_ids = self.tokenizer.encode(
4854
self.tokenizer.bos_token + prompt,
4955
add_special_tokens=False,
5056
max_length=self.context_length,
51-
pad_to_max_length=True,
57+
padding=padding_type,
5258
)
5359
label_ids = self.tokenizer.encode(
5460
target_ + self.tokenizer.eos_token,
5561
add_special_tokens=False,
5662
max_length=self.context_length,
57-
pad_to_max_length=True,
63+
padding=padding_type,
5864
)
5965

6066
sample = {

0 commit comments

Comments
 (0)