cnn_keras.py 21 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
"""
    Generic classifier with multiple models
    Models -> (Xception, VGG16, VGG19, ResNet50, InceptionV3, MobileNet)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
7

8 9 10 11 12 13
    Name: cnn_keras.py
    Author: Gabriel Kirsten Menezes (gabriel.kirsten@hotmail.com)

"""
import time
import os
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
14 15
import shutil
import random
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
16
import numpy as np
17
import json
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
18
from PIL import Image
19 20 21
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
22
from keras.models import Model, load_model
23
from keras.layers import Dropout, Flatten, Dense
24 25
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras import backend as K
26

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
27
from interface.interface import InterfaceException as IException
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
28
from classification.classifier import Classifier
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

from collections import OrderedDict

from util.config import Config
from util.file_utils import File
from util.utils import TimeUtils

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress warnings
START_TIME = time.time()

# =========================================================
# Constants
# =========================================================

IMG_WIDTH, IMG_HEIGHT = 256, 256
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
44 45
CLASS_NAMES = ['FolhasLargas', 'Gramineas',
               'Soja', 'Solo']
46
weight_path = None
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
47

48
class CNNKeras(Classifier):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
49 50
    """ Class for CNN classifiers based on Keras applications """

51
    def __init__(self, architecture="ResNet50", learning_rate=0.001, momentum=0.9, batch_size=32, epochs=50, fine_tuning_rate=100, transfer_learning=False, save_weights=True, perc_train=80, perc_validation=20, recreate_dataset=False):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
52 53 54 55
        """
            Constructor of CNNKeras
        """

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        self.architecture = Config(
            "Architecture", architecture, str)
        self.learning_rate = Config(
            "Learning rate", learning_rate, float)
        self.momentum = Config(
            "Momentum", momentum, float)
        self.batch_size = Config(
            "Batch size", batch_size, int)
        self.epochs = Config(
            "Epochs", epochs, int)
        self.fine_tuning_rate = Config(
            "Fine Tuning Rate", fine_tuning_rate, int)
        self.transfer_learning = Config(
            "Transfer Learning", transfer_learning, bool)
        self.save_weights = Config(
            "Save weights", save_weights, bool)
72 73 74 75 76 77
        self.perc_train = Config(
            "Perc Train", perc_train, float)
        self.perc_validation = Config(
            "Perc Validation", perc_validation, float)
        self.recreate_dataset = Config(
            "Recreate Dataset", recreate_dataset, bool)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
78
        self.file_name = "kerasCNN"
79

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
80
        self.model = None
81

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
82
        self.trained = False
83 84

    def get_config(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
85
        """Return configuration of classifier.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
86

87 88 89 90 91
        Returns
        -------
        config : OrderedDict
            Current configs of classifier.
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
92 93
        keras_config = OrderedDict()

94 95 96 97 98 99 100 101
        keras_config["Architecture"] = self.architecture
        keras_config["Learning rate"] = self.learning_rate
        keras_config["Momentum"] = self.momentum
        keras_config["Batch size"] = self.batch_size
        keras_config["Epochs"] = self.epochs
        keras_config["Fine Tuning rate"] = self.fine_tuning_rate
        keras_config["Transfer Learning"] = self.transfer_learning
        keras_config["Save weights"] = self.save_weights
102 103 104
        keras_config["Perc Train"] = self.perc_train
        keras_config["Perc Validation"] = self.perc_validation
        keras_config["Recreate Dataset"] = self.recreate_dataset
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
105
        return keras_config
106 107

    def set_config(self, configs):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
108
        """Update configuration of classifier.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
109

110 111 112 113 114
        Parameters
        ----------
        configs : OrderedDict
            New configs of classifier.
        """
115 116 117 118 119 120 121 122 123 124 125
        self.architecture = Config.nvl_config(configs["Architecture"], self.architecture)
        self.learning_rate = Config.nvl_config(configs["Learning rate"], self.learning_rate)
        self.momentum = Config.nvl_config(configs["Momentum"], self.momentum)
        self.batch_size = Config.nvl_config(configs["Batch size"], self.batch_size)
        self.epochs = Config.nvl_config(configs["Epochs"], self.epochs)
        self.fine_tuning_rate = Config.nvl_config(configs["Fine Tuning rate"], self.fine_tuning_rate)
        self.transfer_learning = Config.nvl_config(configs["Transfer Learning"], self.transfer_learning)
        self.save_weights = Config.nvl_config(configs["Save weights"], self.save_weights)
        self.perc_train = Config.nvl_config(configs["Perc Train"], self.perc_train)
        self.perc_validation = Config.nvl_config(configs["Perc Validation"], self.perc_validation)
        self.recreate_dataset = Config.nvl_config(configs["Recreate Dataset"], self.recreate_dataset)
126 127

    def get_summary_config(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
128
        """Return fomatted summary of configuration.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
129

130 131 132 133 134
        Returns
        -------
        summary : string
            Formatted string with summary of configuration.
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
135
        keras_config = OrderedDict()
136

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
137
        keras_config[self.architecture.label] = self.architecture.value
138 139 140 141
        keras_config[self.learning_rate.label] = self.learning_rate.value
        keras_config[self.momentum.label] = self.momentum.value
        keras_config[self.batch_size.label] = self.batch_size.value
        keras_config[self.epochs.label] = self.epochs.value
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
142
        keras_config[self.fine_tuning_rate.label] = self.fine_tuning_rate.value
143 144
        keras_config[self.transfer_learning.label] = self.transfer_learning.value
        keras_config[self.save_weights.label] = self.save_weights.value
145 146 147
        keras_config[self.perc_train.label] = self.perc_train.value
        keras_config[self.perc_validation.label] = self.perc_validation.value
        keras_config[self.recreate_dataset.label] = self.recreate_dataset.value
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
148 149 150
        summary = ''
        for config in keras_config:
            summary += "%s: %s\n" % (config, str(keras_config[config]))
151

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
152
        return summary
153

154
    def classify(self, dataset, test_dir, test_data, image):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
155
        """"Perform the classification.
156

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
157 158 159 160 161 162 163 164
        Parameters
        ----------
        dataset : string
            Path to image dataset.
        test_dir : string
            Not used.
        test_data : string
            Name of test data file.
165

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
166 167 168 169 170 171 172 173
        Returns
        -------
        summary : list of string
            List of predicted classes for each instance in test data in ordered way.
        """

        predict_directory = File.make_path(dataset, test_dir)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
174
        # Create a Keras class
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        if not os.path.exists(File.make_path(predict_directory, "png")):
            os.makedirs(File.make_path(predict_directory, "png"))

        for file in os.listdir(predict_directory):
            print(File.make_path(predict_directory, file))
            if os.path.splitext(file)[-1] == ".tif":
                    try:
                        img = Image.open(File.make_path(predict_directory, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(predict_directory, 'png', new_file), "PNG", quality=100)
                    except Exception, e:
                        print e
            else:
                    os.symlink(File.make_path(predict_directory, file), File.make_path(predict_directory, 'png', file))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
190

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
191 192 193
        classify_datagen = ImageDataGenerator()

        classify_generator = classify_datagen.flow_from_directory(
194
            File.make_path(predict_directory, 'png'),
195
            target_size=(IMG_HEIGHT, IMG_WIDTH),
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
196 197 198 199 200
            batch_size=1,
            shuffle=False,
            class_mode=None)

        try:
201 202 203 204 205 206 207
            #self.model.load_weights(
                #"../models_checkpoints/" + self.file_name + ".h5")
            K.clear_session()
            if self.weight_path is not None:
                self.model = load_model(self.weight_path)
                path_classes = self.weight_path.replace("_model.h5", "_classes.npy")
                CLASS_NAMES = np.load(path_classes).item().keys()
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
208 209
        except Exception, e:
            raise IException("Can't load the model in " +
210
                             self.weight_path + str(e))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
211

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
212
        output_classification = self.model.predict_generator(
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
213
            classify_generator, classify_generator.samples, verbose=2)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
214 215 216

        one_hot_output = np.argmax(output_classification, axis=1)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
217 218
        one_hot_output = one_hot_output.tolist()

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        for index in range(0, len(one_hot_output)):
            one_hot_output[index] = CLASS_NAMES[one_hot_output[index]]

        return one_hot_output

    def train(self, dataset, training_data, force=False):
        """Perform the training of classifier.

        Parameters
        ----------
        dataset : string
            Path to image dataset.
        training_data : string
            Name of ARFF training file.
        force : boolean, optional, default = False
            If False don't perform new training if there is trained data.
        """
236

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
237
        # select .h5 filename
238 239 240
        if self.fine_tuning_rate.value == 100:
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
241
                '_transfer_learning'
242 243 244
        elif self.fine_tuning_rate.value == -1:
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
245 246
                '_without_transfer_learning'
        else:
247 248
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
249
                '_fine_tunning_' + str(self.fine_tuning_rate.value)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
250

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
251 252
        File.remove_dir(File.make_path(dataset, ".tmp"))

253 254 255 256 257 258 259 260 261 262 263 264 265
        train_generator, validation_generator, test_generator = self.make_dataset(dataset)

        # Save the model according to the conditions
        if self.save_weights:
            if not os.path.exists("../models_checkpoints/"):
                os.makedirs("../models_checkpoints/")
            
            checkpoint = ModelCheckpoint("../models_checkpoints/" + self.file_name + ".h5", monitor='val_acc',
                                            verbose=1, save_best_only=True, save_weights_only=False,
                                            mode='auto', period=1)
        else:
            checkpoint = None

266

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
267 268
        self.model = self.select_model_params(train_generator.num_classes)

269 270
        tensorboard = TensorBoard(log_dir="../models_checkpoints/logs_" + self.file_name, write_images=False)
        #tensorboard.set_model(self.model)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
271 272 273 274 275
        # compile the model
        self.model.compile(loss="categorical_crossentropy",
                           optimizer=optimizers.SGD(
                               lr=self.learning_rate.value, momentum=self.momentum.value),
                           metrics=["accuracy"])
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
276

277
        # Train the model
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
278
        self.model.fit_generator(
279
            train_generator,
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
280 281
            steps_per_epoch=train_generator.samples // self.batch_size.value,
            epochs=self.epochs.value,
282 283 284
            callbacks=[checkpoint, tensorboard],
            validation_data=validation_generator,
            validation_steps=validation_generator.samples // self.batch_size.value)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
285 286

        if self.save_weights:
287 288 289 290 291 292 293 294 295
            #self.model.save_weights(
            #    "../models_checkpoints/" + self.file_name + ".h5")
            self.model.save(
                "../models_checkpoints/" + self.file_name + "_model.h5")
            self.weight_path = "../models_checkpoints/" + self.file_name + "_model.h5"

            dict_classes = validation_generator.class_indices
            np.save("../models_checkpoints/" + self.file_name + "_classes.npy", dict_classes)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
296 297

    def must_train(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
298
        """Return if classifier must be trained.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
299 300 301 302 303

        Returns
        -------
        True
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
304
        return not self.trained
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
305 306

    def must_extract_features(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
307
        """Return if classifier must be extracted features.
308

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
309 310 311 312 313
        Returns
        -------
        False
        """
        return False
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335

    def select_model_params(self, num_classes):
        if self.fine_tuning_rate.value != -1:
            if self.architecture.value == "Xception":
                model = applications.Xception(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG16":
                model = applications.VGG16(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG19":
                model = applications.VGG19(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "ResNet50":
                model = applications.ResNet50(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "InceptionV3":
                model = applications.InceptionV3(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "MobileNet":
                model = applications.MobileNet(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))

336
            for layer in model.layers[:int(len(model.layers) * (self.fine_tuning_rate.value / 100.0))]:
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
                layer.trainable = False

        else:  # without transfer learning
            if self.architecture.value == "Xception":
                model = applications.Xception(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG16":
                model = applications.VGG16(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG19":
                model = applications.VGG19(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "ResNet50":
                model = applications.ResNet50(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "InceptionV3":
                model = applications.InceptionV3(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "MobileNet":
                model = applications.MobileNet(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            for layer in model.layers:
                layer.trainable = True

        # Adding custom Layers
        new_custom_layers = model.output
        new_custom_layers = Flatten()(new_custom_layers)
        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
        new_custom_layers = Dropout(0.5)(new_custom_layers)
        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
        predictions = Dense(num_classes,
                            activation="softmax")(new_custom_layers)

        # creating the final model
        model = Model(inputs=model.input, outputs=predictions)

        return model

    def make_dataset(self, dataset):

        # create symbolic links to the dataset
378 379
        KERAS_DATASET_DIR_NAME = ".keras_dataset"
        #KERAS_DATASET_DIR_NAME = File.make_path("..", os.path.split(dataset)[-1] + "_keras_dataset")
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
380 381
        KERAS_DIR_TRAIN_NAME = "train"
        KERAS_DIR_VALIDATION_NAME = "validation"
382 383 384 385
        KERAS_DIR_TEST_NAME = "test"
        PERC_TRAIN = self.perc_train.value
        PERC_VALIDATION = self.perc_validation.value

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
386 387

        # create keras dir dataset
388 389 390 391
        if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME)) or self.recreate_dataset.value:
            if os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME)):
                shutil.rmtree(File.make_path(dataset, KERAS_DATASET_DIR_NAME))

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
392 393
            os.makedirs(File.make_path(dataset, KERAS_DATASET_DIR_NAME))

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
            # create keras dir train
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME))

            # create keras dir validation
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME))

            # create keras dir test
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME))

            dir_classes = sorted(File.list_dirs(dataset))

            if KERAS_DATASET_DIR_NAME in dir_classes:
                dir_classes.remove(KERAS_DATASET_DIR_NAME)

            for dir_class in dir_classes:
                root = File.make_path(dataset, dir_class)
                files = os.listdir(root)
                random.shuffle(files)
                quant_files = len(files)
                quant_train = int((quant_files / 100.0) * PERC_TRAIN)
                quant_validation = int((quant_files / 100.0) * PERC_VALIDATION)

                files_train = files[0:quant_train]
                files_validation = files[quant_train:quant_train+quant_validation]
                files_test = files[quant_train+quant_validation:quant_files]
                print("Processing class %s - %d itens - %d train items - %d validation items" % (dir_class, quant_files, quant_train, quant_validation))


                for file in files_train:
                    dir_class_train = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME, dir_class)
                    if not os.path.exists(dir_class_train):
                        os.makedirs(dir_class_train)
    
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_train, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_train, file))

                for file in files_validation:
                    dir_class_validation = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME, dir_class)
                    if not os.path.exists(dir_class_validation):
                            os.makedirs(dir_class_validation)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
445
                    
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_validation, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_validation, file))

                for file in files_test:
                    dir_class_test = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME, dir_class)
                    if not os.path.exists(dir_class_test):
                            os.makedirs(dir_class_test)
                    
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_test, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_test, file))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487



        train_datagen = ImageDataGenerator()

        train_generator = train_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")


        validation_datagen = ImageDataGenerator()

        validation_generator = validation_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")

488 489 490 491 492 493 494 495 496 497 498 499 500
        test_datagen = ImageDataGenerator()

        test_generator = test_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")

        return train_generator, validation_generator, test_generator