From 1e828abbc5b7ace30ceac00c79f5dd2ccc9b9d93 Mon Sep 17 00:00:00 2001 From: Hasan Shahriar <7091945+hsleonis@users.noreply.github.com> Date: Wed, 6 Aug 2025 17:38:38 +0200 Subject: [PATCH] Update trainer.py Fix deprecation warning in Flair's trainer code. The issue is that: Flair is using the old PyTorch AMP (Automatic Mixed Precision) API. --- flair/trainers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index b77bb21d15..1714efe413 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -562,7 +562,7 @@ def train_custom( self.context_stack = context_stack self.dispatch("after_setup", **parameters) - scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu") + scaler = torch.amp.GradScaler('cuda', enabled=use_amp and flair.device.type != "cpu") final_eval_info = ( "model after last epoch (final-model.pt)"