import warnings
'ignore') warnings.filterwarnings(
MLP-GAN
computervision
deeplearning
keras
python
tensorflow
Implementation of Vanilla GAN with Multilayered Perceptron using Keras and Tensorflow
Project Repository: https://github.com/soumik12345/Adventures-with-GANS
::: {#cell-3 .cell _cell_guid=‘b1076dfc-b9ad-4769-8c92-a6c4dae69d19’ _uuid=‘8f2839f25d086af736a60e9eeb907d3b93b6e0e5’ execution_count=2}
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
from tqdm import tqdm
:::
= 28
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 1
IMAGE_CHANNELS = 128
BATCH_SIZE = 100
LATENT_DIMENSION = (IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS)
IMAGE_SHAPE = 10000 EPOCHS
def load_data():
= mnist.load_data()
(x_train, _), (_, _) = x_train / 127.5 - 1.
x_train = np.expand_dims(x_train, axis = 3)
x_train return x_train
= load_data()
x_train x_train.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
(60000, 28, 28, 1)
def build_generator(latent_dimension, image_shape, optimizer):
= Sequential([
generator 256, input_dim = latent_dimension),
Dense(0.2),
LeakyReLU(512),
Dense(0.2),
LeakyReLU(1024),
Dense(0.2),
LeakyReLU(= 'tanh')
Dense(np.prod(image_shape), activation = 'Generator')
], name compile(loss = 'binary_crossentropy', optimizer = optimizer)
generator.return generator
def build_discriminator(image_shape, optimizer):
= Sequential([
discriminator 1024, input_dim = np.prod(image_shape)),
Dense(0.2),
LeakyReLU(0.3),
Dropout(512),
Dense(0.2),
LeakyReLU(0.3),
Dropout(256),
Dense(0.2),
LeakyReLU(1, activation = 'sigmoid')
Dense(= 'Discriminator')
], name compile(loss = 'binary_crossentropy', optimizer = optimizer)
discriminator.return discriminator
def build_gan(generator, discriminator, latent_dimension, optimizer):
= False
discriminator.trainable = Input(shape = (latent_dimension, ))
gan_input = generator(gan_input)
x = discriminator(x)
gan_output = Model(gan_input, gan_output, name = 'GAN')
gan compile(loss = 'binary_crossentropy', optimizer = optimizer, metrics = ['accuracy'])
gan.return gan
= Adam(0.0002, 0.5) optimizer
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
= build_generator(LATENT_DIMENSION, IMAGE_SHAPE, optimizer)
generator generator.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 256) 25856
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
dense_3 (Dense) (None, 784) 803600
=================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
_________________________________________________________________
= True, show_layer_names = True).create(prog = 'dot', format = 'svg')) SVG(model_to_dot(generator, show_shapes
= build_discriminator(IMAGE_SHAPE, optimizer)
discriminator discriminator.summary()
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 1024) 803840
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_5 (Dense) (None, 512) 524800
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_6 (Dense) (None, 256) 131328
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 256) 0
_________________________________________________________________
dense_7 (Dense) (None, 1) 257
=================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
_________________________________________________________________
= True, show_layer_names = True).create(prog = 'dot', format = 'svg')) SVG(model_to_dot(discriminator, show_shapes
= build_gan(generator, discriminator, LATENT_DIMENSION, optimizer)
gan gan.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 100) 0
_________________________________________________________________
Generator (Sequential) (None, 784) 1486352
_________________________________________________________________
Discriminator (Sequential) (None, 1) 1460225
=================================================================
Total params: 2,946,577
Trainable params: 1,486,352
Non-trainable params: 1,460,225
_________________________________________________________________
= True, show_layer_names = True).create(prog = 'dot', format = 'svg')) SVG(model_to_dot(gan, show_shapes
def plot_images(nrows, ncols, figsize, generator):
= plt.subplots(nrows = nrows, ncols = ncols, figsize = figsize)
fig, axes = [], yticks = [])
plt.setp(axes.flat, xticks = np.random.normal(0, 1, (nrows * ncols, LATENT_DIMENSION))
noise = generator.predict(noise).reshape(nrows * ncols, IMAGE_WIDTH, IMAGE_HEIGHT)
generated_images for i, ax in enumerate(axes.flat):
= 'gray')
ax.imshow(generated_images[i], cmap plt.show()
= [], []
generator_loss_history, discriminator_loss_history
for epoch in tqdm(range(1, EPOCHS + 1)):
# Select a random batch of images from training data
= np.random.randint(0, x_train.shape[0], BATCH_SIZE)
index = x_train[index].reshape(BATCH_SIZE, 784)
batch_images
# Adversarial Noise
= np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIMENSION))
noise
# Generate fake images
= generator.predict(noise)
generated_images
# Construct batches of real and fake data
= np.concatenate([batch_images, generated_images])
x
# Labels for training the discriminator
= np.zeros(2 * BATCH_SIZE)
y_discriminator = 0.9
y_discriminator[: BATCH_SIZE]
# train the discrimator to distinguish between fake data and real data
= True
discriminator.trainable = discriminator.train_on_batch(x, y_discriminator)
discriminator_loss
discriminator_loss_history.append(discriminator_loss)= False
discriminator.trainable
# Training the GAN
= gan.train_on_batch(noise, np.ones(BATCH_SIZE))
generator_loss
generator_loss_history.append(generator_loss)
if epoch % 1000 == 0:
1, 8, (16, 4), generator) plot_images(
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
2, 8, (16, 6), generator) plot_images(
'./generator.h5') generator.save(