Thursday, November 10, 2016

Handwritten digits recognition via TensorFlow based on Windows MFC (II) - Train the MNist model

Now let's approach the second episode of this series.
Now we will train the MNist model and later load it to do prediction.

WARNING: The record surrounded by asterisks below is the first try of my work, however, due to poor recording during the process, I don't remember it's workable or not, I just log here for my reference. Please safely neglected it, directly jump to the lines below the second asterisk line


****************
For simplicity, it will directly utilize the existing one, namely fully_connected_feed.py, to obtain such a model.

First clone the tensorflow source code from github.
Since the master branch has some problem with regards running fully_connected_feed.py, so we have to rebase the branch to the stable r0.11 branch:
git checkout r0.11
Then check everything ok:
git status

Second make a copy of the original fully_connected_feed.py, since we have no intention to contaminate the original one:
cp  fully_connected_feed.py fully_connected_feed2.py

Third add an named op under the default graph clause:
    # Create the recognizer
    digit = tf.argmax(tf.nn.softmax(logits), 1, name = 'recognize')

The fourth step is to get rid of the intermediate checkpoint files, so the condition becomes:
      if (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
        saver.save(sess, checkpoint_file, global_step=step)

The fifth step is to save the model:
    tf.train.write_graph(sess.graph_def, 'models/', 'mnist-mlp.pb')


And the last step is to get rid of reluctant things:
python /opt/tensorflow/tensorflow/python/tools/freeze_graph.py  --input_checkpoint=data/checkpoint-1999 --input_graph=models/mnist-mlp.pb --output_graph=models/frozenn-mnist-mlp.pb --output_node_names=recognize

Now we have the model file prepared.
****************


For the model preparation process, I highly referred to the post by Jacky Tung, the address is http://jackytung8085.blogspot.kr/2016/06/loading-tensorflow-graph-with-c-api-by.html.

I git cloned everything of his work on github, done the model based on the existing work. It's quite easy to follow, so I just mention something tricky.

1. First it probably complains the some file doesn't exist, it's due to recursively creating file, so first create a directory named "models" under the mnist.py script directory.
Or speaking alternatively, when creating some file like foo/bar, the "-p" option may be mandatory for some OS.

2. When reducing the file, probably the freeze_graph script will complain the checkpoint file doesn't exist, I guess freeze_graph probably assume cwd is where it resides, so passing the absolute path for model file and checkpoint file

3. It seems when doing inference, tf.argmax can further reduce the work of retrieving result, however, I guess since it doesn't associate with an obvious derivative, so it seems freeze_graph strips it from the file model file, even if you add it in the graph. So probably we have to analyze the result returned by softmax manually.

4. Since freeze_graph is a python script, seems it prefers that graph and checkpoint files are all in text format, if you save them as binaries, probably it will complains decoding fault.

Welcome further discussion.

Keep learning, keep tweaking!

No comments:

Post a Comment