@@ -45,8 +45,8 @@ def _build_input_queue(
4545 not_train = split != 'train'
4646 per_device_batch_size = int (global_batch_size / N_GPUS )
4747
48- seq_len = 2048 # TODO: define it somewehere else
49- DTYPE = torch .int32 # TODO: decide between int32 and int64.
48+ seq_len = self . _seq_len # TODO: define it somewehere else?
49+ dtype = torch .int32 # TODO: decide between int32 and int64.
5050
5151 # Only create and iterate over tf input pipeline in one Python process to
5252 # avoid creating too many threads.
@@ -66,18 +66,18 @@ def _build_input_queue(
6666 if RANK == 0 :
6767 batch = next (np_iter ) # pylint: disable=stop-iteration-return
6868 inputs = torch .as_tensor (
69- batch ['inputs' ], dtype = DTYPE ,
69+ batch ['inputs' ], dtype = dtype ,
7070 device = DEVICE ) # (N_GPUS, global_batch_size, seq_len)
7171 targets = torch .as_tensor (
72- batch ['targets' ], dtype = DTYPE ,
72+ batch ['targets' ], dtype = dtype ,
7373 device = DEVICE ) # (N_GPUS, global_batch_size, seq_len)
7474
7575 # Send batch to other devices when using DDP.
7676 if USE_PYTORCH_DDP :
7777 if not_train :
7878 # During eval, the batch size of the remainder might be different.
7979 per_device_batch_size = torch .tensor (
80- len (targets [0 ]), dtype = DTYPE , device = DEVICE )
80+ len (targets [0 ]), dtype = dtype , device = DEVICE )
8181 dist .broadcast (per_device_batch_size , src = 0 )
8282 # We don't broadcast the shard for RANK 0.
8383 dist .broadcast (inputs [1 :], src = 0 )
@@ -90,15 +90,15 @@ def _build_input_queue(
9090 # Receive batch from rank 0.
9191 if not_train :
9292 # During eval, the batch size of the remainder might be different.
93- per_device_batch_size = torch .empty ((1 ,), dtype = DTYPE , device = DEVICE )
93+ per_device_batch_size = torch .empty ((1 ,), dtype = dtype , device = DEVICE )
9494 dist .broadcast (per_device_batch_size , src = 0 )
9595
9696 # N_GPUS - 1 since we don't broadcast the shard for RANK 0.
9797 inputs = torch .empty ((N_GPUS - 1 , per_device_batch_size , seq_len ),
98- dtype = DTYPE ,
98+ dtype = dtype ,
9999 device = DEVICE )
100100 targets = torch .empty ((N_GPUS - 1 , per_device_batch_size , seq_len ),
101- dtype = DTYPE ,
101+ dtype = dtype ,
102102 device = DEVICE )
103103 dist .broadcast (inputs , src = 0 )
104104 dist .broadcast (targets , src = 0 )
0 commit comments