Relativistic AnimeGAN
Generating Anime Faces using Relativistic GAN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, os
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm_notebook as tqdm
class Generator(nn.Module):
def __init__(self, nz=128, channels=3):
super(Generator, self).__init__()
self.nz = nz
self.channels = channels
def convlayer(n_input, n_output, k_size = 4, stride = 2, padding = 0):
block = [
nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(n_output),
nn.ReLU(inplace=True),
]
return block
self.model = nn.Sequential(
*convlayer(self.nz, 1024, 4, 1, 0),
*convlayer(1024, 512, 4, 2, 1),
*convlayer(512, 256, 4, 2, 1),
*convlayer(256, 128, 4, 2, 1),
*convlayer(128, 64, 4, 2, 1),
nn.ConvTranspose2d(64, self.channels, 3, 1, 1),
nn.Tanh()
)
def forward(self, z):
z = z.view(-1, self.nz, 1, 1)
img = self.model(z)
return img
class Discriminator(nn.Module):
def __init__(self, channels=3):
super(Discriminator, self).__init__()
self.channels = channels
def convlayer(n_input, n_output, k_size = 4, stride = 2, padding = 0, bn = False):
block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
if bn:
block.append(nn.BatchNorm2d(n_output))
block.append(nn.LeakyReLU(0.2, inplace=True))
return block
self.model = nn.Sequential(
*convlayer(self.channels, 32, 4, 2, 1),
*convlayer(32, 64, 4, 2, 1),
*convlayer(64, 128, 4, 2, 1, bn = True),
*convlayer(128, 256, 4, 2, 1, bn = True),
nn.Conv2d(256, 1, 4, 1, 0, bias = False),
)
def forward(self, imgs):
out = self.model(imgs)
return out.view(-1, 1)
batch_size = 32
lr = 0.001
beta1 = 0.5
epochs = 300
real_label = 0.5
fake_label = 0
nz = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class AnimeFacesDataset(Dataset):
def __init__(self, img_dir, transform1=None, transform2=None):
self.img_dir = img_dir
self.img_names = os.listdir(img_dir)
self.transform1 = transform1
self.transform2 = transform2
self.imgs = []
for img_name in self.img_names:
img = Image.open(os.path.join(img_dir, img_name))
if self.transform1 is not None:
img = self.transform1(img)
self.imgs.append(img)
def __getitem__(self, index):
img = self.imgs[index]
if self.transform2 is not None:
img = self.transform2(img)
return img
def __len__(self):
return len(self.imgs)
transform1 = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64)
])
random_transforms = [transforms.RandomRotation(degrees = 5)]
transform2 = transforms.Compose([
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply(random_transforms, p = 0.3),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)
)
])
train_dataset = AnimeFacesDataset(
img_dir = '../input/data/data/',
transform1 = transform1,
transform2 = transform2
)
train_loader = DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
num_workers = 4
)
imgs = next(iter(train_loader))
imgs = imgs.numpy().transpose(0, 2, 3, 1)
fig = plt.figure(figsize = (25, 16))
for ii, img in enumerate(imgs):
ax = fig.add_subplot(4, 8, ii + 1, xticks = [], yticks = [])
plt.imshow((img + 1) / 2)
netG = Generator(nz).to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
fixed_noise = torch.randn(25, nz, 1, 1, device=device)
def show_generated_img():
noise = torch.randn(1, nz, 1, 1, device=device)
gen_image = netG(noise).to("cpu").clone().detach().squeeze(0)
gen_image = gen_image.numpy().transpose(1, 2, 0)
plt.imshow((gen_image+1)/2)
plt.show()
for epoch in range(epochs):
for ii, real_images in tqdm(enumerate(train_loader), total=len(train_loader)):
netD.zero_grad()
real_images = real_images.to(device)
batch_size = real_images.size(0)
labels = torch.full((batch_size, 1), real_label, device=device)
outputR = netD(real_images)
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
outputF = netD(fake.detach())
errD = (torch.mean((outputR - torch.mean(outputF) - labels) ** 2) +
torch.mean((outputF - torch.mean(outputR) + labels) ** 2))/2
errD.backward(retain_graph=True)
optimizerD.step()
netG.zero_grad()
outputF = netD(fake)
errG = (torch.mean((outputR - torch.mean(outputF) + labels) ** 2) +
torch.mean((outputF - torch.mean(outputR) - labels) ** 2))/2
errG.backward()
optimizerG.step()
if (ii+1) % (len(train_loader)//2) == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
% (epoch + 1, epochs, ii+1, len(train_loader),
errD.item(), errG.item()))
show_generated_img()
gen_z = torch.randn(32, nz, 1, 1, device=device)
gen_images = (netG(gen_z).to("cpu").clone().detach() + 1)/2
gen_images = gen_images.numpy().transpose(0, 2, 3, 1)
fig = plt.figure(figsize=(25, 16))
for ii, img in enumerate(gen_images):
ax = fig.add_subplot(4, 8, ii + 1, xticks=[], yticks=[])
plt.imshow(img)