@@ -286,8 +286,9 @@ def _validate(self, dataloader):
286
286
running_loss = 0
287
287
self .model .eval ()
288
288
with torch .no_grad ():
289
- for inputs , labels , * _ in dataloader :
289
+ for batch in dataloader :
290
290
# Move data to the correct device
291
+ inputs , labels , _ = unpack_batch (batch )
291
292
inputs , labels = self ._move_to_device (inputs , labels )
292
293
293
294
if isinstance (inputs , tuple ) or isinstance (inputs , list ):
@@ -469,12 +470,12 @@ def __init__(self, data_loader, auto_reset=True):
469
470
def __next__ (self ):
470
471
# Get a new set of inputs and labels
471
472
try :
472
- inputs , labels , * _ = next (self ._iterator )
473
+ inputs , labels , _ = unpack_batch ( next (self ._iterator ) )
473
474
except StopIteration :
474
475
if not self .auto_reset :
475
476
raise
476
477
self ._iterator = iter (self .data_loader )
477
- inputs , labels , * _ = next (self ._iterator )
478
+ inputs , labels , _ = unpack_batch ( next (self ._iterator ) )
478
479
479
480
return inputs , labels
480
481
@@ -483,3 +484,9 @@ def __next__(self):
483
484
484
485
def get_batch (self ):
485
486
return next (self )
487
+
488
+
489
+ def unpack_batch (batch ):
490
+ """Mimics the functionality of extended iterable unpacking from Py3k (PEP-3132) so
491
+ that it can be used in Py2k."""
492
+ return batch [0 ], batch [1 ], batch [2 :]
0 commit comments