Saturday, May 20, 2017

How to prepare tfrecords utilizing TensorFlow for training models (II)

The following code snippet is the corresponding code for retrieving tfrecords from the prepared tfrecords files just in the previous post. Hope it's useful for relieving some difficulties for beginners.

'''
@author: Yurui Ming (yrming@gmail.com)
'''
import tensorflow as tf
import os
import skimage.io as io

class TFRecordPumper(object):
    '''
    classdocs
    '''

    def __init__(self, graph = None, sess = None):
        '''
        Constructor
        '''
        if graph == None:
            self._graph = tf.Graph()
        else:
            self._graph = graph
        
        if sess == None:
            self._sess = tf.Session(graph = self._graph)
            self._self_sess = True
        else:
            self._sess = sess
            self._self_sess = False
    
    def __exit__(self):
        if self._coord:
            self._coord.request_stop()
            self._coord.join(self._threads)
        
        if self._self_sess == True:
            self._sess.close()
        
    
    def Pump(self, tfr_dir, tfr_basename, batch_size = 2, features = None, img_shape = None,
             capacity = 10, num_threads = 1, min_after_dequeue = 5):
        '''
        Pump
        pumping out tfrecords
        Args:
            tfr_dir: directory contains tfrecords file
            tfr_basename: basename pattern for collecting tfrecords files
            batch_size: batch number of tfrecords to pump each time
            features: features describing tfrecords
        '''
        
        # assume the most general feature if nono provided
        if features == None:
            features = {'image': tf.FixedLenFeature([], tf.string),
                        'label': tf.FixedLenFeature([1], tf.int64)
                        }
        
        with self._graph.as_default():
            ptn = os.path.join(tfr_dir, tfr_basename + "*.tfrecords")
        
            filenames = tf.train.match_filenames_once(ptn)
            
            tf_record_filename_queue = tf.train.string_input_producer(filenames)
            
            # Notice the different record reader, this one is designed to work with TFRecord files which may
            # have more than one example in them.
            
            tf_record_reader = tf.TFRecordReader()
            _, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
            
            # The label and image are stored as bytes but could be stored as int64 or float64 values in 
            # serialized tf.Example protobuf.
            if 'train' in tfr_basename:
                label_key = 'train/label'
                image_key = 'train/image'
            elif 'xval' in tfr_basename:
                label_key = 'xval/label'
                image_key = 'xval/image'
            elif 'test' in tfr_basename:
                label_key = 'test/label'
                image_key = 'test/image'
            else:
                label_key = 'label'
                image_key = 'image'
                
            tf_record_features = tf.parse_single_example(tf_record_serialized,
                                                         features = {
                                                             label_key: tf.FixedLenFeature([], tf.int64),
                                                             image_key: tf.FixedLenFeature([], tf.string),
                                                             })
            
            # Using tf.uint8 because all of the channel information is between 0-255
            tf_record_image = tf.reshape(tf_record_features[image_key], [])
            
            tf_record_image = tf.decode_raw(tf_record_image, tf.uint8)
            
            # Reshape the image to look like the image saved, not required
            if img_shape:
                tf_record_image = tf.reshape(tf_record_image, img_shape)
            
            # Use real values for the height, width and channels of the image because it's required
            # to reshape the input.
            
            tf_record_label = tf_record_features[label_key];
            
            
            images, labels = tf.train.shuffle_batch([tf_record_image, tf_record_label],
                                                    batch_size = batch_size,
                                                    capacity = capacity,
                                                    min_after_dequeue = min_after_dequeue,
                                                    num_threads = num_threads)
            
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

            with self._sess.as_default():
                self._sess.run(init_op)
                
                self._coord = tf.train.Coordinator()
                self._threads = tf.train.start_queue_runners(coord = self._coord)

                #images, labels = self._sess.run([tf_record_image, tf_record_label])
                
                yield self._sess.run([images, labels])
            
                
if __name__ == '__main__':
    tf_pumper = TFRecordPumper()
    #images, labels = tf_pumper.Pump('', 'train', img_shape = [64, 64, 3])

    images, labels = next(tf_pumper.Pump('', 'xval', img_shape = [64, 64, 3]))
    
    for i in range(images.shape[0]):
        io.imshow(images[i, ...])
        
    io.show() 

No comments:

Post a Comment