Spaces:
Runtime error
Runtime error
| from argparse import Namespace | |
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from models.psp import pSp | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| transfroms = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor()] | |
| ) | |
| def tensor2im(var): | |
| var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() | |
| var = ((var + 1) / 2) | |
| var[var < 0] = 0 | |
| var[var > 1] = 1 | |
| var = var * 255 | |
| return Image.fromarray(var.astype('uint8')) | |
| def sketch_recognition(img): | |
| from_im = transfroms(Image.fromarray(img)) | |
| with torch.no_grad(): | |
| res = net(from_im.unsqueeze(0).to(device)) | |
| return tensor2im(res[0]) | |
| path = hf_hub_download('huggan/TediGAN_sketch', 'psp_celebs_sketch_to_face.pt') | |
| ckpt = torch.load(path, map_location=device) | |
| opts = ckpt['opts'] | |
| opts.update({"checkpoint_path": path}) | |
| opts = Namespace(**opts) | |
| net = pSp(opts) | |
| net.eval() | |
| net.to(device) | |
| iface = gr.Interface( | |
| fn=sketch_recognition, | |
| inputs=gr.inputs.Image( | |
| shape=(256, 256), | |
| image_mode="L", | |
| invert_colors=False, | |
| source="canvas", | |
| tool="editor", | |
| type="numpy", | |
| label=None, | |
| optional=False | |
| ), | |
| outputs="image" | |
| ).launch() | |
| iface.launch() | |