-
Notifications
You must be signed in to change notification settings - Fork 0
/
gan.py
80 lines (60 loc) · 2.16 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import adam
def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# convert shape of x_train from (60000, 28, 28) to (60000, 784)
# 784 columns per row
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)
def adam_optimizer():
return adam(lr=0.0002, beta_1=0.5)
def create_generator():
generator=Sequential()
generator.add(Dense(units=256,input_dim=100))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
return generator
g=create_generator()
g.summary()
def create_discriminator():
discriminator=Sequential()
discriminator.add(Dense(units=1024,input_dim=784))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(units=512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(units=256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(units=1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
return discriminator
d =create_discriminator()
d.summary()
def create_gan(discriminator, generator):
discriminator.trainable=False
gan_input = Input(shape=(100,))
x = generator(gan_input)
gan_output= discriminator(x)
gan= Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
return gan
gan = create_gan(d,g)
gan.summary()