Code Monkey home page Code Monkey logo

dist_inc_comp's Introduction

Learning distributions of increasing complexity

This respository contains code to accompany the paper "Neural networks trained with SGD learn distributions of increasing complexity" [arXiv:2211.11567] by M. Refinetti, A. Ingrosso, and S. Goldt.

In a nutshell

Learning distributions of increasing complexity

In this plot, we show the test accuracy of a ResNet18 evaluated on CIFAR10 during training with SGD on four different training data sets: the standard CIFAR10 training set (dark blue), and three different ``clones'' of the training set. The images of the clones were drawn from a Gaussian mixture fitted to CIFAR10, a mixture of Wasserstein GAN (WGAN) fitted to CIFAR10, and the cifar5m data set of Nakkiran et al.. The clones form a hierarchy of approximations to CIFAR10: while the Gaussian mixture captures only the first two moments of the inputs of each class correctly, the images in the WGAN and cifar5m data sets yield increasingly realistic images by capturing higher-order statistics. The ResNet18 trained on the Gaussian mixture has the same test accuracy on CIFAR10 as the baseline model, trained directly on CIFAR10, for the first 50 steps of SGD; the ResNet18 trained on cifar5m has the same error as the baseline model for about 2000 steps. This result suggests that the network trained on CIFAR10 discriminates the images using increasingly higher-order statistics during training.

Usage

The key programme to train the network on distributions of increasing complexity is dist_inc_comp.py. Running the programme with the --help option yields an overview over the options.

To train a ResNet18 on CIFAR10, simply run

python dist_inc_comp.py --model resnet18 --dataset cifar10

If instead you would like to train the ResNet18 on a Gaussian mixture, and test it on CIFAR10 (that's the green line in the plot above), call

python dist_inc_comp.py --model resnet18 --dataset cifar10 --clone gp

where gp indicates that the clone to be used for training is the Gaussian process. If you would like to train the model on the GAN data set or on cifar5m, please contact Sebastian directly while we figure out a better way to share the raw data sets.

Requirements

To run the code, you will need up-to-date versions of

  • pyTorch
  • numpy
  • scipy
  • einops

dist_inc_comp's People

Contributors

sgoldt avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.