Skip to content

Commit b4038d0

Browse files
committed
Fix extended iterable unpacking for Python 2.7
Thanks to @NaleRaphael that proposed the change in #27 and so I just had to copy and paste 😃
1 parent 863e5f6 commit b4038d0

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torch_lr_finder/lr_finder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,9 @@ def _validate(self, dataloader):
286286
running_loss = 0
287287
self.model.eval()
288288
with torch.no_grad():
289-
for inputs, labels, *_ in dataloader:
289+
for batch in dataloader:
290290
# Move data to the correct device
291+
inputs, labels, _ = unpack_batch(batch)
291292
inputs, labels = self._move_to_device(inputs, labels)
292293

293294
if isinstance(inputs, tuple) or isinstance(inputs, list):
@@ -469,12 +470,12 @@ def __init__(self, data_loader, auto_reset=True):
469470
def __next__(self):
470471
# Get a new set of inputs and labels
471472
try:
472-
inputs, labels, *_ = next(self._iterator)
473+
inputs, labels, _ = unpack_batch(next(self._iterator))
473474
except StopIteration:
474475
if not self.auto_reset:
475476
raise
476477
self._iterator = iter(self.data_loader)
477-
inputs, labels, *_ = next(self._iterator)
478+
inputs, labels, _ = unpack_batch(next(self._iterator))
478479

479480
return inputs, labels
480481

@@ -483,3 +484,9 @@ def __next__(self):
483484

484485
def get_batch(self):
485486
return next(self)
487+
488+
489+
def unpack_batch(batch):
490+
"""Mimics the functionality of extended iterable unpacking from Py3k (PEP-3132) so
491+
that it can be used in Py2k."""
492+
return batch[0], batch[1], batch[2:]

0 commit comments

Comments
 (0)