Saturday, May 27, 2017

How to write a word recite program via TensorFlow and Sonnet

It's obviously if you give a network a word like "congratulations" and it will learn the correlation and will generate "congratulations". When considering RNN, there's should be a one or several step lag(s). For the situation in one step lag, that means if you input "congratulations[ ]", it will generate "[ ]congratulations". here "[ ]" means a space character. I have written a Matlab script for playing so, however when playing with Sonnet, to my surprise the speed is astonishing. Thank Google for always shipping great tools.

Following is the snippet, feel free to tweak with it and good luck:

import tensorflow as tf
import sonnet as snt

label_size = 27
hidden_size = 128
batch_size = 1

class MyOneHotData(snt.AbstractModule):
    def __init__(self, depth = label_size, on_value = 1.0, off_value = 0.0, name = 'my_one_hot_data'):
        super(MyOneHotData, self).__init__(name = name)
        self._on_value = on_value
        self._off_value = off_value
        self._depth = label_size

    def _build(self, inputs, axis = -1, append_head = None, append_tail = None):
        indices = [(ord(c) - 96) for c in inputs]
        if append_head:
            indices = [0] * append_head + indices
        if append_tail:
            indices = indices + [0] * append_tail

        return tf.one_hot(indices, self._depth, self._on_value, self._off_value, axis, tf.float32)

class MySoftmax(snt.AbstractModule):
    def __init__(self, hidden_size = hidden_size, label_size = label_size, name = "my_softmax"):
        super(MySoftmax, self).__init__(name = name)
        self._hidden_size = hidden_size
        self._label_size = label_size

    @snt.experimental.reuse_vars
    def _trans(self, inputs):
        w = tf.get_variable("w", shape = [self._hidden_size, self._label_size])
        b = tf.get_variable("b", shape = [self._label_size])
        return tf.matmul(inputs, w) + b
        
    def _build(self, inputs):
        unstack_along_time_series_inputs = tf.unstack(inputs)
        return tf.stack([self._trans(c) for c in unstack_along_time_series_inputs])
        

class MyRNN(snt.AbstractModule):
    def __init__(self, batch_size = batch_size, hidden_size = hidden_size, name = "my_rnn"):
        super(MyRNN, self).__init__(name = name)
        self._batch_size = batch_size
        self._hidden_size = hidden_size

    def _build(self, inputs):
        lstm = snt.LSTM(self._hidden_size)
        init_state = lstm.initial_state(self._batch_size)
        output_sequence, final_state = tf.nn.dynamic_rnn(lstm, inputs, initial_state = init_state, time_major = True)
        return output_sequence

class MyWord(snt.AbstractModule):
    def __init__(self, label_size = label_size, name = "my_word"):
        super(MyWord, self).__init__(name = name)
        self._label_size = label_size

    def _build(self, inputs):
        indices = tf.argmax(inputs, 1)
        chars = [tf.cond(tf.equal(indices[i], 0), lambda: tf.constant(32, tf.int64), lambda: indices[i] + 96) \
            for i in range(indices.get_shape().as_list()[0])]
        return chars

with tf.Session() as sess:
    my_one_hot_data = MyOneHotData()

    encoded_input = my_one_hot_data("congradulations", append_tail = 1)
    input_with_batch_dim = tf.expand_dims(encoded_input, axis = 1)

    my_rnn = MyRNN()
    outputs = my_rnn(input_with_batch_dim)


    my_softmax = MySoftmax()   
    label_pred_with_batch = my_softmax(outputs)
    
    label_pred = tf.squeeze(label_pred_with_batch, axis = 1)

    encoded_label = my_one_hot_data("congradulations", append_head = 1)

    loss = tf.nn.softmax_cross_entropy_with_logits(labels = encoded_label, logits = label_pred)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    total_regularization_loss = tf.reduce_sum(graph_regularizers)

    total_loss = tf.reduce_mean(loss) + total_regularization_loss

    train_op = tf.train.GradientDescentOptimizer(0.05).minimize(total_loss)    

    my_word = MyWord()
    chars = my_word(label_pred)

    tf.summary.scalar("model-loss", total_loss)
    summ_op = tf.summary.merge_all()


    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

    writer = tf.summary.FileWriter("char_pred_train", sess.graph)

    
    for i in range(1000):
        _, summaries = sess.run([train_op, summ_op])
        #writer.add_summary(summaries, global_step = i)
        sole_chars = sess.run(chars)
        
        print(''.join([chr(c) for c in sole_chars]))

    writer.close()

No comments:

Post a Comment