From 8bba4c044aa65af9e617096623c78099f3cdaaba Mon Sep 17 00:00:00 2001
From: Gilberto Astolfi <gilbertoastolfi@gmail.com>
Date: Wed, 4 Oct 2017 14:41:06 -0400
Subject: [PATCH] adicionada menu para o keras, criada a classe do keras

---
 src/classification/__init__.py  |   7 ++
 src/classification/cnn_keras.py | 151 ++++++++++++++++++++++++++++++++
 2 files changed, 158 insertions(+)
 create mode 100644 src/classification/cnn_keras.py

diff --git a/src/classification/__init__.py b/src/classification/__init__.py
index 99cb77b..ab7294e 100644
--- a/src/classification/__init__.py
+++ b/src/classification/__init__.py
@@ -10,6 +10,11 @@ try:
 except:
     CNNCaffe = None
 
+try:
+    from .cnn_keras import CNNKeras
+except:
+    CNNKeras = None
+
 __all__ = ["cnn_caffe",
            "classifier",
            "weka_classifiers"]
@@ -22,6 +27,8 @@ from util.config import Config
 _classifier_list = OrderedDict( [ 
                             ["cnn_caffe", Config("Invalid" if CNNCaffe is None else CNNCaffe.__name__,
                                 WekaClassifiers is None and CNNCaffe is not None, bool, meta=CNNCaffe, hidden=CNNCaffe is None)],
+                            ["cnn_keras", Config("Invalid" if CNNKeras is None else CNNKeras.__name__,
+                                CNNKeras is not None, bool, meta=CNNKeras, hidden=CNNKeras is None)],
                             ["weka_classifiers", Config("Invalid" if WekaClassifiers is None else WekaClassifiers.__name__,
                                 WekaClassifiers is not None, bool, meta=WekaClassifiers, hidden=WekaClassifiers is None)]
                         ] )
diff --git a/src/classification/cnn_keras.py b/src/classification/cnn_keras.py
new file mode 100644
index 0000000..d5e01ba
--- /dev/null
+++ b/src/classification/cnn_keras.py
@@ -0,0 +1,151 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+#
+"""
+    Generic classifier with multiple models
+    Models -> (Xception, VGG16, VGG19, ResNet50, InceptionV3, MobileNet)
+    Name: cnn_keras.py
+    Author: Gabriel Kirsten Menezes (gabriel.kirsten@hotmail.com)
+
+"""
+
+import time
+import os
+from keras import applications
+from keras.preprocessing.image import ImageDataGenerator
+from keras import optimizers
+from keras.models import Model
+from keras.layers import Dropout, Flatten, Dense
+from keras.callbacks import ModelCheckpoint
+from sklearn.metrics import confusion_matrix
+
+from classifier import Classifier
+
+from collections import OrderedDict
+
+from util.config import Config
+from util.file_utils import File
+from util.utils import TimeUtils
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress warnings
+START_TIME = time.time()
+
+# =========================================================
+# Constants
+# =========================================================
+
+IMG_WIDTH, IMG_HEIGHT = 256, 256
+TRAIN_DATA_DIR = "../data/train"
+VALIDATION_DATA_DIR = "../data/validation"
+BATCH_SIZE = 16
+EPOCHS = 50
+CLASS_NAMES = ['ferrugemAsiatica', 'folhaSaudavel',
+               'fundo', 'manchaAlvo', 'mildio', 'oidio']
+
+class CNNKeras(Classifier):
+
+    def __init__(self):
+        model = applications.VGG16(
+            weights="imagenet", include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
+
+        for layer in model.layers:
+            layer.trainable = True
+
+        # Adding custom Layers
+        new_custom_layers = model.output
+        new_custom_layers = Flatten()(new_custom_layers)
+        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
+        new_custom_layers = Dropout(0.5)(new_custom_layers)
+        new_custom_layers = Dense(1024, activation="relu")(new_custom_layers)
+        predictions = Dense(6, activation="softmax")(new_custom_layers)
+
+        # creating the final model
+        model_final = Model(inputs=model.input, outputs=predictions)
+
+        # compile the model
+        model_final.compile(loss="categorical_crossentropy",
+                            optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
+                            metrics=["accuracy"])
+
+    def get_config(self):
+        """Return configuration of classifier. 
+        
+        Returns
+        -------
+        config : OrderedDict
+            Current configs of classifier.
+        """
+        pass
+
+    def set_config(self, configs):
+        """Update configuration of classifier. 
+        
+        Parameters
+        ----------
+        configs : OrderedDict
+            New configs of classifier.
+        """
+        pass
+
+    def get_summary_config(self):
+        """Return fomatted summary of configuration. 
+        
+        Returns
+        -------
+        summary : string
+            Formatted string with summary of configuration.
+        """
+        pass 
+
+    def classify(self, dataset, test_dir, test_data):
+        pass
+
+    def train(self, model_final):
+        
+        # Initiate the train and test generators with data Augumentation
+        train_datagen = ImageDataGenerator(
+            rescale=1. / 255,
+            horizontal_flip=True,
+            fill_mode="nearest",
+            zoom_range=0.3,
+            width_shift_range=0.3,
+            height_shift_range=0.3,
+            rotation_range=30)
+
+        train_generator = train_datagen.flow_from_directory(
+            TRAIN_DATA_DIR,
+            target_size=(IMG_HEIGHT, IMG_WIDTH),
+            batch_size=BATCH_SIZE,
+            shuffle=True,
+            class_mode="categorical")
+
+        test_datagen = ImageDataGenerator(
+            rescale=1. / 255,
+            horizontal_flip=True,
+            fill_mode="nearest",
+            zoom_range=0.3,
+            width_shift_range=0.3,
+            height_shift_range=0.3,
+            rotation_range=30)
+
+        validation_generator = test_datagen.flow_from_directory(
+            VALIDATION_DATA_DIR,
+            target_size=(IMG_HEIGHT, IMG_WIDTH),
+            batch_size=BATCH_SIZE,
+            shuffle=True,
+            class_mode="categorical")
+
+        # Save the model according to the conditions
+        checkpoint = ModelCheckpoint("../models_checkpoints/" + file_name + ".h5", monitor='val_acc',
+                                    verbose=1, save_best_only=True, save_weights_only=False,
+                                    mode='auto', period=1)
+
+        # Train the model
+        model_final.fit_generator(
+            train_generator,
+            steps_per_epoch=train_generator.samples // BATCH_SIZE,
+            epochs=EPOCHS,
+            callbacks=[checkpoint],
+            validation_data=validation_generator,
+            validation_steps=validation_generator.samples // BATCH_SIZE)
+
-- 
GitLab