import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.convT1 = nn.ConvTranspose2d(noise_dim, 128, 4, 1, 0)
self.convT2 = nn.ConvTranspose2d(128, 64, 3, 2, 1)
self.convT3 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
self.convT4 = nn.ConvTranspose2d(32, 1, 4, 2, 1)
def forward(self, input):
x = F.relu(self.convT1(input))
x = F.relu(self.convT2(x))
x = F.relu(self.convT3(x))
x = torch.tanh(self.convT4(x))
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 4, 2, 1)
self.conv2 = nn.Conv2d(32, 64, 4, 2, 1)
self.conv3 = nn.Conv2d(64, 128, 3, 2, 1)
self.conv4 = nn.Conv2d(128, 1, 4, 1, 0)
def forward(self, input):
x = F.leaky_relu(self.conv1(input), negative_slope=0.2)
x = F.leaky_relu(self.conv2(x), negative_slope=0.2)
x = F.leaky_relu(self.conv3(x), negative_slope=0.2)
x = self.conv4(x)
return x
if __name__ == '__main__':
workers = 4 # Number of workers for dataloader
batch_size = 128
noise_dim = 100 # Size of z latent vector
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)
dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(), transforms.Lambda(lambda x: ((x * 2) - 1))
])
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=workers
)
netG = Generator(noise_dim).to(device)
netD = Discriminator().to(device)
# convention for real and fake labels
real_label = 1.0
fake_label = 0.0
criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.001, betas=(0.5, 0.999))
for epoch in range(num_epochs):
ep_G_loss = 0
ep_D_loss = 0
for i, data in enumerate(dataloader, 0):
realX = data[0].to(device)
b_size = realX.size(0)
# Update D: maximize log(D(x)) + log(1 - D(G(z)))
### Pass real data to netD
netD.zero_grad()
label = torch.full((b_size,), real_label, device=device)
output = netD(realX).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
### Pass fake data to netD
noiseX = torch.randn(b_size, noise_dim, 1, 1, device=device)
label.fill_(fake_label)
fake = netG(noiseX)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
# Update D
optimizerD.step()
# Update G: maximize log(D(G(z)))
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
# Update G
optimizerG.step()
ep_G_loss += errG.item()
ep_D_loss += errD_real.item() + errD_fake.item()
print(
'epoch:', epoch, '/', num_epochs,
'; G_loss:', ep_G_loss / (dataset.data.shape[0] // batch_size),
'; D_loss:', ep_D_loss / (dataset.data.shape[0] // batch_size)
)
if epoch % 10 == 0 or (epoch+1) % num_epochs == 0:
noiseX = torch.randn(5, noise_dim, 1, 1, device=device)
fake = netG(noiseX)
fig, ax = plt.subplots(1, 5)
for i in range(5):
ax[i].imshow((fake[i].detach().cpu().reshape(28, 28) + 1) * 0.5, cmap='gray', vmin=0, vmax=1)
ax[i].axis('off')
plt.show()
plt.close()
Using device: cuda
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
epoch: 0 / 100 ; G_loss: 4.025594746837249 ; D_loss: 0.541045660198374
epoch: 1 / 100 ; G_loss: 2.2992326023741665 ; D_loss: 0.7424446938432052 epoch: 2 / 100 ; G_loss: 1.9783648646667473 ; D_loss: 0.8379886933069072 epoch: 3 / 100 ; G_loss: 1.9486756179577265 ; D_loss: 0.8742384373918812 epoch: 4 / 100 ; G_loss: 1.9037742032550085 ; D_loss: 0.9338736828798667 epoch: 5 / 100 ; G_loss: 1.8381108779173632 ; D_loss: 0.898899404077321 epoch: 6 / 100 ; G_loss: 1.9588675692331237 ; D_loss: 1.0968680807387319 epoch: 7 / 100 ; G_loss: 1.5689463105976071 ; D_loss: 0.9765682847390318 epoch: 8 / 100 ; G_loss: 1.7169391168997838 ; D_loss: 0.9516355206505356 epoch: 9 / 100 ; G_loss: 1.7341617144580581 ; D_loss: 0.9522137177678255 epoch: 10 / 100 ; G_loss: 1.7555983998836615 ; D_loss: 0.9518681125572095
epoch: 11 / 100 ; G_loss: 1.7360371795729694 ; D_loss: 0.957454572407863 epoch: 12 / 100 ; G_loss: 1.7592343311024528 ; D_loss: 0.9510215567663695 epoch: 13 / 100 ; G_loss: 1.7311208406901066 ; D_loss: 1.030192855028157 epoch: 14 / 100 ; G_loss: 1.7508108870595949 ; D_loss: 0.9457079759584024 epoch: 15 / 100 ; G_loss: 1.7789086168393111 ; D_loss: 0.942578107277807 epoch: 16 / 100 ; G_loss: 1.7989460865401814 ; D_loss: 0.9404596544674828 epoch: 17 / 100 ; G_loss: 1.78328772704316 ; D_loss: 0.9470169762006173 epoch: 18 / 100 ; G_loss: 1.829394889056173 ; D_loss: 0.9338661862425824 epoch: 19 / 100 ; G_loss: 1.8181682500021101 ; D_loss: 1.2435397097545777 epoch: 20 / 100 ; G_loss: 1.711103537143805 ; D_loss: 0.9201265703409146
epoch: 21 / 100 ; G_loss: 1.833623936288377 ; D_loss: 0.8992209309696132 epoch: 22 / 100 ; G_loss: 1.8796356265616214 ; D_loss: 0.8930723124271275 epoch: 23 / 100 ; G_loss: 1.8995193031608548 ; D_loss: 0.8997794571850035 epoch: 24 / 100 ; G_loss: 1.9316799252206445 ; D_loss: 0.8911645908672841 epoch: 25 / 100 ; G_loss: 1.941901578098281 ; D_loss: 0.8878339567563982 epoch: 26 / 100 ; G_loss: 1.9575555241770215 ; D_loss: 0.8912391072282424 epoch: 27 / 100 ; G_loss: 1.9775260340454233 ; D_loss: 0.8791944426285405 epoch: 28 / 100 ; G_loss: 1.9883574495712917 ; D_loss: 0.8786238700024083 epoch: 29 / 100 ; G_loss: 2.0078131208817163 ; D_loss: 0.8791541205798714 epoch: 30 / 100 ; G_loss: 2.0560114472849755 ; D_loss: 0.8606156882567283
epoch: 31 / 100 ; G_loss: 2.0541510882540646 ; D_loss: 0.863074517227773 epoch: 32 / 100 ; G_loss: 2.1322783022837637 ; D_loss: 1.2848561377587298 epoch: 33 / 100 ; G_loss: 1.9544953981525877 ; D_loss: 0.8356986431420869 epoch: 34 / 100 ; G_loss: 2.038284544634004 ; D_loss: 0.8263622351691254 epoch: 35 / 100 ; G_loss: 2.0833821836699786 ; D_loss: 0.8282250597850125 epoch: 36 / 100 ; G_loss: 2.0887566478843365 ; D_loss: 0.836319170287277 epoch: 37 / 100 ; G_loss: 2.1106385954170146 ; D_loss: 0.8263147123412699 epoch: 38 / 100 ; G_loss: 2.1214524405634303 ; D_loss: 0.8324364782112021 epoch: 39 / 100 ; G_loss: 2.1276926568940153 ; D_loss: 0.8409088224682033 epoch: 40 / 100 ; G_loss: 2.1308112964670882 ; D_loss: 0.8327076374911345
epoch: 41 / 100 ; G_loss: 2.131626808363148 ; D_loss: 0.8408964784003985 epoch: 42 / 100 ; G_loss: 2.1721376113147817 ; D_loss: 0.8258880305016397 epoch: 43 / 100 ; G_loss: 2.1578446550246997 ; D_loss: 0.8229683301069288 epoch: 44 / 100 ; G_loss: 2.177697014095437 ; D_loss: 0.8259326925294267 epoch: 45 / 100 ; G_loss: 2.194676972607262 ; D_loss: 0.8179371532402996 epoch: 46 / 100 ; G_loss: 1.8628630738816836 ; D_loss: 1.482671215695141 epoch: 47 / 100 ; G_loss: 1.9911826800586832 ; D_loss: 0.802754440655311 epoch: 48 / 100 ; G_loss: 2.1275273613695407 ; D_loss: 0.7862693232794603 epoch: 49 / 100 ; G_loss: 2.177294879132866 ; D_loss: 0.7863244345873339 epoch: 50 / 100 ; G_loss: 2.21672241606264 ; D_loss: 0.789539947691891
epoch: 51 / 100 ; G_loss: 2.2275218026250854 ; D_loss: 0.7933149426124799 epoch: 52 / 100 ; G_loss: 2.22218901428402 ; D_loss: 0.8077221612294769 epoch: 53 / 100 ; G_loss: 2.2164184894317236 ; D_loss: 0.8100897765153239 epoch: 54 / 100 ; G_loss: 2.2335894105271397 ; D_loss: 0.7955706094065284 epoch: 55 / 100 ; G_loss: 2.240376371476385 ; D_loss: 0.8128903179915032 epoch: 56 / 100 ; G_loss: 2.234919564336793 ; D_loss: 0.8055409416237957 epoch: 57 / 100 ; G_loss: 2.2421333145382056 ; D_loss: 0.8070204915462905 epoch: 58 / 100 ; G_loss: 2.2626537757042127 ; D_loss: 0.8084046615558302 epoch: 59 / 100 ; G_loss: 2.2467449298526487 ; D_loss: 0.8024482933533752 epoch: 60 / 100 ; G_loss: 2.026891887698301 ; D_loss: 2.1036019780422976
epoch: 61 / 100 ; G_loss: 1.9626058108276792 ; D_loss: 0.797756417757935 epoch: 62 / 100 ; G_loss: 2.1288901776330085 ; D_loss: 0.7584441274913967 epoch: 63 / 100 ; G_loss: 2.2040701368425646 ; D_loss: 0.7601017936363689 epoch: 64 / 100 ; G_loss: 2.2573280049185467 ; D_loss: 0.7619534540189128 epoch: 65 / 100 ; G_loss: 2.2657612470480113 ; D_loss: 0.7767149130375977 epoch: 66 / 100 ; G_loss: 2.2908432669619208 ; D_loss: 0.7709830653750234 epoch: 67 / 100 ; G_loss: 2.2890092284760923 ; D_loss: 0.7892235971223085 epoch: 68 / 100 ; G_loss: 2.2841589216493134 ; D_loss: 0.795194176877411 epoch: 69 / 100 ; G_loss: 2.294537895000898 ; D_loss: 0.7916370539679232 epoch: 70 / 100 ; G_loss: 2.2995976207093296 ; D_loss: 0.7932863374296416
epoch: 71 / 100 ; G_loss: 2.2877162328133216 ; D_loss: 0.8016409057423345 epoch: 72 / 100 ; G_loss: 2.2639855719529667 ; D_loss: 0.8041345239258729 epoch: 73 / 100 ; G_loss: 2.2871644779657707 ; D_loss: 0.7964961169621884 epoch: 74 / 100 ; G_loss: 2.298080288956308 ; D_loss: 0.7952589884591408 epoch: 75 / 100 ; G_loss: 2.3144368096294565 ; D_loss: 0.7932604627094717 epoch: 76 / 100 ; G_loss: 2.319773532386519 ; D_loss: 0.7938909871010189 epoch: 77 / 100 ; G_loss: 2.336074157148345 ; D_loss: 0.7911147294072514 epoch: 78 / 100 ; G_loss: 2.8829208509505326 ; D_loss: 2.608052038170312 epoch: 79 / 100 ; G_loss: 1.8106857831152077 ; D_loss: 0.8609801237909203 epoch: 80 / 100 ; G_loss: 2.0663085607891407 ; D_loss: 0.7559228680518448
epoch: 81 / 100 ; G_loss: 2.1988770847137156 ; D_loss: 0.7310278601307645 epoch: 82 / 100 ; G_loss: 2.28019866372785 ; D_loss: 0.7315544899966981 epoch: 83 / 100 ; G_loss: 2.3286844176104946 ; D_loss: 0.7390283906243296 epoch: 84 / 100 ; G_loss: 2.3573762165684986 ; D_loss: 0.7349227956918061 epoch: 85 / 100 ; G_loss: 2.3894227648902144 ; D_loss: 0.747650688338993 epoch: 86 / 100 ; G_loss: 2.3921624722643795 ; D_loss: 0.7551322772653184 epoch: 87 / 100 ; G_loss: 2.41464268295174 ; D_loss: 0.7602824414801649 epoch: 88 / 100 ; G_loss: 2.4098722081408543 ; D_loss: 0.7749648282511367 epoch: 89 / 100 ; G_loss: 2.4061698116298413 ; D_loss: 0.7687831522785445 epoch: 90 / 100 ; G_loss: 2.4115507567030754 ; D_loss: 0.7742021458271222
epoch: 91 / 100 ; G_loss: 2.409960554451005 ; D_loss: 0.7709544651114788 epoch: 92 / 100 ; G_loss: 2.4242513600068216 ; D_loss: 0.7755984421341847 epoch: 93 / 100 ; G_loss: 2.1453172233505886 ; D_loss: 1.5187836557232048 epoch: 94 / 100 ; G_loss: 2.2327247835122623 ; D_loss: 0.7408213297971803 epoch: 95 / 100 ; G_loss: 2.3568259867338033 ; D_loss: 0.7247299123396221 epoch: 96 / 100 ; G_loss: 2.425844599039127 ; D_loss: 0.7328249701163453 epoch: 97 / 100 ; G_loss: 2.4568663583352017 ; D_loss: 0.7322179988090299 epoch: 98 / 100 ; G_loss: 2.4941434651358514 ; D_loss: 0.7371676288959053 epoch: 99 / 100 ; G_loss: 2.488156497478485 ; D_loss: 0.7438272241120919