From c4238b4ea0b6996c9ed9cbfbf6c846e1bc6ec2f0 Mon Sep 17 00:00:00 2001 From: Giles Bathgate Date: Sun, 23 Jul 2023 21:43:05 +0200 Subject: [PATCH] Fix discriminator balancing in mnist_gan example (#1332) --- examples/mnist_gan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index 3a1f21f16d..9f16140c0c 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -31,7 +31,8 @@ class LinearDisc: self.l4 = Tensor.scaled_uniform(256, 2) def forward(self, x): - x = x.dot(self.l1).leakyrelu(0.2).dropout(0.3) + # balance the discriminator inputs with const bias (.add(1)) + x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l4).log_softmax() @@ -39,13 +40,12 @@ class LinearDisc: def make_batch(images): sample = np.random.randint(0, len(images), size=(batch_size)) - image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 255.0 - image_b = (image_b - 0.5) / 0.5 + image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0 return Tensor(image_b) -def make_labels(bs, val): +def make_labels(bs, col, val=-2.0): y = np.zeros((bs, 2), np.float32) - y[range(bs), [val] * bs] = -2.0 # Can we do label smoothin? i.e -2.0 changed to -1.98789. + y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789. return Tensor(y) def train_discriminator(optimizer, data_real, data_fake):