Saturday, May 20, 2017

How to prepare tfrecords utilizing TensorFlow for training models

The merit of utilizing tfrecords is manifest, since high throughput of feeding can obviously keep the training iteration from starving. A precondition is one should have tfrecords prepared before launching the whole process. The general guidelines could be easily understood however since example codes are scattered here and there, so it's not easy for assembling the snippets to form something actually workable. The following code has such an aim and intention in mind, so hope it's useful for everybody's work concerning deep learning. BTW no hesitate for providing any feedback concerning improvement of the code quality.

'''
@author: Yurui Ming (yrming@gmail.com)
'''
import numpy as np
import tensorflow as tf
import os

class TFRecordGenerator(object):
    '''
    classdocs
    '''
    def __init__(self, params = None):
        '''
        Constructor
        '''
        self._graph = tf.Graph()
    def _int64_feature(self, value):
        return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

    def _bytes_feature(self, value):
        return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
    
    def Generate(self, img_dir, img_fmt = None, img_shape = [64, 64], partition = [0.8, 0.1, 0.1], 
                  train_tfrecord_base_name = 'train{}.tfrecords',
                  xval_tfrecord_base_name = 'xval{}.tfrecords',
                  test_tfrecord_base_name = 'test{}.tfrecords', 
                  split_unit = 500):
        '''
        Generate
        Generate TFRecord files
        Three categories of TFRecord files will be generated, namely, training category, cross-validating category and testing category
        Args:
            img_dir: directory containing the images. The label should be decided from the training name
            img_fmt: image encoding standard, e.g., jpeg or png
            partition: portions of percentage of each category, namely, training, cross-validating and testing
            train_tfrecord_base_name: base training tfrecord file name paradigm for generating training tfrecord file name
            xval_tfrecord_base_name: base cross-validating tfrecord file name paradigm for generating cross-validating tfrecord file name
            test_tfrecord_base_name: base testing tfrecord file name paradigm for generating testing tfrecord file name
            split_unit: number of accumulated tfrecords in each tfrecord file 
        '''
        if not img_fmt:
            raise ValueError('Unspecified image format')
                    
        with self._graph.as_default():
            ptn = None
            if 'jpg' in img_fmt:
                ptn = os.path.join(img_dir, '*.jpg')
            if 'png' in img_fmt:
                ptn = os.path.join(img_dir, '*.png')
            if not ptn:
                raise ValueError('Unsupported image format')

            filenames = tf.train.match_filenames_once(ptn)
            filename_queue = tf.train.string_input_producer(filenames)
            image_reader = tf.WholeFileReader()
            image_key, image_file = image_reader.read(filename_queue)

            if 'jpg' in img_fmt:
                image_data = tf.image.decode_jpeg(image_file)
            if 'png' in img_fmt:
                image_data = tf.image.decode_png(image_file)
         
            image_data_shape = tf.shape(image_data)
            
            if img_shape:
                image_data = tf.cond(image_data_shape[0] > image_data_shape[1], \
                                     lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[1], image_data_shape[1]),
                                     lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[0], image_data_shape[0]))
                
                image_data = tf.image.resize_images(image_data, img_shape)        
            
            image_data = tf.cast(image_data, tf.uint8)
            
            #image_data = tf.image.encode_jpeg(image_data);
                
            init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            
            with tf.Session() as sess:
                sess.run(init)
                
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess = sess, coord = coord)
                
                num_files = len(sess.run(filenames))
                
                if np.sum(partition) > 1:
                    raise ValueError('Invalid partition')
                
                partition = [v * num_files for v in partition]
                
                # training tfrecord category
                writer = None
                for i in range(int(partition[0])):
                    if not i % split_unit:
                        if writer:
                            writer.close()
                        train_filename = train_tfrecord_base_name.format(i)
                        writer = tf.python_io.TFRecordWriter(train_filename)
                        
                    image_label, image_cont = sess.run([image_key, image_data])
                    
                    if b'cat' in image_label:
                        label = 0
                    elif b'dog' in image_label:
                        label = 1
                    else:
                        raise ValueError('Invalid file name: {}'.format(image_label))
                    
                    feature = {
                        'train/label': self._int64_feature(label),
                        'train/image': self._bytes_feature(image_cont.tobytes())
                        }

                    
                    example = tf.train.Example(features = tf.train.Features(feature = feature))
                    writer.write(example.SerializeToString())
                writer.close()
                
                writer = None
                for i in range(int(partition[1])):
                    if not i % split_unit:
                        if writer:
                            writer.close()
                        xval_filename = xval_tfrecord_base_name.format(i)
                        writer = tf.python_io.TFRecordWriter(xval_filename)

                    image_label, image_cont = sess.run([image_key, image_data])

                    if b'cat' in image_label:
                        label = 0
                    elif b'dog' in image_label:
                        label = 1
                    else:
                        raise ValueError('Invalid file name: {}'.format(image_label))
                    
                    feature = {
                        'xval/label': self._int64_feature(label),
                        'xval/image': self._bytes_feature(image_cont.tobytes())
                        }

                    example = tf.train.Example(features = tf.train.Features(feature = feature))
                    writer.write(example.SerializeToString())
                writer.close()
                
                writer = None
                for i in range(int(partition[2])):
                    if not i % split_unit:
                        if writer:
                            writer.close()
                        test_filename = test_tfrecord_base_name.format(i)
                        writer = tf.python_io.TFRecordWriter(test_filename)

                    image_label, image_cont = sess.run([image_key, image_data])

                    if b'cat' in image_label:
                        label = 0
                    elif b'dog' in image_label:
                        label = 1
                    else:
                        raise ValueError('Invalid file name: {}'.format(image_label))
                    
                    feature = {
                        'test/label': self._int64_feature(label),
                        'test/image': self._bytes_feature(image_cont.tobytes())
                        }

                    example = tf.train.Example(features = tf.train.Features(feature = feature))
                    writer.write(example.SerializeToString())

                
                writer.close()
                writer = None
            
                coord.request_stop()
                coord.join(threads)
            
if __name__ == '__main__':
    tf_generator = TFRecordGenerator()
    tf_generator.Generate('C:\\Users\\MSUser\\Downloads\\mytest', 'jpg')

No comments:

Post a Comment