Thursday, November 10, 2016

Handwritten digits recognition via TensorFlow based on Windows MFC (IV) - Load trained model

I think two good article have detailed everything, thanks a lot to their efforts:
https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.t78tjznzu, by Jim Fleming; http://jackytung8085.blogspot.kr/2016/06/loading-tensorflow-graph-with-c-api-by.html by Jacky Tung.

So I directly paste the code here for reference:

MnistModel.cc:

#include<Windows.h>

#include <stdio.h>

#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <utility>

#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

#include "MNistComm.h"

using std::vector;
using std::string;
using std::ostringstream;
using std::endl;
using std::pair;

using namespace tensorflow;

void fillErrMsg(MNIST_COMM_ERROR *err, MNIST_ERROR_CODE c, Status& status)
{
    memset(err, 0, sizeof(MNIST_COMM_ERROR));
        
    err->err = c;
        
    ostringstream ost;
    ost << status.ToString() << endl;
        
    snprintf(err->msg, MAX_MSG_SIZ, "%s", ost.str().c_str());
}

// Windows are Unicode supportted, so everything is natively Unicode
int wmain(wchar_t* argc, wchar_t* argv[])
{    
    // Open file mapping object
    MnistShm mnistShm(false);
    if (!mnistShm)
        return MNIST_OPEN_SHM_FAILED;        
   
    MnistEvent mnistEvent(false);
    if (!mnistEvent)
        return MNIST_OPEN_EVT_FAILED;

    Session* session = NULL;
    Status status = NewSession(SessionOptions(), &session);
    if(!status.ok())
    {
        MNIST_COMM_ERROR err;
        fillErrMsg(&err, MNIST_SESSION_CREATION_FAILED, status);
        mnistShm.SetError(reinterpret_cast<char*>(&err));
        
        return MNIST_SESSION_CREATION_FAILED;
    }
        
    char modelPath[MAX_PATH];
    CMnistComm::WChar2Char(modelPath, argv[1], MAX_PATH - 1);
    
    GraphDef graph_def;    
    status = ReadBinaryProto(Env::Default(), modelPath, &graph_def);
    if (!status.ok())
    {        
        MNIST_COMM_ERROR err;
        fillErrMsg(&err, MNIST_MODEL_LOAD_FAILED, status);
        mnistShm.SetError(reinterpret_cast<char*>(&err));

        return MNIST_MODEL_LOAD_FAILED;
    }
    
    status = session->Create(graph_def);
    if (!status.ok()) {

        MNIST_COMM_ERROR err;
        fillErrMsg(&err, MNIST_GRAPH_CREATION_FAILED, status);
        mnistShm.SetError(reinterpret_cast<char*>(&err));

        return MNIST_GRAPH_CREATION_FAILED;
    }
    
    // Setup inputs and outputs:
    Tensor img(DT_FLOAT, TensorShape({1, MNIST_IMG_DIM}));

    MNIST_COMM_EVENT evt;
    
    while (evt = mnistEvent.WaitForEvent(MNIST_EVENT_PROC))
    {        
        auto buf = img.flat<float>().data();
    
        mnistShm.GetImageData(reinterpret_cast<char*>(buf));

        vector<pair<string, Tensor>> inputs = {
            { "input", img}
        };
        
        // The session will initialize the outputs
        vector<Tensor> outputs;
        // Run the session, evaluating our "logits" operation from the graph
        status = session->Run(inputs, {"recognize"}, {}, &outputs);
        if (!status.ok()) {
            MNIST_COMM_ERROR err;
            fillErrMsg(&err, MNIST_MODEL_RUN_FAILED, status);
            mnistShm.SetError(reinterpret_cast<char*>(&err));
            
            return MNIST_MODEL_RUN_FAILED;
        }
        
        auto weights = outputs[0].shaped<float, 1>({10});
        int index = 0;
        int digit = -1;
        
        float min_ = 0.0;
        for (int i = 0; i < 10; i ++, index ++)
        {
            if (weights(i) > min_)
            {
                min_ = weights(i);
                digit = index;
            }
        }
                
        mnistShm.SetImageLabel(reinterpret_cast<char*>(&digit));
        mnistEvent.NotifyReady();
                
    }

    session->Close();
    
    return 0;
}






No comments:

Post a Comment