cnn_keras.py 20.9 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
"""
    Generic classifier with multiple models
    Models -> (Xception, VGG16, VGG19, ResNet50, InceptionV3, MobileNet)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
7

8 9 10 11 12 13
    Name: cnn_keras.py
    Author: Gabriel Kirsten Menezes (gabriel.kirsten@hotmail.com)

"""
import time
import os
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
14 15
import shutil
import random
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
16
import numpy as np
17
import json
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
18 19 20
import logging
import sys

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
21
from PIL import Image
22 23 24
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
25
from keras.models import Model, load_model
26
from keras.layers import Dropout, Flatten, Dense
27 28
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras import backend as K
29

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
30
from interface.interface import InterfaceException as IException
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
31
from classification.classifier import Classifier
32 33 34 35 36 37 38

from collections import OrderedDict

from util.config import Config
from util.file_utils import File
from util.utils import TimeUtils

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
39 40 41 42 43 44



logger = logging.getLogger('PIL')
logger.setLevel(logging.WARNING)

45 46 47 48 49 50 51
START_TIME = time.time()

# =========================================================
# Constants
# =========================================================

IMG_WIDTH, IMG_HEIGHT = 256, 256
52
weight_path = None
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
53

54
class CNNKeras(Classifier):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
55 56
    """ Class for CNN classifiers based on Keras applications """

57
    def __init__(self, architecture="ResNet50", learning_rate=0.001, momentum=0.9, batch_size=32, epochs=50, fine_tuning_rate=100, transfer_learning=False, save_weights=True, perc_train=80, perc_validation=20, recreate_dataset=False):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
58 59 60 61
        """
            Constructor of CNNKeras
        """

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        self.architecture = Config(
            "Architecture", architecture, str)
        self.learning_rate = Config(
            "Learning rate", learning_rate, float)
        self.momentum = Config(
            "Momentum", momentum, float)
        self.batch_size = Config(
            "Batch size", batch_size, int)
        self.epochs = Config(
            "Epochs", epochs, int)
        self.fine_tuning_rate = Config(
            "Fine Tuning Rate", fine_tuning_rate, int)
        self.transfer_learning = Config(
            "Transfer Learning", transfer_learning, bool)
        self.save_weights = Config(
            "Save weights", save_weights, bool)
78 79 80 81 82 83
        self.perc_train = Config(
            "Perc Train", perc_train, float)
        self.perc_validation = Config(
            "Perc Validation", perc_validation, float)
        self.recreate_dataset = Config(
            "Recreate Dataset", recreate_dataset, bool)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
84
        self.file_name = "kerasCNN"
85

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
86
        self.model = None
87

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
88
        self.trained = False
89 90

    def get_config(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
91
        """Return configuration of classifier.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
92

93 94 95 96 97
        Returns
        -------
        config : OrderedDict
            Current configs of classifier.
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
98 99
        keras_config = OrderedDict()

100 101 102 103 104 105 106 107
        keras_config["Architecture"] = self.architecture
        keras_config["Learning rate"] = self.learning_rate
        keras_config["Momentum"] = self.momentum
        keras_config["Batch size"] = self.batch_size
        keras_config["Epochs"] = self.epochs
        keras_config["Fine Tuning rate"] = self.fine_tuning_rate
        keras_config["Transfer Learning"] = self.transfer_learning
        keras_config["Save weights"] = self.save_weights
108 109 110
        keras_config["Perc Train"] = self.perc_train
        keras_config["Perc Validation"] = self.perc_validation
        keras_config["Recreate Dataset"] = self.recreate_dataset
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
111
        return keras_config
112 113

    def set_config(self, configs):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
114
        """Update configuration of classifier.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
115

116 117 118 119 120
        Parameters
        ----------
        configs : OrderedDict
            New configs of classifier.
        """
121 122 123 124 125 126 127 128 129 130 131
        self.architecture = Config.nvl_config(configs["Architecture"], self.architecture)
        self.learning_rate = Config.nvl_config(configs["Learning rate"], self.learning_rate)
        self.momentum = Config.nvl_config(configs["Momentum"], self.momentum)
        self.batch_size = Config.nvl_config(configs["Batch size"], self.batch_size)
        self.epochs = Config.nvl_config(configs["Epochs"], self.epochs)
        self.fine_tuning_rate = Config.nvl_config(configs["Fine Tuning rate"], self.fine_tuning_rate)
        self.transfer_learning = Config.nvl_config(configs["Transfer Learning"], self.transfer_learning)
        self.save_weights = Config.nvl_config(configs["Save weights"], self.save_weights)
        self.perc_train = Config.nvl_config(configs["Perc Train"], self.perc_train)
        self.perc_validation = Config.nvl_config(configs["Perc Validation"], self.perc_validation)
        self.recreate_dataset = Config.nvl_config(configs["Recreate Dataset"], self.recreate_dataset)
132 133

    def get_summary_config(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
134
        """Return fomatted summary of configuration.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
135

136 137 138 139 140
        Returns
        -------
        summary : string
            Formatted string with summary of configuration.
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
141
        keras_config = OrderedDict()
142

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
143
        keras_config[self.architecture.label] = self.architecture.value
144 145 146 147
        keras_config[self.learning_rate.label] = self.learning_rate.value
        keras_config[self.momentum.label] = self.momentum.value
        keras_config[self.batch_size.label] = self.batch_size.value
        keras_config[self.epochs.label] = self.epochs.value
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
148
        keras_config[self.fine_tuning_rate.label] = self.fine_tuning_rate.value
149 150
        keras_config[self.transfer_learning.label] = self.transfer_learning.value
        keras_config[self.save_weights.label] = self.save_weights.value
151 152 153
        keras_config[self.perc_train.label] = self.perc_train.value
        keras_config[self.perc_validation.label] = self.perc_validation.value
        keras_config[self.recreate_dataset.label] = self.recreate_dataset.value
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
154 155 156
        summary = ''
        for config in keras_config:
            summary += "%s: %s\n" % (config, str(keras_config[config]))
157

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
158
        return summary
159

160
    def classify(self, dataset, test_dir, test_data, image):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
161
        """"Perform the classification.
162

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
163 164 165 166 167 168 169 170
        Parameters
        ----------
        dataset : string
            Path to image dataset.
        test_dir : string
            Not used.
        test_data : string
            Name of test data file.
171

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
172 173 174 175 176 177 178 179
        Returns
        -------
        summary : list of string
            List of predicted classes for each instance in test data in ordered way.
        """

        predict_directory = File.make_path(dataset, test_dir)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
180
        # Create a Keras class
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        if not os.path.exists(File.make_path(predict_directory, "png")):
            os.makedirs(File.make_path(predict_directory, "png"))

        for file in os.listdir(predict_directory):
            print(File.make_path(predict_directory, file))
            if os.path.splitext(file)[-1] == ".tif":
                    try:
                        img = Image.open(File.make_path(predict_directory, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(predict_directory, 'png', new_file), "PNG", quality=100)
                    except Exception, e:
                        print e
            else:
                    os.symlink(File.make_path(predict_directory, file), File.make_path(predict_directory, 'png', file))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
196

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
197 198 199
        classify_datagen = ImageDataGenerator()

        classify_generator = classify_datagen.flow_from_directory(
200
            File.make_path(predict_directory, 'png'),
201
            target_size=(IMG_HEIGHT, IMG_WIDTH),
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
202 203 204 205 206
            batch_size=1,
            shuffle=False,
            class_mode=None)

        try:
207 208 209 210 211 212 213
            #self.model.load_weights(
                #"../models_checkpoints/" + self.file_name + ".h5")
            K.clear_session()
            if self.weight_path is not None:
                self.model = load_model(self.weight_path)
                path_classes = self.weight_path.replace("_model.h5", "_classes.npy")
                CLASS_NAMES = np.load(path_classes).item().keys()
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
214 215
        except Exception, e:
            raise IException("Can't load the model in " +
216
                             self.weight_path + str(e))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
217

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
218
        output_classification = self.model.predict_generator(
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
219
            classify_generator, classify_generator.samples, verbose=2)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
220 221 222

        one_hot_output = np.argmax(output_classification, axis=1)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
223 224
        one_hot_output = one_hot_output.tolist()

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        for index in range(0, len(one_hot_output)):
            one_hot_output[index] = CLASS_NAMES[one_hot_output[index]]

        return one_hot_output

    def train(self, dataset, training_data, force=False):
        """Perform the training of classifier.

        Parameters
        ----------
        dataset : string
            Path to image dataset.
        training_data : string
            Name of ARFF training file.
        force : boolean, optional, default = False
            If False don't perform new training if there is trained data.
        """
242

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
243
        # select .h5 filename
244 245 246
        if self.fine_tuning_rate.value == 100:
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
247
                '_transfer_learning'
248 249 250
        elif self.fine_tuning_rate.value == -1:
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
251 252
                '_without_transfer_learning'
        else:
253 254
            self.file_name = str(self.architecture.value) + \
                '_learning_rate' + str(self.learning_rate.value) + \
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
255
                '_fine_tunning_' + str(self.fine_tuning_rate.value)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
256

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
257 258
        File.remove_dir(File.make_path(dataset, ".tmp"))

259 260 261 262 263 264 265 266 267 268 269 270 271
        train_generator, validation_generator, test_generator = self.make_dataset(dataset)

        # Save the model according to the conditions
        if self.save_weights:
            if not os.path.exists("../models_checkpoints/"):
                os.makedirs("../models_checkpoints/")
            
            checkpoint = ModelCheckpoint("../models_checkpoints/" + self.file_name + ".h5", monitor='val_acc',
                                            verbose=1, save_best_only=True, save_weights_only=False,
                                            mode='auto', period=1)
        else:
            checkpoint = None

272

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
273 274
        self.model = self.select_model_params(train_generator.num_classes)

275 276
        tensorboard = TensorBoard(log_dir="../models_checkpoints/logs_" + self.file_name, write_images=False)
        #tensorboard.set_model(self.model)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
277 278 279 280 281
        # compile the model
        self.model.compile(loss="categorical_crossentropy",
                           optimizer=optimizers.SGD(
                               lr=self.learning_rate.value, momentum=self.momentum.value),
                           metrics=["accuracy"])
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
282

283
        # Train the model
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
284
        self.model.fit_generator(
285
            train_generator,
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
286 287
            steps_per_epoch=train_generator.samples // self.batch_size.value,
            epochs=self.epochs.value,
288 289 290
            callbacks=[checkpoint, tensorboard],
            validation_data=validation_generator,
            validation_steps=validation_generator.samples // self.batch_size.value)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
291 292

        if self.save_weights:
293 294 295 296 297 298 299 300 301
            #self.model.save_weights(
            #    "../models_checkpoints/" + self.file_name + ".h5")
            self.model.save(
                "../models_checkpoints/" + self.file_name + "_model.h5")
            self.weight_path = "../models_checkpoints/" + self.file_name + "_model.h5"

            dict_classes = validation_generator.class_indices
            np.save("../models_checkpoints/" + self.file_name + "_classes.npy", dict_classes)

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
302 303

    def must_train(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
304
        """Return if classifier must be trained.
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
305 306 307 308 309

        Returns
        -------
        True
        """
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
310
        return not self.trained
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
311 312

    def must_extract_features(self):
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
313
        """Return if classifier must be extracted features.
314

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
315 316 317 318 319
        Returns
        -------
        False
        """
        return False
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

    def select_model_params(self, num_classes):
        if self.fine_tuning_rate.value != -1:
            if self.architecture.value == "Xception":
                model = applications.Xception(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG16":
                model = applications.VGG16(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG19":
                model = applications.VGG19(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "ResNet50":
                model = applications.ResNet50(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "InceptionV3":
                model = applications.InceptionV3(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "MobileNet":
                model = applications.MobileNet(
                    weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))

342
            for layer in model.layers[:int(len(model.layers) * (self.fine_tuning_rate.value / 100.0))]:
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
                layer.trainable = False

        else:  # without transfer learning
            if self.architecture.value == "Xception":
                model = applications.Xception(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG16":
                model = applications.VGG16(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "VGG19":
                model = applications.VGG19(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "ResNet50":
                model = applications.ResNet50(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "InceptionV3":
                model = applications.InceptionV3(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            elif self.architecture.value == "MobileNet":
                model = applications.MobileNet(
                    weights=None, include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
            for layer in model.layers:
                layer.trainable = True

        # Adding custom Layers
        new_custom_layers = model.output
        new_custom_layers = Flatten()(new_custom_layers)
        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
        new_custom_layers = Dropout(0.5)(new_custom_layers)
        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
        predictions = Dense(num_classes,
                            activation="softmax")(new_custom_layers)

        # creating the final model
        model = Model(inputs=model.input, outputs=predictions)

        return model

    def make_dataset(self, dataset):

        # create symbolic links to the dataset
384 385
        KERAS_DATASET_DIR_NAME = ".keras_dataset"
        #KERAS_DATASET_DIR_NAME = File.make_path("..", os.path.split(dataset)[-1] + "_keras_dataset")
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
386 387
        KERAS_DIR_TRAIN_NAME = "train"
        KERAS_DIR_VALIDATION_NAME = "validation"
388 389 390 391
        KERAS_DIR_TEST_NAME = "test"
        PERC_TRAIN = self.perc_train.value
        PERC_VALIDATION = self.perc_validation.value

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
392 393

        # create keras dir dataset
394 395 396 397
        if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME)) or self.recreate_dataset.value:
            if os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME)):
                shutil.rmtree(File.make_path(dataset, KERAS_DATASET_DIR_NAME))

Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
398 399
            os.makedirs(File.make_path(dataset, KERAS_DATASET_DIR_NAME))

400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
            # create keras dir train
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME))

            # create keras dir validation
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME))

            # create keras dir test
            if not os.path.exists(File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME)):
                os.makedirs(File.make_path(
                    dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME))

            dir_classes = sorted(File.list_dirs(dataset))

            if KERAS_DATASET_DIR_NAME in dir_classes:
                dir_classes.remove(KERAS_DATASET_DIR_NAME)

            for dir_class in dir_classes:
                root = File.make_path(dataset, dir_class)
                files = os.listdir(root)
                random.shuffle(files)
                quant_files = len(files)
                quant_train = int((quant_files / 100.0) * PERC_TRAIN)
                quant_validation = int((quant_files / 100.0) * PERC_VALIDATION)

                files_train = files[0:quant_train]
                files_validation = files[quant_train:quant_train+quant_validation]
                files_test = files[quant_train+quant_validation:quant_files]
                print("Processing class %s - %d itens - %d train items - %d validation items" % (dir_class, quant_files, quant_train, quant_validation))


                for file in files_train:
                    dir_class_train = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME, dir_class)
                    if not os.path.exists(dir_class_train):
                        os.makedirs(dir_class_train)
    
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_train, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_train, file))

                for file in files_validation:
                    dir_class_validation = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME, dir_class)
                    if not os.path.exists(dir_class_validation):
                            os.makedirs(dir_class_validation)
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
451
                    
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_validation, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_validation, file))

                for file in files_test:
                    dir_class_test = File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME, dir_class)
                    if not os.path.exists(dir_class_test):
                            os.makedirs(dir_class_test)
                    
                    if os.path.splitext(file)[-1] == ".tif":
                        img = Image.open(File.make_path(root, file))
                        #img.thumbnail(img.size)
                        new_file = os.path.splitext(file)[0]+".png"
                        img.save(File.make_path(dir_class_test, new_file), "PNG", quality=100)
                    else:
                        os.symlink(File.make_path(root, file), File.make_path(dir_class_test, file))
Gabriel Kirsten's avatar
 
Gabriel Kirsten committed
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493



        train_datagen = ImageDataGenerator()

        train_generator = train_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TRAIN_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")


        validation_datagen = ImageDataGenerator()

        validation_generator = validation_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_VALIDATION_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")

494 495 496 497 498 499 500 501 502 503 504 505 506
        test_datagen = ImageDataGenerator()

        test_generator = test_datagen.flow_from_directory(
            File.make_path(dataset, KERAS_DATASET_DIR_NAME, KERAS_DIR_TEST_NAME),
            target_size=(IMG_HEIGHT, IMG_WIDTH),
            batch_size=self.batch_size.value,
            shuffle=True,
            class_mode="categorical")

        return train_generator, validation_generator, test_generator