feature_extraction.py 3.09 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
"""
    Runs feature extraction algorithms.
    
    Name: feature_extractor.py
    Author: Alessandro dos Santos Ferreira ( santosferreira.alessandro@gmail.com )
"""

import io
import itertools
import os

from interface.interface import InterfaceException as IException

from util.file_utils import File
from util.utils import ImageUtils
from util.utils import TimeUtils

from extractor import Extractor

class FeatureExtractor(object):
    
    def __init__(self, extractors):
        self.extractors = extractors
    
    def extract_all(self, dataset, output_file = None, classes = None, overwrite = True):
        
        if len(self.extractors) == 0:
            raise IException("Please select at least one extractor")
        
        if output_file is None:
            output_file = File.get_filename(dataset)
        output_file = File.make_path(dataset, output_file + '.arff')
            
        if overwrite == False and os.path.isfile(output_file):
            return output_file
        
        classes = sorted(File.list_dirs(dataset)) if classes is None else classes
        
        start_time = TimeUtils.get_time()
        
        data = []
        
        for cl in classes:
            items = os.listdir( File.make_path(dataset, cl) )
            print("Processing class %s - %d itens" % (cl, len(items)))
            
            for item in items:
                if item.startswith('.'):
                    continue 
                
                try:
                    filepath = File.make_path(dataset, cl, item)
                    image = File.open_image(filepath, rgb = False )
                except:
                    raise IException("Image %s is possibly corrupt" % filepath)
                
                if len(data) > 0:
                    values = list(itertools.chain.from_iterable(zip(*([extractor().run(image) for extractor in self.extractors]))[2] ))
                else:          
                    labels, types, values = [ list(itertools.chain.from_iterable(ret))
                                                for ret in zip(*([extractor().run(image) for extractor in self.extractors])) ]
                data.append(values + [cl])
                
        if len(data) == 0:
            raise IException("There are no images in dataset: %s" % dataset)
            
        self._save_output(File.get_filename(dataset), classes, labels, types, data, output_file)
        
        end_time = TimeUtils.get_time()

        return output_file, (end_time - start_time)

            
    def _save_output(self, relation, classes, labels, types, data, output_file):

        arff = open(output_file,'wb')

        arff.write("%s %s\n\n" % ('@relation ', relation))

        for label, t in zip(labels,types):
            arff.write("%s %s %s\n" % ('@attribute', label, t))

        arff.write("%s %s {%s}\n\n" % ('@attribute','classe',', '.join(classes)))

        arff.write('@data\n\n')

        for instance in data:
            instance = map(str, instance)
            line = ",".join(instance)
            arff.write(line+"\n")

        arff.close()