Code Monkey home page Code Monkey logo

mojograd's Introduction

🔥grad



mojograd is a Mojo implementation of micrograd, a reverse-mode autodiff library with a PyTorch-like API.

The goal is to be as close as possible to micrograd, keeping a pretty clean syntax to define computational graphs. Like micrograd, it only supports scalar values for now, but we plan to extend it to support Tensors in the near future.

Note that mojograd is in WIP and relies on static register passable structures, so backward pass copies values and can be really slow (Mojo traits support should improve that, so please stay tuned!). However, even now with zero optimizations, forward pass is already ~40x faster than the original Python implementation (see benchmarks bellow).

Using

mojograd dynamically builds a computational graph by overloading operators on Value type, performing the forward pass. Just write your expression like a normal (non-diff) equation and call backward() to perform the backward pass:

from mojograd import Value

var a = Value(2.0)
var b = Value(3.0)
var c: Float32 = 2.0
var d = b**c
var e = a + c
e.backward()

a.print() # => <Value data: 2.0 grad: 1.0 op:  >
b.print() # => <Value data: 3.0 grad: 0.0 op:  >
d.print() # => <Value data: 9.0 grad: 0.0 op: ** >
e.print() # => <Value data: 4.0 grad: 1.0 op: + > 

For a more complete example (a simple Multi-Layer Perceptron), please check the tests.mojo file. You can run it with:

mojo tests.mojo

Benchmarks

MLP binary classifier

When compared to original Python implementation, mojograd is up to ~40 times faster in forward pass.

# parameters micrograd (Python) (sec) mojograd (Mojo) (sec) speed up
367 0.001 0.00006 x20
1185 0.004 0.0001 x40
4417 0.01 0.0005 x20
17025 0.06 0.002 x30

Changelog

  • 2023.11.19
    • Benchmarking inference and comparing with micrograd
  • 2023.11.18
    • Optimization pass through the code
  • 2023.11.14
    • Rebuild the whole thing using pointer handling (dangerous) to register-passables
    • Got the full micrograd implementation working!
    • MLP example training and inference working!
  • 2023.09.05
    • Starting from scratch based on suggestions from Jack Clayton
    • Topological sort works but I'm messing something with memory handling, the gradients are not getting updated
  • 2023.07.04
    • Ported Neuron, Layer and MLP
    • Back to use yakupc55's List (need register_passable data struct)
  • 2023.06.30
    • Finally got it working! Only missing pow ops and review it

mojograd's People

Contributors

automata avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

g8392

mojograd's Issues

Pointer[Value] is not working

error in mojograd/utils.mojo:9:45: error: 'DynamicVector' parameter #0 has 'CollectionElement' type, but value has type 'Pointer[Value, 0]'

@staticmethodPointer
fn build_topo(inout ptr_v: Pointer[Value], inout visited: DynamicVector[Pointer[Value]], inout topo: DynamicVector[Pointer[Value]]):
    if ptr_v == Pointer[Value].get_null():
        return
    var is_visited: Bool = False
    let size: Int = len(visited)
    for i in range(size):
        if ptr_v == visited[i]:
            is_visited = True
    if not is_visited:
        visited.push_back(ptr_v)
        # Make sure we don't try to access null pointers (e.g. on pow
        # where we don't have the right child)
        if ptr_v.load().l != Pointer[Int].get_null():
            var ptr_l: Pointer[Value] = ptr_v.load().l.bitcast[Value]()
            if ptr_l != Pointer[Value].get_null():
                Value.build_topo(ptr_l, visited, topo)
        if ptr_v.load().r != Pointer[Int].get_null():
            var ptr_r: Pointer[Value] = ptr_v.load().r.bitcast[Value]()
            if ptr_r != Pointer[Value].get_null():
                Value.build_topo(ptr_r, visited, topo)
        topo.push_back(ptr_v)

@always_inline
fn backward(inout self):
    var visited: DynamicVector[Pointer[Value]] = DynamicVector[Pointer[Value]]()
    var topo: DynamicVector[Pointer[Value]] = DynamicVector[Pointer[Value]]()
    var ptr_self: Pointer[Value] = Pointer[Value].alloc(1)
    ptr_self.store(self)
    Value.build_topo(ptr_self, visited, topo)
    self.grad.store(1.0)
    var reversed: DynamicVector[Pointer[Value]] = reverse(topo)
    for i in range(len(reversed)):
        Value._backward(reversed[i])
    visited.clear()
    topo.clear()
    reversed.clear()
    ptr_self.free()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.