Code Monkey home page Code Monkey logo

modelsym's Introduction

Code to run the experiments of the Neurips 2022 paper On the Symmetries of Deep Learning Models and their Internal Representations.

Overview

This repository is currently organized into a module model_symmetries with submodules stitching and alignment, corresponding to sections 4 and 5 of the paper (for the network dissection results of section 6 we used the implementation at https://github.com/CSAILVision/NetDissect-Lite).

In addition there are some submodules containing code shared across stitching and alignment, namely

  • models.py, datasets.py, train.py and plotting.py (self explanatory)
  • zoo.py: utilities to train a bunch of models from independent random seeds
  • constants.py: specify a directory in which to store data/models/results by defining the variable data_dir.

stitching

The key classes for stitching layers and stitched models are in stitching.py. In particular, we direct attention towards the Birkhoff class, which implements for our approach using PGD on the Birkhoff polytope of doubly stochastic matrices.

train.py has more options than is typical, due to a few major implementation considerations:

  1. The need to make sure that when stitching, we only update parameters of the stitching layer.
  2. The overhead of PGD and extra $-\ell_2$ regularization.
  3. The necessity of a no-grad training epoch before validation.

The main experiment script is cifar10_stitching.py. This also has many options, due to the number of combinations of model/stitching layer type we consider.

In order to run the experiments stitching Compact Convolutional Transformers, you will need https://github.com/SHI-Labs/Compact-Transformers, which is included as a Git submodule of this repository at model_symmetries/ct. To initialize and update it, run

git submodule init && git submodule update

alignment

Core functions are located in alignment.py. The $G_{\mathrm{ReLU}}$-Procrustes and CKA metrics are wreath_{procrustes,cka} (the group $G_{\mathrm{ReLU}}$ is an example of a wreath product, hence the name).

Visualization

plotting.py contains functions for displaying stitching penalties and dissimilarity metrics, which can be run in the notebook plotting.ipynb.

Parallelization

We ran these experiments on a cluster managed by SLURM -- files ending in .slurm are SLURM batch files. In order to distribute the many sweeps in these experiments across nodes of the cluster, we submitted batches to the queue using loops found in the bash scripts (files ending in .sh). WARNING: executing these scripts will consume many GPU days.

Citation

If you find this code useful, please cite our paper.

@article{modelsyms2022,
  doi = {10.48550/ARXIV.2205.14258},
  url = {https://arxiv.org/abs/2205.14258},
  author = {Godfrey, Charles and Brown, Davis and Emerson, Tegan and Kvinge, Henry},
  keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {On the Symmetries of Deep Learning Models and their Internal Representations},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}

Notice

This research was supported by the Mathematics for Artificial Reasoning in Science (MARS) initiative at Pacific Northwest National Laboratory. It was conducted under the Laboratory Directed Research and Development (LDRD) Program at at Pacific Northwest National Laboratory (PNNL), a multiprogram National Laboratory operated by Battelle Memorial Institute for the U.S. Department of Energy under Contract DE-AC05-76RL01830.

modelsym's People

Contributors

godfrey-cw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

godfrey-cw

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.