Monday, October 31, 2016

Handwritten digits recognition via TensorFlow based on Windows MFC (I) - Load MNist image data

I would like to write a series of posts on how to utilize TensorFlow doing handwritten digit recognition on Windows based on MFC application. The procedures will be first trained a model based on Linux system such Ubuntu, then export the data of the model to Windows. Then try to build an application which can load the model. The application based on MFC will acts as front-end role, sending the image to the application hosting the model, and retrieve the recognized result.

In this posts, I will focus on how to load MNist image data. Please referring to the following linkage for details:
http://yann.lecun.com/exdb/mnist/

I only past the source code here without further explanation, please note I just test the code in VC++ 2015, so I am not sure it's workable under other VC versions:

Header file:

#pragma once

#include <Windows.h>

#define MNIST_MAGIC_IMAGE 0x00000803
#define MNIST_MAGIC_LABEL 0x00000801

#define IMAGE_HEIGHT 28
#define IMAGE_WIDTH 28
#define LABEL_SIZE 1

typedef enum tagMNIST_TYPE
{
IMAGE = 0,
LABEL = 1,
} MNIST_TYPE;

#pragma pack(push, 1)
typedef struct tagMNIST_IMAGE
{
int magic;
int items;
int rows;
int cols;
unsigned char* data;
} MNIST_IMAGE;

typedef struct tagMNIST_LABEL
{
int magic;
int items;
unsigned char* label;
} MNIST_LABEL;

#pragma pack(pop)

typedef unsigned char IMAGE_DATA[IMAGE_HEIGHT][IMAGE_WIDTH];
typedef unsigned char LABEL_DATA[LABEL_SIZE];

class CMnistReader
{
public:
CMnistReader(LPCTSTR path, MNIST_TYPE type = IMAGE);
~CMnistReader();

bool operator !()
{
return !m_bInit;
}

bool GetNextImage(IMAGE_DATA img);
bool GetNextLabel(LABEL_DATA lb);

bool GetPrevImage(IMAGE_DATA img);
bool GetPrevLabel(LABEL_DATA lb);

bool GetImage(IMAGE_DATA img, int idx);
bool GetLabel(LABEL_DATA lb, int idx);

private:
MNIST_TYPE m_type;
bool m_bInit;
HANDLE m_hFile;
HANDLE m_hMapFile;
DWORD m_dwFileSize;
unsigned char* m_lpMapAddress;
unsigned char* m_lpCurAddress;
union {
MNIST_IMAGE m_image;
MNIST_LABEL m_label;
};

};

class CBitmapConverter
{
public:
CBitmapConverter(HDC hdc);
~CBitmapConverter();

bool Convert(IMAGE_DATA data);
HBITMAP GetBmpHandle() { return m_hBitmap; }

private:
HBITMAP m_hBitmap;
unsigned char* m_lpBitmapBits;
};


Source Files:

#include "stdafx.h"

#include <WinSock2.h>
#include "MnistReader.h"

#pragma comment(lib, "Ws2_32.lib")

CMnistReader::CMnistReader(LPCTSTR path, MNIST_TYPE type) :
m_type(type),
m_bInit(false),
m_hFile(INVALID_HANDLE_VALUE),
m_hMapFile(NULL),
m_lpMapAddress(NULL)
{
m_hFile = CreateFile(path, GENERIC_READ, 0, NULL, OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL, NULL);

if (m_hFile == INVALID_HANDLE_VALUE)
{
OutputDebugString(_T("CreateFile() Failed"));
return;
}

m_dwFileSize = GetFileSize(m_hFile, NULL);
if (type == IMAGE && m_dwFileSize < sizeof(m_image) ||
type == LABEL && m_dwFileSize < sizeof(m_label))
{
OutputDebugString(_T("Invalid File Size"));
return;
}

m_hMapFile = CreateFileMapping(m_hFile, NULL, PAGE_READONLY,
0, m_dwFileSize,  NULL);

if (m_hMapFile == NULL)
{
OutputDebugString(_T("CreateFileMapping() Failed"));
return;
}

m_lpMapAddress = (unsigned char*)MapViewOfFile(m_hMapFile, FILE_MAP_READ, 0, 0, m_dwFileSize);

if (m_lpMapAddress == NULL)
{
OutputDebugString(_T("MapViewOfFile() Failed"));
return;
}

switch (type)
{
case IMAGE:
memcpy(&m_image, m_lpMapAddress, sizeof(MNIST_IMAGE));
m_image.magic = htonl(m_image.magic);
m_image.items = htonl(m_image.items);
m_image.rows = htonl(m_image.rows);
m_image.cols = htonl(m_image.cols);
m_lpCurAddress = m_lpMapAddress + offsetof(MNIST_IMAGE, data);

if (m_image.magic != MNIST_MAGIC_IMAGE)
{
OutputDebugString(_T("Invalid Image File Format"));
return;
}
break;
case LABEL:
memcpy(&m_label, m_lpMapAddress, sizeof(MNIST_LABEL));
m_label.magic = htonl(m_label.magic);
m_label.items = htonl(m_label.items);
m_lpCurAddress = m_lpMapAddress + offsetof(MNIST_LABEL, label);

if (m_label.magic != MNIST_MAGIC_LABEL)
{
OutputDebugString(_T("Invalid Image File Format"));
return;
}
break;
}

m_bInit = true;
}


CMnistReader::~CMnistReader()
{
if (m_lpMapAddress)
UnmapViewOfFile(m_lpMapAddress);
if (m_hMapFile)
CloseHandle(m_hMapFile);
if(m_hFile != INVALID_HANDLE_VALUE)
CloseHandle(m_hFile);
}

bool CMnistReader::GetNextImage(IMAGE_DATA img)
{
if (m_type == IMAGE &&
(m_lpMapAddress + m_dwFileSize - m_lpCurAddress) >= sizeof(IMAGE_DATA))
{
memcpy(img, m_lpCurAddress, sizeof(IMAGE_DATA));
m_lpCurAddress += sizeof(IMAGE_DATA);
return true;
}
else
{
memset(img, 0, sizeof(IMAGE_DATA));
return false;
}
}

bool CMnistReader::GetNextLabel(LABEL_DATA lb)
{
if (m_type == LABEL &&
(m_lpMapAddress + m_dwFileSize - m_lpCurAddress) >= sizeof(LABEL_DATA))
{
memcpy(lb, m_lpCurAddress, sizeof(LABEL_DATA));
m_lpCurAddress += sizeof(LABEL_DATA);
return true;
}
else
{
memset(lb, 0xff, sizeof(LABEL_DATA));
return false;
}
}

bool CMnistReader::GetPrevImage(IMAGE_DATA img)
{
if (m_type == IMAGE &&
(m_lpCurAddress - sizeof(IMAGE_DATA)) >=
(m_lpMapAddress + offsetof(MNIST_IMAGE, data)))
{
m_lpCurAddress -= sizeof(IMAGE_DATA);
memcpy(img, m_lpCurAddress, sizeof(IMAGE_DATA));
return true;
}
else
{
memset(img, 0, sizeof(IMAGE_DATA));
return false;
}
}

bool CMnistReader::GetPrevLabel(LABEL_DATA lb)
{
if (m_type == LABEL &&
(m_lpCurAddress - sizeof(LABEL_DATA)) >=
(m_lpMapAddress + offsetof(MNIST_LABEL, label)))
{
m_lpCurAddress -= sizeof(LABEL_DATA);
memcpy(lb, m_lpCurAddress, sizeof(LABEL_DATA));
return true;
}
else
{
memset(lb, 0xff, sizeof(LABEL_DATA));
return false;
}
}

bool CMnistReader::GetImage(IMAGE_DATA img, int idx)
{
    // Not implemented
return false;
}

bool CMnistReader::GetLabel(LABEL_DATA lb, int idx)
{
// Not implemented
return false;
}

CBitmapConverter::CBitmapConverter(HDC hdc)
{
BITMAPINFO bi;
ZeroMemory(&bi, sizeof(BITMAPINFO));
bi.bmiHeader.biSize = sizeof(BITMAPINFOHEADER);
bi.bmiHeader.biWidth = IMAGE_WIDTH;
bi.bmiHeader.biHeight = -IMAGE_HEIGHT;
bi.bmiHeader.biPlanes = 1;
bi.bmiHeader.biBitCount = 8;
bi.bmiHeader.biCompression = BI_RGB;

m_hBitmap = CreateDIBSection(hdc, &bi, DIB_RGB_COLORS, (VOID**)&m_lpBitmapBits, NULL, 0);
}

CBitmapConverter::~CBitmapConverter()
{
if (m_hBitmap)
DeleteObject(m_hBitmap);
}

bool CBitmapConverter::Convert(IMAGE_DATA data)
{
unsigned char* lpBitmapBits = m_lpBitmapBits;

#define ALIGN(x,a)              __ALIGN_MASK(x, a-1)
#define __ALIGN_MASK(x,mask)    (((x)+(mask))&~(mask))
int width = IMAGE_WIDTH;
int pitch = ALIGN(width, 4);
#undef __ALIGN_MASK
#undef ALIGN

for (int i = 0; i < IMAGE_HEIGHT; i++)
{
unsigned char* v = data[i];
memcpy(lpBitmapBits, v, width);
lpBitmapBits += pitch;
}

return true;
}


The running result:



No comments:

Post a Comment