Code Monkey home page Code Monkey logo

slot-attention-lightning's Introduction

Slot Attention Lightning

PyTorch Lightning Config: Hydra Template


Description

This repo is the implementation of the baseline methods for unsupervised Object-Centric Learning, including IODINE, MONet, Slot Attention, and Genesis V2. The implementation of IODINE, MONet, and Genesis V2 is from here.

↑↑↑ Visualization of training results logged by WandB ↑↑↑


Repository Structure

The directory structure of this repo looks like this:
├── .github                   <- Github Actions workflows
│
├── configs                   <- Hydra configs
│   ├── callbacks                <- Callbacks configs
│   ├── data                     <- Data configs
│   ├── debug                    <- Debugging configs
│   ├── experiment               <- *** Experiment configs ***
│   │   ├── slota                 
│   │   │  ├── clv6.yaml          
│   │   │  └── ...
│   │   └── ...                  
│   ├── extras                   <- Extra utilities configs
│   ├── hparams_search           <- Hyperparameter search configs
│   ├── hydra                    <- Hydra configs
│   ├── local                    <- Local configs
│   ├── logger                   <- Logger configs (we use wandb)
│   ├── model                    <- Model configs
│   ├── paths                    <- Project paths configs
│   ├── trainer                  <- Trainer configs
│   │
│   ├── eval.yaml             <- Main config for evaluation
│   └── train.yaml            <- Main config for training
│
├── data                      <- Directory for Dataset
│   ├── CLEVR6                
│   │   ├── images            <- raw images
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   ├── masks             <- mask annotations
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   └── scenes          <- metadata
│   │       ├── CLEVR_train_scenes.json
│   │       └── CLEVR_val_scenes.json
│   └── ...
│
├── logs                   <- Logs generated by hydra and lightning loggers
│
├── scripts                <- Shell scripts
│
├── src                    <- Source code
│   ├── data                     <- Data scripts
│   ├── models                   <- Model scripts
│   ├── utils                    <- Utility scripts
│   │
│   ├── eval.py                  <- Run evaluation
│   └── train.py                 <- Run training
│
├── tests                  <- Tests of any kind
│
├── .env.example              <- Example of file for storing private environment variables
├── .gitignore                <- List of files ignored by git
├── .pre-commit-config.yaml   <- Configuration of pre-commit hooks for code formatting
├── .project-root             <- File for inferring the position of project root directory
├── environment.yaml          <- File for installing conda environment
├── Makefile                  <- Makefile with commands like `make train` or `make test`
├── pyproject.toml            <- Configuration options for testing and linting
├── requirements.txt          <- File for installing python dependencies
├── setup.py                  <- File for installing project as a package
└── README.md

Note
Each dataset may have each different way of providing mask annotation and metadata, so you should match the Dataset class for each dataset with its desired configuration.


Installation

This repo is developed based on Lightning-Hydra-Template 1.5.3 with Python 3.8.12 and PyTorch 1.11.0.

Pip

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# [OPTIONAL] create conda environment
conda create -n slota python=3.8
conda activate slota

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Conda

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# create conda environment and install dependencies
conda env create -f environment.yaml

# activate conda environment
conda activate slota

How to run

Train model with chosen experiment configuration from configs/experiment/

Training

# training Slot Attention over CLEVR6 dataset
python src/train.py \
experiment=slota/clv6.yaml

# training Genesis V2 over CLEVRTEX dataset
python src/train.py \
experiment=genesis2/clvt.yaml

You can create your own expreiment configs for the purpose.
But, for simple modification, you can override any parameter from command line.

# training Slot Attention over CLEVR6 dataset with custom config
python src/train.py \
experiment=slota/clv6.yaml \
data.data_dir=/workspace/dataset/clevr_with_masks/CLEVR6 \
trainer.check_val_every_n_epoch=10 \
model.net.num_slots=10 \
model.net.num_iter=5 \
model.name="slota_k10_t5" # model.name will be used for logging on wandb

Evaluation

You can evaluate a trained model with the corresponding checkpoint.
The evaluation is also conducted during training with the interval of trainer.check_val_every_n_epoch.

# evaluating Slot Attention over CLEVR6 dataset.
# similar to the training phase, you can also customize the config with command line
python src/eval.py \
experiment=slota/clv6.yaml \
ckpt_path=logs/train/runs/clv6_slota/{timestamp}/checkpoints/last.ckpt

slot-attention-lightning's People

Contributors

janghyuk-choi avatar

Stargazers

 avatar JiHyuk-Byun avatar Dongwon Kim avatar Jinsu Lim avatar  avatar  avatar

Watchers

 avatar

Forkers

hello-jinwoo

slot-attention-lightning's Issues

Question for dataset download

Hello, thank you for sharing the excellent code!

I would like to run your code, but I noticed that it only accepts the CLEVR dataset in the .png format.

Unfortunately, I only have the data in .tfrecord format from the official multi-object-dataset.

Should I use the original CLEVR dataset from https://cs.stanford.edu/people/jcjohns/clevr/ (though it does not provide mask labels) or should I manually convert the multi-object-dataset to the .png format?

It would also be extremely helpful for newcomers in this field like me if you could specify the source of each dataset. 😄

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.