Wednesday, July 18, 2018

A wrapper around batch_normalization

Usually I am using Sonnet, however recently an overlook of the document in-lined with source code made me thought there is a potential bug in the implementation. But when I turned to the implementation provided by TensorFlow, there is no better off. Lots of pitfalls here and there.

The following is a wrapper by me to demonstrate a user case of the routine, hope it will be useful. And I believe you know how to save and restore the variables, yes?

Enjoy coding no matter how frustrating.

import numpy as np
import tensorflow as tf
import sonnet as snt

from tensorflow.python.layers import normalization


class MyBatchNorm(object):
    def __init__(self):
        self._bn = normalization.BatchNormalization(axis = 1,
            epsilon = np.finfo(np.float32).eps, momentum = 0.9)

    def __call__(self, inputs, is_training = True, test_local_stats = False):
        outputs = self._bn(inputs, training = is_training)

        self._add_variable(self._bn.moving_mean)
        self._add_variable(self._bn.moving_variance)

        return outputs

    def _add_variable(self, var):
        if var not in tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

t = tf.truncated_normal([2, 4, 4, 2])


bn = MyBatchNorm()
bn2 = MyBatchNorm()

n = bn(t)
n2 = bn2(t)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    n = tf.identity(n)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    n_v, n2_v = sess.run([n, n2])

    print(tf.trainable_variables())
    print(tf.moving_average_variables())

No comments:

Post a Comment