使用StyleGAN + CLIP从文本生成面部图像
这次介绍的是StyleClip: StyleGAN+CLIP,从文本进行图像生成的一个模型。是现在最好的图像生成模型和文字模型的结合的。
什么是StyleGAN + CLIP模型?
** StyleGAN**是一个根据参数生成图像的模型,CLIP是可以输出图像和文本之间相似度的模型。这次,我们将按以下方式结合这两个模型。
为StyleGAN提供适当的参数并输出图像(目标图像)。然后,将图像和文本输入到CLIP中,找到相似度,并优化参数,以使相似度尽可能高。。然后,获得生成与文本的内容匹配的图像的参数,并且可以从文本生成图像。
这篇文章有机翻 https://aiqianji.com/blog/article/85
现在,让我们看一下代码实现。
代码
该代码在Google Colab上运行并发布在Github上,单击此“链接”
现在,让我们从文本生成一张脸部图像。在文本中,输入要生成的内容的文本。在这里,我输入“她是一个有着金色的头发和蓝色的眼睛的迷人女人”。
将优化循环旋转101次(max_iter
设置),然后每两次旋转(img_save_freq
设置),将生成的中间图像./pic
保存在其中。
import os
import torch
import torchvision
import clip
import numpy as np
from PIL import Image
from stylegan_models import g_all, g_synthesis, g_mapping
from utils import GetFeatureMaps, transform_img, compute_loss
from tqdm import trange
import warnings
warnings.filterwarnings('ignore')
import os
import shutil
if os.path.isdir('pic'):
shutil.rmtree('pic')
os.makedirs('pic', exist_ok=True)
# 初期設定
lr = 1e-2
img_save_freq = 2
max_iter = 101
ref_img_path = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("USING ", device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
vgg_layers = vgg16.features
vgg_layer_name_mapping = {
'1': "relu1_1",
'3': "relu1_2",
'6': "relu2_1",
'8': "relu2_2",
# '15': "relu3_3",
# '22': "relu4_3"
}
g_synthesis.eval()
g_synthesis.to(device)
latent_shape = (1, 1, 512)
normal_generator = torch.distributions.normal.Normal(
torch.tensor([0.0]),
torch.tensor([1.]),
)
# init_latents = normal_generator.sample(latent_shape).squeeze(-1).to(device)
latents_init = torch.zeros(latent_shape).squeeze(-1).to(device)
latents = torch.nn.Parameter(latents_init, requires_grad=True)
optimizer = torch.optim.Adam(
params=[latents],
lr=lr,
betas=(0.9, 0.999),
)
def truncation(x, threshold=0.7, max_layer=8):
avg_latent = torch.zeros(1, x.size(1), 512).to(device)
interp = torch.lerp(avg_latent, x, threshold)
do_trunc = (torch.arange(x.size(1)) < max_layer).view(1, -1, 1).to(device)
return torch.where(do_trunc, interp, x)
def tensor_to_pil_img(img):
img = (img.clamp(-1, 1) + 1) / 2.0
img = img[0].permute(1, 2, 0).detach().cpu().numpy() * 256
img = Image.fromarray(img.astype('uint8'))
return img
clip_transform = torchvision.transforms.Compose([
# clip_preprocess.transforms[2],
clip_preprocess.transforms[4],
])
if ref_img_path is None:
ref_img = None
else:
ref_img = clip_preprocess(Image.open(ref_img_path)).unsqueeze(0).to(device)
clip_normalize = torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
def compute_clip_loss(img, text):
# img = clip_transform(img)
img = torch.nn.functional.upsample_bilinear(img, (224, 224))
tokenized_text = clip.tokenize([text]).to(device)
img_logits, _text_logits = clip_model(img, tokenized_text)
return 1/img_logits * 100
def compute_perceptual_loss(gen_img, ref_img