-
Notifications
You must be signed in to change notification settings - Fork 128
Open
Description
Traceback (most recent call last):
File "E:\New GG\a.py", line 11, in
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
KeyError: 'generator'
import torch
from utils import *
from PIL import Image, ImageFont
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
srgan_checkpoint = "./RealESRGAN_x4plus_anime_6B.pth"
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()
def save_super_resolved_image(input_image_path, output_image_path):
input_image = Image.open(input_image_path)
input_image = input_image.convert('RGB')
lr_img = input_image.resize((int(input_image.width / 4), int(input_image.height / 4)), Image.BICUBIC)
sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
sr_img_srgan.save(output_image_path)
if __name__ == '__main__':
input_image_path = "input.png"
output_image_path = "output.jpg"
save_super_resolved_image(input_image_path, output_image_path)
Metadata
Metadata
Assignees
Labels
No labels