Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when checking input: expected input_61 to have 4 dimensions, but got array with shape (32, 64, 64, 1, 3) #15

Open
sandhya9173 opened this issue Mar 22, 2020 · 3 comments
Assignees

Comments

@sandhya9173
Copy link

what is utils.py file i Thought its generator discriminator train model but i m getting error in that please can you help me

from future import print_function, division

#from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from skimage.io import imread_collection
import cv2
import matplotlib.pyplot as plt
import random

import sys

import numpy as np

class GAN():
def init(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100

    optimizer = Adam(0.0002, 0.5)

    # Build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy',
        optimizer=optimizer,
        metrics=['accuracy'])

    # Build the generator
    self.generator = self.build_generator()

    # The generator takes noise as input and generates imgs
    z = Input(shape=(self.latent_dim,))
    img = self.generator(z)

    # For the combined model we will only train the generator
    self.discriminator.trainable = False

    # The discriminator takes generated images as input and determines validity
    validity = self.discriminator(img)

    # The combined model  (stacked generator and discriminator)
    # Trains the generator to fool the discriminator
    self.combined = Model(z, validity)
    self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


def build_generator(self):

    model = Sequential()

    model.add(Dense(256, input_dim=self.latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))

    model.summary()

    noise = Input(shape=(self.latent_dim,))
    img = model(noise)

    return Model(noise, img)

def build_discriminator(self):

    model = Sequential()

    model.add(Flatten(input_shape=self.img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=self.img_shape)
    validity = model(img)

    return Model(img, validity)

def train(self, epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    #(X_train, _), (_, _) = mnist.load_data()
    col_dir = 'C:/Users/Ajaykumar/Data_set/hack/dataset_2/Images/*.jpg'
    #creating a collection with the available images
    col = imread_collection(col_dir)
    all_images=np.zeros((450,64,64,3))#64*64 Number of pixels
    X=np.zeros((450,64,64,3))
    for i in range(450):
        #if i%10==0:
            #print(i)
        var=cv2.resize(col[i],(64,64))
        #print(var)
        all_images[i,:,:]=var
        all_images[i,:,:]=all_images[i,:,:]
        #X_train[i,:,:] = all_images[i,:,:]
    #print(np.shape(X_train))

    # Rescale -1 to 1
    all_images = all_images / 127.5 - 1.
    all_images = np.expand_dims(all_images, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        #random.sample(range(0, dataset.shape[0]), n_samples)
        idx = random.sample(range(0, all_images.shape[0]), batch_size)
        imgs = all_images[idx]

        noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

        # Generate a batch of new images
        gen_imgs = self.generator.predict(noise)

        # Train the discriminator
        d_loss_real = self.discriminator.train_on_batch(imgs, valid)
        d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

        # Train the generator (to have the discriminator label samples as valid)
        g_loss = self.combined.train_on_batch(noise, valid)

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            self.sample_images(epoch)

def sample_images(self, epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, self.latent_dim))
    gen_imgs = self.generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("C:/Users/Ajaykumar/Data_set/hack/images/aa/%d.png" % epoch)
    plt.close()

if name == 'main':
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)

@sandhya9173
Copy link
Author

hi
I m added my data set can you please this code is low to high-resolution images code is their are
pic
not

@diegoalejogm
Copy link
Owner

Hey, I'll take a look this weekend.
Hope it is not too long of a wait 👍

@diegoalejogm diegoalejogm self-assigned this Apr 2, 2020
@Player0109
Copy link

Player0109 commented Oct 27, 2020

I think this problem is arrising due to this line:
all_images = np.expand_dims(all_images, axis=3)

This Line is not needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants