Usually I am using Sonnet, however recently an overlook of the document in-lined with source code made me thought there is a potential bug in the implementation. But when I turned to the implementation provided by TensorFlow, there is no better off. Lots of pitfalls here and there.
The following is a wrapper by me to demonstrate a user case of the routine, hope it will be useful. And I believe you know how to save and restore the variables, yes?
Enjoy coding no matter how frustrating.
import numpy as np
import tensorflow as tf
import sonnet as snt
from tensorflow.python.layers import normalization
class MyBatchNorm(object):
def __init__(self):
self._bn = normalization.BatchNormalization(axis = 1,
epsilon = np.finfo(np.float32).eps, momentum = 0.9)
def __call__(self, inputs, is_training = True, test_local_stats = False):
outputs = self._bn(inputs, training = is_training)
self._add_variable(self._bn.moving_mean)
self._add_variable(self._bn.moving_variance)
return outputs
def _add_variable(self, var):
if var not in tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES):
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
t = tf.truncated_normal([2, 4, 4, 2])
bn = MyBatchNorm()
bn2 = MyBatchNorm()
n = bn(t)
n2 = bn2(t)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
n = tf.identity(n)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
n_v, n2_v = sess.run([n, n2])
print(tf.trainable_variables())
print(tf.moving_average_variables())
The following is a wrapper by me to demonstrate a user case of the routine, hope it will be useful. And I believe you know how to save and restore the variables, yes?
Enjoy coding no matter how frustrating.
import numpy as np
import tensorflow as tf
import sonnet as snt
from tensorflow.python.layers import normalization
class MyBatchNorm(object):
def __init__(self):
self._bn = normalization.BatchNormalization(axis = 1,
epsilon = np.finfo(np.float32).eps, momentum = 0.9)
def __call__(self, inputs, is_training = True, test_local_stats = False):
outputs = self._bn(inputs, training = is_training)
self._add_variable(self._bn.moving_mean)
self._add_variable(self._bn.moving_variance)
return outputs
def _add_variable(self, var):
if var not in tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES):
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
t = tf.truncated_normal([2, 4, 4, 2])
bn = MyBatchNorm()
bn2 = MyBatchNorm()
n = bn(t)
n2 = bn2(t)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
n = tf.identity(n)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
n_v, n2_v = sess.run([n, n2])
print(tf.trainable_variables())
print(tf.moving_average_variables())