Code Monkey home page Code Monkey logo

mnist_digit_recognition's Introduction

Digit Recognizer for MNIST Data Set

In this project, I try different classification methods to recognize the digits in the MNIST data set.

MNIST Data Set

There are 60000 training samples and 10000 testing samples in the MNIST data set. Detailed description about this data set could be found via http://yann.lecun.com/exdb/mnist/index.html. Notice that the MNIST data set has been normalized to [0 1].

Features

Each digit image is 28x28 pixel. In the data set, each sample is represented by a normalized 28x28=784 dimensional vector. I directly use this vector as the feature vector of each sample.

Classification Methods

  • KNN:
    accuracy 97.05%, no training time, testing time Mmin
  • Linear kernel SVM:
    accuracy 93.98%, training time ≈ 6.658min, testing_time ≈ 2.76min
  • Polynomial kernel SVM with degree 2:
    accuracy 98.08%, training time ≈ 3.9min, testing_time ≈ 2.4min
  • Radial basis kernel SVM with default gamma:
    accuracy 94.46%, training time ≈ 10.43min, testing time ≈ 4.84min
  • Artificial Neural Network with 1 hidden layer (784-300-10):
    1 iteration: accuracy 87.17%
    20 iterations: accuracy 96.19%
    100 iterations: accuracy 97.94%
  • Convolutional Neural Network:
    1 epoch: accuracy 88.17%, training time + testing time ≈ 90s
    100 epochs: accuracy 98.85%, training time + testing time ≈ 2hr 30min

Directories included in the toolbox

load_data/ - data set and functions to load the data
knn/ - 3-nearest neighborhood classifier
svm/ - A library of linear, polynomial and RBF classifier based on libsvm
ann/ - A library of fully connected feedforward neural network and convolutional neural network (CNN is based on DeepLearnToolbox)

Setup

1. Download
2. addpath(genpath('MNIST_digit_recognition'));

Example

```matlab % load training set and testing set clear all; train_set = loadMNISTImages('train-images.idx3-ubyte')'; train_label = loadMNISTLabels('train-labels.idx1-ubyte'); test_set = loadMNISTImages('t10k-images.idx3-ubyte')'; test_label = loadMNISTLabels('t10k-labels.idx1-ubyte');

% parameter setting alpha = 0.1; % learning rate beta = 0.01; % scaling factor for sigmoid function train_size = size(train_set); N = train_size(1); % number of training samples D = train_size(2); % dimension of feature vector n_hidden = 300; % number of hidden layer units K = 10; % number of output layer units % initialize all weights between -1 and 1 W1 = 2rand(1+D, n_hidden)-1; % weight matrix from input layer to hidden layer W2 = 2rand(1+n_hidden, K)-1; % weight matrix from hidden layer to ouput layer max_iter = 100; % number of iterations Y = eye(K); % output vector

% training for i=1:max_iter disp([num2str(i), ' iteration']); for j=1:N % propagate the input forward through the network input_x = [1; train_set(j, :)']; hidden_output = [1;sigmf(W1'input_x, [beta 0])]; output = sigmf(W2'hidden_output, [beta 0]); % propagate the error backward through the network % compute the error of output unit c delta_c = (output-Y(:,train_label(j)+1)).output.(1-output); % compute the error of hidden unit h delta_h = (W2delta_c).(hidden_output).(1-hidden_output); delta_h = delta_h(2:end); % update weight matrix W1 = W1 - alpha(input_xdelta_h'); W2 = W2 - alpha(hidden_output*delta_c'); end end

% testing test_size = size(test_set); num_correct = 0; for i=1:test_size(1) input_x = [1; test_set(i,:)']; hidden_output = [1; sigmf(W1'*input_x, [beta 0])]; output = sigmf(W2'*hidden_output, [beta 0]); [max_unit, max_idx] = max(output); if(max_idx == test_label(i)+1) num_correct = num_correct + 1; end end % computing accuracy accuracy = num_correct/test_size(1);


<h2> References </h2>
<ul>
<li>http://ufldl.stanford.edu/wiki/index.php/Using_the_MNIST_Dataset </li>
<li>https://github.com/cjlin1/libsvm</li>
<li>https://github.com/rasmusbergpalm/DeepLearnToolbox</li>
</ul>

<h2> Remark </h2>
If you are testing CNN using Octave, please pay attention to the Octave version because there is a bug in the previous versions.
<br>
error: Octave 3.8.0 or greater is required for CNNs as there is a bug in convolution in previous versions.

mnist_digit_recognition's People

Contributors

jincheng9 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.