-
Notifications
You must be signed in to change notification settings - Fork 230
Open
Description
Hi,
When passing entity_types to DataCollator batch['classes_to_id’] is batch metadata E.g. one copy per batch and is not indexed.
When entity_types is not passed to DataCollator** the batch['classes_to_id’] is per row and unique to the row.
So why should entity_types ever be passed to DataCollator?
E.g. When using mmbert 8192 input sequences or any other long input sequence model, small batch sizes can lead to a batch when none of the samples have labels. This causes the loss function to throw an exception and training to fail. Perhaps this can be solved with initialization of weights to near zero vs zero or the solution below… larger batch sizes are not always feasible in this case.
processor starting at line 388
def create_labels(self, batch, blank = None):
labels_batch = []
decoder_label_strings = []
for i in range(len(batch['tokens'])):
tokens = batch['tokens'][i]
classes_to_id = batch['classes_to_id'][i] # <=== BUG HERE
ner = batch['entities'][i]
num_classes = len(classes_to_id)
Line 393 is not indexed via the batch (corrected below)
def create_labels(self, batch, blank = None):
labels_batch = []
decoder_label_strings = []
for i in range(len(batch['tokens'])):
tokens = batch['tokens'][i]
if type(batch['classes_to_id']) == dict:
classes_to_id = batch['classes_to_id']
elif type(batch['classes_to_id']) == list:
classes_to_id = batch['classes_to_id'][i]
ner = batch['entities'][i]
...
Metadata
Metadata
Assignees
Labels
No labels