Commit 2b30dbf4 authored by Gabriel Kirsten's avatar Gabriel Kirsten

removed  perc traini and perc validation to split dataset and fixed added create model function
parent 7e36a9c6
......@@ -48,7 +48,6 @@ logger.setLevel(logging.WARNING)
# Constants
# =========================================================
IMG_WIDTH, IMG_HEIGHT = 256, 256
weight_path = None
START_TIME = time.time()
......@@ -65,8 +64,6 @@ class CNNPseudoLabel(Classifier):
fine_tuning_rate=100,
transfer_learning=False,
save_weights=True,
perc_train=80,
perc_validation=20,
recreate_dataset=False,
train_data_directory="",
validation_data_directory="",
......@@ -92,10 +89,6 @@ class CNNPseudoLabel(Classifier):
"Transfer Learning", transfer_learning, bool)
self.save_weights = Config(
"Save weights", save_weights, bool)
self.perc_train = Config(
"Perc Train", perc_train, float)
self.perc_validation = Config(
"Perc Validation", perc_validation, float)
self.recreate_dataset = Config(
"Recreate Dataset", recreate_dataset, bool)
self.train_data_directory = Config(
......@@ -108,7 +101,7 @@ class CNNPseudoLabel(Classifier):
"No label data directory", no_label_data_directory, str)
self.model=None
self.pseudo_label=None
self.trained=False
def get_config(self):
......@@ -129,8 +122,6 @@ class CNNPseudoLabel(Classifier):
keras_config["Fine Tuning rate"]=self.fine_tuning_rate
keras_config["Transfer Learning"]=self.transfer_learning
keras_config["Save weights"]=self.save_weights
keras_config["Perc Train"]=self.perc_train
keras_config["Perc Validation"]=self.perc_validation
keras_config["Recreate Dataset"]=self.recreate_dataset
keras_config["Train data directory"]=self.train_data_directory
keras_config["Validation data directory"]=self.validation_data_directory
......@@ -160,10 +151,6 @@ class CNNPseudoLabel(Classifier):
configs["Transfer Learning"], self.transfer_learning)
self.save_weights=Config.nvl_config(
configs["Save weights"], self.save_weights)
self.perc_train=Config.nvl_config(
configs["Perc Train"], self.perc_train)
self.perc_validation=Config.nvl_config(
configs["Perc Validation"], self.perc_validation)
self.recreate_dataset=Config.nvl_config(
configs["Recreate Dataset"], self.recreate_dataset)
self.train_data_directory = Config.nvl_config(
......@@ -193,8 +180,6 @@ class CNNPseudoLabel(Classifier):
keras_config[self.fine_tuning_rate.label]=self.fine_tuning_rate.value
keras_config[self.transfer_learning.label]=self.transfer_learning.value
keras_config[self.save_weights.label]=self.save_weights.value
keras_config[self.perc_train.label]=self.perc_train.value
keras_config[self.perc_validation.label]=self.perc_validation.value
keras_config[self.recreate_dataset.label]=self.recreate_dataset.value
keras_config[self.train_data_directory.label]=self.train_data_directory.value
keras_config[self.validation_data_directory.label]=self.validation_data_directory.value
......@@ -249,21 +234,19 @@ class CNNPseudoLabel(Classifier):
classify_generator=classify_datagen.flow_from_directory(
File.make_path(predict_directory, 'png'),
taet_size=(IMG_HEIGHT, IMG_WIDTH),
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=1,
shuffle=False,
class_mode=None)
try:
K.clear_session()
if self.weight_path is not None:
self.model=load_model(self.weight_path)
path_classes=self.weight_path.replace(
"_model.h5", "_classes.npy")
CLASS_NAMES=np.load(path_classes).item().keys()
if self.pseudo_label.weight_path is not None:
self.create_model()
self.model.load_weights(self.pseudo_label.weight_path)
except Exception, e:
raise IException("Can't load the model in " +
self.weight_path + str(e))
self.pseudo_label.weight_path + str(e))
output_classification=self.model.predict_generator(
classify_generator, classify_generator.samples, verbose=2)
......@@ -287,29 +270,31 @@ class CNNPseudoLabel(Classifier):
If False don't perform new training if there is trained data.
"""
pseudo_label=PseudoLabel(image_width=IMG_WIDTH,
image_height=IMG_HEIGHT,
image_channels=3,
train_data_directory=self.train_data_directory.value,
validation_data_directory=self.validation_data_directory.value,
test_data_directory=self.test_data_directory.value,
no_label_data_directory=self.no_label_data_directory.value,
epochs=self.epochs.value,
batch_size=self.batch_size.value,
pseudo_label_batch_size=self.batch_size.value*2,
transfer_learning={
'use_transfer_learning': self.transfer_learning.value,
'fine_tuning': self.fine_tuning_rate.value
},
architecture=self.architecture.value,
alpha=0.1)
self.model=pseudo_label.model
pseudo_label.fit_with_pseudo_label(
steps_per_epoch=pseudo_label.train_generator.samples // self.batch_size.value,
validation_steps=pseudo_label.validation_generator.samples // self.batch_size.value)
self.create_model()
self.pseudo_label.fit_with_pseudo_label(
steps_per_epoch=self.pseudo_label.train_generator.samples // self.batch_size.value,
validation_steps=self.pseudo_label.validation_generator.samples // self.batch_size.value)
def create_model(self):
self.pseudo_label=PseudoLabel(image_width=IMG_WIDTH,
image_height=IMG_HEIGHT,
image_channels=3,
train_data_directory=self.train_data_directory.value,
validation_data_directory=self.validation_data_directory.value,
test_data_directory=self.test_data_directory.value,
no_label_data_directory=self.no_label_data_directory.value,
epochs=self.epochs.value,
batch_size=self.batch_size.value,
pseudo_label_batch_size=self.batch_size.value*2,
transfer_learning={
'use_transfer_learning': self.transfer_learning.value,
'fine_tuning': self.fine_tuning_rate.value
},
architecture=self.architecture.value,
alpha=0.1)
self.model=self.pseudo_label.model
def must_train(self):
"""Return if classifier must be trained.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment