diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index b77bb21d15..1714efe413 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -562,7 +562,7 @@ def train_custom( self.context_stack = context_stack self.dispatch("after_setup", **parameters) - scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu") + scaler = torch.amp.GradScaler('cuda', enabled=use_amp and flair.device.type != "cpu") final_eval_info = ( "model after last epoch (final-model.pt)"