The following code snippet is the corresponding code for retrieving tfrecords from the prepared tfrecords files just in the previous post. Hope it's useful for relieving some difficulties for beginners.
'''
@author: Yurui Ming (yrming@gmail.com)
'''
import tensorflow as tf
import os
import skimage.io as io
class TFRecordPumper(object):
'''
classdocs
'''
def __init__(self, graph = None, sess = None):
'''
Constructor
'''
if graph == None:
self._graph = tf.Graph()
else:
self._graph = graph
if sess == None:
self._sess = tf.Session(graph = self._graph)
self._self_sess = True
else:
self._sess = sess
self._self_sess = False
def __exit__(self):
if self._coord:
self._coord.request_stop()
self._coord.join(self._threads)
if self._self_sess == True:
self._sess.close()
def Pump(self, tfr_dir, tfr_basename, batch_size = 2, features = None, img_shape = None,
capacity = 10, num_threads = 1, min_after_dequeue = 5):
'''
Pump
pumping out tfrecords
Args:
tfr_dir: directory contains tfrecords file
tfr_basename: basename pattern for collecting tfrecords files
batch_size: batch number of tfrecords to pump each time
features: features describing tfrecords
'''
# assume the most general feature if nono provided
if features == None:
features = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([1], tf.int64)
}
with self._graph.as_default():
ptn = os.path.join(tfr_dir, tfr_basename + "*.tfrecords")
filenames = tf.train.match_filenames_once(ptn)
tf_record_filename_queue = tf.train.string_input_producer(filenames)
# Notice the different record reader, this one is designed to work with TFRecord files which may
# have more than one example in them.
tf_record_reader = tf.TFRecordReader()
_, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
# The label and image are stored as bytes but could be stored as int64 or float64 values in
# serialized tf.Example protobuf.
if 'train' in tfr_basename:
label_key = 'train/label'
image_key = 'train/image'
elif 'xval' in tfr_basename:
label_key = 'xval/label'
image_key = 'xval/image'
elif 'test' in tfr_basename:
label_key = 'test/label'
image_key = 'test/image'
else:
label_key = 'label'
image_key = 'image'
tf_record_features = tf.parse_single_example(tf_record_serialized,
features = {
label_key: tf.FixedLenFeature([], tf.int64),
image_key: tf.FixedLenFeature([], tf.string),
})
# Using tf.uint8 because all of the channel information is between 0-255
tf_record_image = tf.reshape(tf_record_features[image_key], [])
tf_record_image = tf.decode_raw(tf_record_image, tf.uint8)
# Reshape the image to look like the image saved, not required
if img_shape:
tf_record_image = tf.reshape(tf_record_image, img_shape)
# Use real values for the height, width and channels of the image because it's required
# to reshape the input.
tf_record_label = tf_record_features[label_key];
images, labels = tf.train.shuffle_batch([tf_record_image, tf_record_label],
batch_size = batch_size,
capacity = capacity,
min_after_dequeue = min_after_dequeue,
num_threads = num_threads)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with self._sess.as_default():
self._sess.run(init_op)
self._coord = tf.train.Coordinator()
self._threads = tf.train.start_queue_runners(coord = self._coord)
#images, labels = self._sess.run([tf_record_image, tf_record_label])
yield self._sess.run([images, labels])
if __name__ == '__main__':
tf_pumper = TFRecordPumper()
#images, labels = tf_pumper.Pump('', 'train', img_shape = [64, 64, 3])
images, labels = next(tf_pumper.Pump('', 'xval', img_shape = [64, 64, 3]))
for i in range(images.shape[0]):
io.imshow(images[i, ...])
io.show()
'''
@author: Yurui Ming (yrming@gmail.com)
'''
import tensorflow as tf
import os
import skimage.io as io
class TFRecordPumper(object):
'''
classdocs
'''
def __init__(self, graph = None, sess = None):
'''
Constructor
'''
if graph == None:
self._graph = tf.Graph()
else:
self._graph = graph
if sess == None:
self._sess = tf.Session(graph = self._graph)
self._self_sess = True
else:
self._sess = sess
self._self_sess = False
def __exit__(self):
if self._coord:
self._coord.request_stop()
self._coord.join(self._threads)
if self._self_sess == True:
self._sess.close()
def Pump(self, tfr_dir, tfr_basename, batch_size = 2, features = None, img_shape = None,
capacity = 10, num_threads = 1, min_after_dequeue = 5):
'''
Pump
pumping out tfrecords
Args:
tfr_dir: directory contains tfrecords file
tfr_basename: basename pattern for collecting tfrecords files
batch_size: batch number of tfrecords to pump each time
features: features describing tfrecords
'''
# assume the most general feature if nono provided
if features == None:
features = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([1], tf.int64)
}
with self._graph.as_default():
ptn = os.path.join(tfr_dir, tfr_basename + "*.tfrecords")
filenames = tf.train.match_filenames_once(ptn)
tf_record_filename_queue = tf.train.string_input_producer(filenames)
# Notice the different record reader, this one is designed to work with TFRecord files which may
# have more than one example in them.
tf_record_reader = tf.TFRecordReader()
_, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
# The label and image are stored as bytes but could be stored as int64 or float64 values in
# serialized tf.Example protobuf.
if 'train' in tfr_basename:
label_key = 'train/label'
image_key = 'train/image'
elif 'xval' in tfr_basename:
label_key = 'xval/label'
image_key = 'xval/image'
elif 'test' in tfr_basename:
label_key = 'test/label'
image_key = 'test/image'
else:
label_key = 'label'
image_key = 'image'
tf_record_features = tf.parse_single_example(tf_record_serialized,
features = {
label_key: tf.FixedLenFeature([], tf.int64),
image_key: tf.FixedLenFeature([], tf.string),
})
# Using tf.uint8 because all of the channel information is between 0-255
tf_record_image = tf.reshape(tf_record_features[image_key], [])
tf_record_image = tf.decode_raw(tf_record_image, tf.uint8)
# Reshape the image to look like the image saved, not required
if img_shape:
tf_record_image = tf.reshape(tf_record_image, img_shape)
# Use real values for the height, width and channels of the image because it's required
# to reshape the input.
tf_record_label = tf_record_features[label_key];
images, labels = tf.train.shuffle_batch([tf_record_image, tf_record_label],
batch_size = batch_size,
capacity = capacity,
min_after_dequeue = min_after_dequeue,
num_threads = num_threads)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with self._sess.as_default():
self._sess.run(init_op)
self._coord = tf.train.Coordinator()
self._threads = tf.train.start_queue_runners(coord = self._coord)
#images, labels = self._sess.run([tf_record_image, tf_record_label])
yield self._sess.run([images, labels])
if __name__ == '__main__':
tf_pumper = TFRecordPumper()
#images, labels = tf_pumper.Pump('', 'train', img_shape = [64, 64, 3])
images, labels = next(tf_pumper.Pump('', 'xval', img_shape = [64, 64, 3]))
for i in range(images.shape[0]):
io.imshow(images[i, ...])
io.show()
No comments:
Post a Comment