diff --git a/scripts/image_variations.py b/scripts/image_variations.py index 28bcee42..6c143ba4 100644 --- a/scripts/image_variations.py +++ b/scripts/image_variations.py @@ -19,7 +19,7 @@ def load_model_from_config(config, ckpt, device, verbose=False): print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location=device) + pl_sd = torch.load(ckpt, map_location='cpu') if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"]