Code Monkey home page Code Monkey logo

ago109 / predictive-forward-forward Goto Github PK

View Code? Open in Web Editor NEW
47.0 4.0 5.0 93.23 MB

Implementation/simulation of the predictive forward-forward credit assignment algorithm for training neurobiologically-plausible recurrent neural network models.

License: MIT License

Python 95.54% Shell 4.46%
forward-forward local-learning tensorflow predictive-forward-forward forward-learning learning-algorithms biological-neural-networks brain-inspired-computing credit-assignment machine-learning

predictive-forward-forward's Introduction

The Predictive Forward-Forward Algorithm

Bio-plausible Forward-Only Learning for Training Neural Networks

Implementation of the proposed predictive forward-forward (PFF) learning algorithm for training a neurobiologically-plausible recurrent neural system. This work combines elements of predictive coding with the recently proposed forward-forward algorithm to create a novel online learning process that involves dynamically adapting two neural circuits - a representation circuit and a generative circuit. Notably, the system introduces noise injection into the latent activity updates as well as learnable lateral synapses that induce competition across neural units (emulating cross-inhibitory and self-excitation effects inherent to neural computation).

Requirements

Our implementation is easy to follow and, with knowledge of basic linear algebra, one can decode the inner workings of the PFF algorithm. Please look at Algorithm 1 in our paper (in the Appendix) to better understand the overall mechanics of the inference and learning processes. In this framework, we have provided simple modules; thus hopefully making it very convenient to extend our framework.

To run the code, you should only need following basic packages:

  1. TensorFlow (version >= 2.0)
  2. Numpy
  3. Matplotlib
  4. Python (version >=3.5)
  5. ngc-learn (Some modules responsible for generating image samples are dependent on ngc-learn -- if you do not install ngc-learn, you won't be able to use fit_gmm.py, as this script uses the mixture model in that package to retro-fit the latent prior for the PFF model's generative circuit, which means that sample_model.py and plot_tsne.py will have no prior distribution model to access, so simply comment out the lines that involve sample_model.py and plot_tsne.py in the analyze.sh script if you do not install ngc-learn).

Execution

To reproduce results from our paper, simply perform the following steps (running the relevant provided Bash scripts) the following provided Bash scripts:

  1. bash src/run.sh (This will train model for E=60 epochs.)
  2. bash src/analyze.sh (This will evaluate a trained model and produce plots/visuals.) After running the above two scripts, you can find the simulation outputs in the example experimental results directory tree that have been pre-created for you. exp/pff/mnist/ contains the results for the MNIST model (over 2 trials) and exp/pff/kmnist/ contains the results for the KMNIST model (over 2 trials). In each directory, the following is stored:
  • post_train_results.txt - contains development/training cross-trial accuracy values
  • test_results.txt - contains test cross-trial accuracy values
  • trial0/ - contains model data for trial 0, as well as any visuals produced by analyze.sh
  • trial1 - contains model data for trial 1, as well as any visuals produced by analyze.sh (Note that you should modify the MODEL_DIR in analyze.sh to point to a particular trial's sub-directory -- the default points to trial 0, and thus only places images inside of the trial0/ sub-directoy. )

Model-specific hyper-parameter defaults can be set/adjusted in pff_rnn.py. Training-specific hyper-parameters are available in sim_train.py - note that one can create/edit an arguments dictionary much like the one depicted below (inside of sim_train.py):

which the PFF_RNN() constructor takes in as input to construct the simulation of the dual-circuit system.

Tips while using this algorithm/model on your own datasets:

  1. Track your local losses, accordingly adjust the hyper-parameters for the model
  2. Play with non-zero, small values for the weight decay coefficients reg_lambda (for the representation circuit) and g_reg_lambda (for the generative circuit) - for K-MNIST a small value (as indicated in the comments) for reg_lambda seemed to improve generalization performance slightly in our experience.

Citation

If you use or adapt (portions of) this code/algorithm in any form in your project(s), or find the PFF algorithm helpful in your own work, please cite this code's source paper:

@article{ororbia2023predictive,
  title={The Predictive Forward-Forward Algorithm},
  author={Ororbia, Alexander and Mali, Ankur},
  journal={arXiv preprint arXiv:2301.01452},
  year={2023}
}

predictive-forward-forward's People

Contributors

ago109 avatar ankurmali avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

predictive-forward-forward's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.