使用StyleGAN + CLIP从文本生成面部图像[翻译]

资料仓库  收藏
0 / 900

使用StyleGAN + CLIP从文本生成面部图像

 这次介绍的是StyleClip: StyleGAN+CLIP,从文本进行图像生成的一个模型。是现在最好的图像生成模型和文字模型的结合的。

什么是StyleGAN + CLIP模型?

**  StyleGAN**是一个根据参数生成图像的模型,CLIP是可以输出图像和文本之间相似度的模型。这次,我们将按以下方式结合这两个模型。

 
为StyleGAN提供适当的参数并输出图像(目标图像)。然后,将图像文本输入到CLIP中,找到相似度,并优化参数,以使相似度尽可能高。。然后,获得生成与文本的内容匹配的图像的参数,并且可以从文本生成图像。

这篇文章有机翻 https://aiqianji.com/blog/article/85

 现在,让我们看一下代码实现。

代码

 该代码在Google Colab上运行并发布Github上,单击此“链接”

 现在,让我们从文本生成一张脸部图像。在文本中,输入要生成的内容的文本。在这里,我输入“她是一个有着金色的头发和蓝色的眼睛的迷人女人”。

 将优化循环旋转101次(max_iter设置),然后每两次旋转(img_save_freq设置),将生成的中间图像./pic保存在其中。

import os

import torch

import torchvision

import clip

import numpy as np

from PIL import Image

from stylegan_models import g_all, g_synthesis, g_mapping

from utils import GetFeatureMaps, transform_img, compute_loss

from tqdm import trange

import warnings  

warnings.filterwarnings('ignore')   

 


import os

import shutil

if os.path.isdir('pic'):

     shutil.rmtree('pic')

os.makedirs('pic', exist_ok=True)

 

# 初期設定 

lr = 1e-2 

img_save_freq = 2

max_iter = 101 

ref_img_path = None 

 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("USING ", device)

 

clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

vgg16 = torchvision.models.vgg16(pretrained=True).to(device)

vgg_layers = vgg16.features

 

vgg_layer_name_mapping = {

    '1': "relu1_1",

    '3': "relu1_2",

    '6': "relu2_1",

    '8': "relu2_2",

    # '15': "relu3_3",

    # '22': "relu4_3"

}

 

g_synthesis.eval()

g_synthesis.to(device)

 

latent_shape = (1, 1, 512)

 

normal_generator = torch.distributions.normal.Normal(

    torch.tensor([0.0]),

    torch.tensor([1.]),

)

 

# init_latents = normal_generator.sample(latent_shape).squeeze(-1).to(device)

latents_init = torch.zeros(latent_shape).squeeze(-1).to(device)

latents = torch.nn.Parameter(latents_init, requires_grad=True)

 

optimizer = torch.optim.Adam(

    params=[latents],

    lr=lr,

    betas=(0.9, 0.999),

)

 

def truncation(x, threshold=0.7, max_layer=8):

    avg_latent = torch.zeros(1, x.size(1), 512).to(device)

    interp = torch.lerp(avg_latent, x, threshold)

    do_trunc = (torch.arange(x.size(1)) < max_layer).view(1, -1, 1).to(device)

    return torch.where(do_trunc, interp, x)

 

def tensor_to_pil_img(img):

    img = (img.clamp(-1, 1) + 1) / 2.0

    img = img[0].permute(1, 2, 0).detach().cpu().numpy() * 256

    img = Image.fromarray(img.astype('uint8'))

    return img

 

 

clip_transform = torchvision.transforms.Compose([

    # clip_preprocess.transforms[2],

    clip_preprocess.transforms[4],

])

 

if ref_img_path is None:

    ref_img = None

else:

    ref_img = clip_preprocess(Image.open(ref_img_path)).unsqueeze(0).to(device)

 

clip_normalize = torchvision.transforms.Normalize(

    mean=(0.48145466, 0.4578275, 0.40821073),

    std=(0.26862954, 0.26130258, 0.27577711),

)

 

def compute_clip_loss(img, text):

    # img = clip_transform(img)

    img = torch.nn.functional.upsample_bilinear(img, (224, 224))

    tokenized_text = clip.tokenize([text]).to(device)

    img_logits, _text_logits = clip_model(img, tokenized_text)

    return 1/img_logits * 100

 

def compute_perceptual_loss(gen_img, ref_img