The merit of utilizing tfrecords is manifest, since high throughput of feeding can obviously keep the training iteration from starving. A precondition is one should have tfrecords prepared before launching the whole process. The general guidelines could be easily understood however since example codes are scattered here and there, so it's not easy for assembling the snippets to form something actually workable. The following code has such an aim and intention in mind, so hope it's useful for everybody's work concerning deep learning. BTW no hesitate for providing any feedback concerning improvement of the code quality.
'''
@author: Yurui Ming (yrming@gmail.com)
'''
import numpy as np
import tensorflow as tf
import os
class TFRecordGenerator(object):
'''
classdocs
'''
def __init__(self, params = None):
'''
Constructor
'''
self._graph = tf.Graph()
def _int64_feature(self, value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(self, value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def Generate(self, img_dir, img_fmt = None, img_shape = [64, 64], partition = [0.8, 0.1, 0.1],
train_tfrecord_base_name = 'train{}.tfrecords',
xval_tfrecord_base_name = 'xval{}.tfrecords',
test_tfrecord_base_name = 'test{}.tfrecords',
split_unit = 500):
'''
Generate
Generate TFRecord files
Three categories of TFRecord files will be generated, namely, training category, cross-validating category and testing category
Args:
img_dir: directory containing the images. The label should be decided from the training name
img_fmt: image encoding standard, e.g., jpeg or png
partition: portions of percentage of each category, namely, training, cross-validating and testing
train_tfrecord_base_name: base training tfrecord file name paradigm for generating training tfrecord file name
xval_tfrecord_base_name: base cross-validating tfrecord file name paradigm for generating cross-validating tfrecord file name
test_tfrecord_base_name: base testing tfrecord file name paradigm for generating testing tfrecord file name
split_unit: number of accumulated tfrecords in each tfrecord file
'''
if not img_fmt:
raise ValueError('Unspecified image format')
with self._graph.as_default():
ptn = None
if 'jpg' in img_fmt:
ptn = os.path.join(img_dir, '*.jpg')
if 'png' in img_fmt:
ptn = os.path.join(img_dir, '*.png')
if not ptn:
raise ValueError('Unsupported image format')
filenames = tf.train.match_filenames_once(ptn)
filename_queue = tf.train.string_input_producer(filenames)
image_reader = tf.WholeFileReader()
image_key, image_file = image_reader.read(filename_queue)
if 'jpg' in img_fmt:
image_data = tf.image.decode_jpeg(image_file)
if 'png' in img_fmt:
image_data = tf.image.decode_png(image_file)
image_data_shape = tf.shape(image_data)
if img_shape:
image_data = tf.cond(image_data_shape[0] > image_data_shape[1], \
lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[1], image_data_shape[1]),
lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[0], image_data_shape[0]))
image_data = tf.image.resize_images(image_data, img_shape)
image_data = tf.cast(image_data, tf.uint8)
#image_data = tf.image.encode_jpeg(image_data);
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
num_files = len(sess.run(filenames))
if np.sum(partition) > 1:
raise ValueError('Invalid partition')
partition = [v * num_files for v in partition]
# training tfrecord category
writer = None
for i in range(int(partition[0])):
if not i % split_unit:
if writer:
writer.close()
train_filename = train_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(train_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'train/label': self._int64_feature(label),
'train/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
for i in range(int(partition[1])):
if not i % split_unit:
if writer:
writer.close()
xval_filename = xval_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(xval_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'xval/label': self._int64_feature(label),
'xval/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
for i in range(int(partition[2])):
if not i % split_unit:
if writer:
writer.close()
test_filename = test_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(test_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'test/label': self._int64_feature(label),
'test/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf_generator = TFRecordGenerator()
tf_generator.Generate('C:\\Users\\MSUser\\Downloads\\mytest', 'jpg')
'''
@author: Yurui Ming (yrming@gmail.com)
'''
import numpy as np
import tensorflow as tf
import os
class TFRecordGenerator(object):
'''
classdocs
'''
def __init__(self, params = None):
'''
Constructor
'''
self._graph = tf.Graph()
def _int64_feature(self, value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(self, value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def Generate(self, img_dir, img_fmt = None, img_shape = [64, 64], partition = [0.8, 0.1, 0.1],
train_tfrecord_base_name = 'train{}.tfrecords',
xval_tfrecord_base_name = 'xval{}.tfrecords',
test_tfrecord_base_name = 'test{}.tfrecords',
split_unit = 500):
'''
Generate
Generate TFRecord files
Three categories of TFRecord files will be generated, namely, training category, cross-validating category and testing category
Args:
img_dir: directory containing the images. The label should be decided from the training name
img_fmt: image encoding standard, e.g., jpeg or png
partition: portions of percentage of each category, namely, training, cross-validating and testing
train_tfrecord_base_name: base training tfrecord file name paradigm for generating training tfrecord file name
xval_tfrecord_base_name: base cross-validating tfrecord file name paradigm for generating cross-validating tfrecord file name
test_tfrecord_base_name: base testing tfrecord file name paradigm for generating testing tfrecord file name
split_unit: number of accumulated tfrecords in each tfrecord file
'''
if not img_fmt:
raise ValueError('Unspecified image format')
with self._graph.as_default():
ptn = None
if 'jpg' in img_fmt:
ptn = os.path.join(img_dir, '*.jpg')
if 'png' in img_fmt:
ptn = os.path.join(img_dir, '*.png')
if not ptn:
raise ValueError('Unsupported image format')
filenames = tf.train.match_filenames_once(ptn)
filename_queue = tf.train.string_input_producer(filenames)
image_reader = tf.WholeFileReader()
image_key, image_file = image_reader.read(filename_queue)
if 'jpg' in img_fmt:
image_data = tf.image.decode_jpeg(image_file)
if 'png' in img_fmt:
image_data = tf.image.decode_png(image_file)
image_data_shape = tf.shape(image_data)
if img_shape:
image_data = tf.cond(image_data_shape[0] > image_data_shape[1], \
lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[1], image_data_shape[1]),
lambda: tf.image.resize_image_with_crop_or_pad(image_data, image_data_shape[0], image_data_shape[0]))
image_data = tf.image.resize_images(image_data, img_shape)
image_data = tf.cast(image_data, tf.uint8)
#image_data = tf.image.encode_jpeg(image_data);
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
num_files = len(sess.run(filenames))
if np.sum(partition) > 1:
raise ValueError('Invalid partition')
partition = [v * num_files for v in partition]
# training tfrecord category
writer = None
for i in range(int(partition[0])):
if not i % split_unit:
if writer:
writer.close()
train_filename = train_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(train_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'train/label': self._int64_feature(label),
'train/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
for i in range(int(partition[1])):
if not i % split_unit:
if writer:
writer.close()
xval_filename = xval_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(xval_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'xval/label': self._int64_feature(label),
'xval/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
for i in range(int(partition[2])):
if not i % split_unit:
if writer:
writer.close()
test_filename = test_tfrecord_base_name.format(i)
writer = tf.python_io.TFRecordWriter(test_filename)
image_label, image_cont = sess.run([image_key, image_data])
if b'cat' in image_label:
label = 0
elif b'dog' in image_label:
label = 1
else:
raise ValueError('Invalid file name: {}'.format(image_label))
feature = {
'test/label': self._int64_feature(label),
'test/image': self._bytes_feature(image_cont.tobytes())
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
writer.close()
writer = None
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf_generator = TFRecordGenerator()
tf_generator.Generate('C:\\Users\\MSUser\\Downloads\\mytest', 'jpg')
No comments:
Post a Comment