Code Monkey home page Code Monkey logo

mnist-classification-with-numpy-and-backpropagation's Introduction

MNIST Classification with Numpy and Backpropagation

Classifying MNIST digits without abstractifying the mathematical optimization involved in deep learning, diving into the meat of backpropagation.

mnist

Mathematical Optimization

Machine Learning uses mathematical optimization to minimize a loss, or measure of how bad the model is, where we then can descend the gradients of the parameters, called gradient descent. Finding the partial derivative, or rate of change, of a function with respect to its parameters is finding the gradient of those parameters.

Deep neural networks

Multilayer feedforward networks contain input neurons, hidden neurons, and output neurons, which can be expressed as single numbers. Each and every input neuron affects each hidden neuron in the second layer, and so on until the output layer.

a

Each neuron is connected to a value in the next layer by a multiplier, or 'weight', and then added to a bias. This value, later expressed as 'z', is then sent through an activation function, such as sigmoid, to squish the value to the range of (0, 1), which is known as the activation. The activation of a neuron in any layer except for the input layer may be expressed as this:

b

where the superscript index is the layer, and the subscript index is of which neuron it is in a layer. However, this can be simplified significantly by expressing this as a series of matrix operations on an entire layer:

c

d

Batches

Deep neural networks are usually trained on mini-batches of data at a time, finding an optimal balance between training time and results, where multiple inputs are fed in and multiple outputs are expected. This can be done by simply adding more rows to the input matrix.

Backpropagation

The goal of backpropagation is to adjust the weights and biases in order to optimally classify MNIST digits. We first have to have a measure of how bad the network performs, so that we can minimize it using gradient descent. An easy yet effective one to implement is the Mean Squared Error:

e

'Y' is the expected values, in our case that would be a vector of 0's except for the index that is the handwritten input digit, which will be a 1. Where L is the number of layers, so a superscript L would be the final activations, and is what the network predicts what the expected labels should be, computed through the layers of matrix operations.

To do this we must find the partial derivative with respect for each weight and bias, or gradients, using the chain rule. These are the calculated derivatives of a simple neural network with one input neuron, and one output neuron. Remember that 'z' is the pre-activation.

f

g

i

j

If we add another layer of one neuron to the input, how do we find the weight and bias gradient of the first layer? We have to find the gradient of the activation in the second layer, and use it in the chain rule the same way we used the derivative of our loss function.

k

l

m

n

When we add more neurons per layer, we can abstractify the backpropagation process in the same way we did with forward propagation: using matrix operations. The math does not change that much. Representing transpose matrices conflict with the superscript, so the layer it is in is now subscript.

o

p

q

Thanks to Online LaTeX Equations for a great equation editor, and 3Blue1Brown for the math!

mnist-classification-with-numpy-and-backpropagation's People

Contributors

hughmarch avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.