Code Monkey home page Code Monkey logo

nanograd's Introduction


nanograd

A lightweight deep learning framework.

DescriptionFeaturesTODOLicense

Description

After verification, nanograd is not a city in Russia...

However, it is a PyTorch-like lightweight deep learning framework. Use it to implement any DL algo you want with little boilerplate code.

Essentially, Nanograd is a continuously updated project. The goal is to implement as many features as possible while using as few abstraction layers as possible (only Numpy functions are allowed). Any contribution to the repo is welcome.

The library has a built-in auto-differentiation engine that dynamically builds a computational graph. The framework is built with basic features to train neural nets: basic ops, layers, weight initializers, optimizers and loss functions. Additional tools are developed to visualize your network: computational graph visualizers or activation map visualizers (SOON!).

The repo will be updated regularly with new features and examples.

Inspired from geohot's tinygrad.

Features

  • PyTorch-like autodifferentiation engine (dynamically constructed computational graph)
  • Weight initialization: Glorot uniform, Glorot normal, Kaiming uniform, Kaiming normal
  • Activations: ReLU, Sigmoid, tanh, Swish, ELU, LeakyReLU
  • Convolutions: Conv1d, Conv2d, MaxPool2d, AvgPool2d
  • Layers: Linear, BatchNorm1d, BatchNorm2d, Flatten, Dropout
  • Optimizers: SGD, Adam, AdamW
  • Loss: CrossEntropyLoss, Mean squared error
  • Computational graph visualizer (see example)

A quick side-by-side comparison between PyTorch and Nanograd for tensor computations

Basic tensor calculations

PyTorch

a = torch.empty((30, 30, 2))
         .normal_(mean=3, std=4)
b = torch.empty((30, 30, 1))
         .normal_(mean=10, std=2)

a.requires_grad = True
b.requires_grad = True

c = a + b
d = c.relu()
e = c.sigmoid()
f = d * e

f.sum().backward()

print(a.grad)
print(b.grad)

Nanograd

a = Tensor.normal(3, 4, (30, 30, 2), requires_grad=True)
b = Tensor.normal(10, 2, (30, 30, 1), requires_grad=True)

c = a + b
d = c.relu()
e = c.sigmoid()
f = d * e

f.backward()

print(a.grad)
print(b.grad)

Training a CNN on MNIST

# Model, loss & optim
model = CNN()
loss_function = CrossEntropyLoss()
optim = SGD(model.parameters(), lr=0.01, momentum=0)

# Training loop
BS = 128
losses, accuracies = [], []
STEPS = 1000

for i in tqdm(range(STEPS), total=STEPS):
  samp = np.random.randint(0, X_train.shape[0], size=(BS))
  X = tensor.Tensor(X_train[samp])
  Y = tensor.Tensor(Y_train[samp])

  optim.zero_grad()

  out = model(X)

  cat = out.data.argmax(1)
  accuracy = (cat == Y.data).mean()

  loss = loss_function(out, Y)
  loss.backward()

  optim.step()

  loss, accuracy = float(loss.data), float(accuracy)
  losses.append(loss)
  accuracies.append(accuracy)

Y_test_preds = model(tensor.Tensor(X_test)).data.argmax(1)
print((Y_test == Y_test_preds).mean())

Visualizing a computational graph

Visualizing a computational graph has never been that easy. Just call plot_forward and plot_backward.

f.plot_forward()

f.plot_backward()

TODO

  • Solve batchnorm issues
  • Add GRU, LSTM cells
  • Code example with EfficientNet-B0, CIFAR-10, MNIST
  • Code a transformer with Nanograd and train it on GPU

License

MIT


GitHub @PABannier  ·  Twitter @el_PA_B

nanograd's People

Contributors

pabannier avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

longjohncoder

nanograd's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.