@@ -64,9 +64,9 @@ def test_basic_padding(self):
64
64
65
65
result = self .collator (examples )
66
66
67
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
67
68
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
68
69
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
69
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
70
70
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
71
71
72
72
def test_completion_mask (self ):
@@ -79,9 +79,9 @@ def test_completion_mask(self):
79
79
80
80
result = self .collator (examples )
81
81
82
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
82
83
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
83
84
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
84
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
85
85
torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 ], [- 100 , 5 , - 100 ]]))
86
86
87
87
def test_completion_only_loss_disabled (self ):
@@ -95,9 +95,9 @@ def test_completion_only_loss_disabled(self):
95
95
result = collator (examples )
96
96
97
97
# Labels should not be masked when completion_only_loss=False
98
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
98
99
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
99
100
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
100
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
101
101
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
102
102
103
103
def test_padding_free_mode (self ):
@@ -107,72 +107,42 @@ def test_padding_free_mode(self):
107
107
108
108
result = collator (examples )
109
109
110
+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
110
111
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
111
- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
112
112
torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 ]]))
113
- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
113
+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 ]]))
114
114
115
115
def test_padding_free_with_completion_mask (self ):
116
116
"""Test padding-free mode with completion masks."""
117
117
collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True )
118
118
examples = [
119
- {"input_ids" : [1 , 2 , 3 ], "completion_mask" : [0 , 1 , 1 ]},
119
+ {"input_ids" : [1 , 2 , 3 ], "completion_mask" : [0 , 0 , 1 ]},
120
120
{"input_ids" : [4 , 5 ], "completion_mask" : [1 , 1 ]},
121
121
]
122
122
123
123
result = collator (examples )
124
124
125
+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
125
126
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
126
- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
127
127
torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 ]]))
128
- torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , 4 , 5 ]]))
128
+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , - 100 , 3 , - 100 , 5 ]]))
129
129
130
- def test_packing_drops_attention_mask_for_flash_attention (self ):
130
+ def test_packing (self ):
131
131
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
132
- collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True , return_position_ids = True )
132
+ collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True )
133
133
134
134
# Simulate packed sequences with position_ids that restart (typical of BFD packing)
135
135
examples = [
136
- {
137
- "input_ids" : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ], # Packed: [1,2,3] + [4,5] + [6,7,8]
138
- "seq_lengths" : [3 , 2 , 3 ],
139
- }
136
+ {"input_ids" : [1 , 2 , 3 , 4 , 5 , 6 ], "seq_lengths" : [3 , 3 ]},
137
+ {"input_ids" : [7 , 8 , 9 , 10 , 11 ], "seq_lengths" : [4 , 1 ]},
140
138
]
141
139
142
140
result = collator (examples )
143
141
144
- # Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids
145
- self .assertNotIn ("attention_mask" , result , "attention_mask should be dropped for packing with position_ids" )
146
-
147
- # Verify essential keys are present
148
- self .assertIn ("input_ids" , result )
149
- self .assertIn ("position_ids" , result )
150
- self .assertIn ("labels" , result )
151
-
152
- # Verify the data is correctly processed
153
- torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]]))
154
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 0 , 1 , 2 ]]))
155
- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]]))
156
-
157
- def test_padding_free_without_position_ids_keeps_attention_mask (self ):
158
- """
159
- Test that padding_free mode without explicit position_ids still creates attention_mask.
160
- """
161
- collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True , return_position_ids = True )
162
-
163
- # Examples without position_ids (not packed)
164
- examples = [{"input_ids" : [1 , 2 , 3 , 4 , 5 ]}]
165
-
166
- result = collator (examples )
167
-
168
- # Should still have attention_mask since no packed position_ids
169
- self .assertIn ("attention_mask" , result , "attention_mask should be present when no packed position_ids" )
170
- self .assertIn ("position_ids" , result )
171
- self .assertIn ("input_ids" , result )
172
-
173
- torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
174
- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
175
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 3 , 4 ]]))
142
+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
143
+ torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ]]))
144
+ torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 , 2 , 3 , 0 ]]))
145
+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 , 6 , - 100 , 8 , 9 , 10 , - 100 ]]))
176
146
177
147
def test_pad_to_multiple_of (self ):
178
148
"""Test padding to multiple of specified value."""
@@ -181,9 +151,9 @@ def test_pad_to_multiple_of(self):
181
151
182
152
result = collator (examples )
183
153
154
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
184
155
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 0 ], [4 , 5 , 0 , 0 ]]))
185
156
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 0 ], [1 , 1 , 0 , 0 ]]))
186
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 ], [0 , 1 , 0 , 0 ]]))
187
157
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , - 100 ], [4 , 5 , - 100 , - 100 ]]))
188
158
189
159
def test_pad_to_multiple_of_and_padding_free (self ):
@@ -193,21 +163,21 @@ def test_pad_to_multiple_of_and_padding_free(self):
193
163
194
164
result = collator (examples )
195
165
166
+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
196
167
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 0 , 0 , 0 ]]))
197
- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ]]))
198
168
torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 0 , 0 , 0 ]]))
199
- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , - 100 , - 100 , - 100 ]]))
169
+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 , - 100 , - 100 , - 100 ]]))
200
170
201
- def test_custom_position_ids (self ):
202
- """Test handling of custom position IDs in examples ."""
171
+ def test_custom_position_ids_but_no_padding_free (self ):
172
+ """Test that custom position_ids are ignored if padding_free is False ."""
203
173
self .collator = DataCollatorForLanguageModeling (pad_token_id = 0 )
204
174
examples = [{"input_ids" : [1 , 2 , 3 ], "seq_lengths" : [1 , 2 ]}, {"input_ids" : [4 , 5 ], "seq_lengths" : [2 ]}]
205
175
206
176
result = self .collator (examples )
207
177
178
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
208
179
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
209
180
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
210
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 0 , 1 ], [0 , 1 , 0 ]]))
211
181
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
212
182
213
183
def test_single_example (self ):
@@ -217,9 +187,9 @@ def test_single_example(self):
217
187
218
188
result = self .collator (examples )
219
189
190
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
220
191
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 ]]))
221
192
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 ]]))
222
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 3 ]]))
223
193
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 ]]))
224
194
225
195
def test_different_pad_token_id (self ):
@@ -229,9 +199,9 @@ def test_different_pad_token_id(self):
229
199
230
200
result = collator (examples )
231
201
202
+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
232
203
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 999 ]]))
233
204
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
234
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
235
205
torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
236
206
237
207
def test_assistant_masks (self ):
@@ -246,7 +216,6 @@ def test_assistant_masks(self):
246
216
247
217
torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
248
218
torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
249
- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
250
219
torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 ], [- 100 , 5 , - 100 ]]))
251
220
252
221
def test_single_example_single_doc (self ):
0 commit comments