Code Monkey home page Code Monkey logo

object-centric-library's Introduction

Object-Centric Library

Code accompanying our paper:

Generalization and Robustness Implications in Object-Centric Learning
Andrea Dittadi, Samuele Papa, Michele De Vita, Bernhard Schölkopf, Ole Winther, Francesco Locatello
ICML 2022

Summary of out-of-the-box functionalities (see Using the library):

  • training 4 object-centric models and 2 VAE baselines on 6 multi-object datasets (CLEVR, MultidSprites, Objects Room, Shapestacks, Tetrominoes, ClevrTex);
  • evaluating trained models in terms of:
    • object segmentation;
    • a downstream task consisting in predicting all object properties;
    • qualitative performance, e.g., showing reconstructions, segmentation masks, and separate reconstructions from each slot.
  • evaluating the generalization of trained models under a range of distribution shifts.

The image below showcases the datasets (top row) and the distribution shifts on CLEVR (bottom row) that were used in the experimental study in our paper.

Datasets and distribution shifts

Visualizations of a few object-centric models trained in our study on the datasets shown above:

Visualization trained models

Example full visualization of a single trained model, including separate slot reconstructions:

Visualization of a trained model on CLEVR6

Visualizations of a few object-centric models on the distribution shifts on CLEVR:

Visualization trained models on distribution shifts on CLEVR

The library can be extended with more models, datasets, distribution shifts, evaluation metrics, and downstream tasks.

Compared to the original library used in our paper, the current version includes the ClevrTex dataset.

Setting up the environment

  1. Install requirements from requirements.txt. Example installation with conda:

    conda create --name object_centric_lib python=3.8
    conda activate object_centric_lib
    
    # Optionally install PyTorch with a custom CUDA version. Example:
    # pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
    
    pip install -r requirements.txt

    Note: PyTorch might have to be installed separately before installing the requirements, depending on the required CUDA version (see the PyTorch installation instructions).

    Python 3.8 recommended (≥3.8 required).

  2. Set the environment variable OBJECT_CENTRIC_LIB_DATA to the folder where the datasets should be stored.

    Details:

    For example, on Linux or MacOS, add the following line to ~/.bashrc (or ~/.zshrc, depending on your shell):

     export OBJECT_CENTRIC_LIB_DATA=/path/to/datasets

    Then, restart the shell or run . ~/.bashrc (or . ~/.zshrc).

  3. Download the datasets with download_data.py.

    Click here for examples
    # Download all datasets
    python download_data.py -d all
    
    # Download all datasets, including style transfer versions
    python download_data.py -d all --include-style
    
    # Download only some datasets, without style transfer
    python download_data.py -d multidsprites clevr

    Each dataset is a .hdf5 file and its metadata is in a corresponding ${DATASET_NAME}_metadata.npy file. Custom datasets may override these defaults.

  4. Check the integrity of the dataset files by running python check_data.py.

Using the library

Quick start

  1. Train a model with default parameters:

    python train_object_discovery.py model=monet dataset=multidsprites

    This saves the model and the logs by default in outputs/runs/${MODEL}-${DATASET}-${DATETIME}.

  2. Resume training of a run, given the path to the root folder ${RUN_ROOT} of the run:

    python train_object_discovery.py model=monet dataset=multidsprites hydra.run.dir=${RUN_ROOT} allow_resume=true
  3. Evaluate reconstruction and segmentation metrics, given ${RUN_ROOT} (the path to the root folder of the run):

    python eval_metrics.py checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS
  4. Run the downstream object property prediction task (training + evaluation):

    python eval_downstream_prediction.py downstream_model=linear checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS
  5. Save visualizations (reconstructions, masks, slot reconstructions):

    python eval_qualitative.py checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS

All evaluation results are saved in ${RUN_ROOT}/evaluation, e.g., outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS/evaluation.

Currently, the library includes the following models:

and the following datasets:

  • CLEVR - clevr (the original dataset has 10 objects: to train on CLEVR6, add +dataset.variant=6 to the command line)
  • Multi-dSprites - multidsprites
  • Objects Room - objects_room
  • Shapestacks - shapestacks
  • Tetrominoes - tetrominoes
  • ClevrTex - clevrtex
    • This is not included in the original paper. Since textures are already present, we do not provide a style-transfer version. Also note that standard dataset variants are not yet supported.

Read the following sections for further details.

Training a model

python train_object_discovery.py model=${MODEL} dataset=${DATASET}

This command trains the specified model on the specified dataset, with default parameters defined by the hydra configuration files in config/. The base config file for this script is config/train_object_discovery.yaml.

The run folder is handled by hydra, and by default it is outputs/runs/${MODEL}-${DATASET}-${DATETIME}. This can be customized using hydra by adding, e.g., hydra.run.dir=outputs/runs/${model.name}-${dataset.name} to the command line.

The model and dataset correspond to config files -- e.g., model=slot-attention reads the model config from config/model/slot-attention.yaml and dataset=multidsprites reads the dataset config from config/dataset/multidsprites.yaml. In some cases we define custom parameters for specific combinations of dataset and model: these are defined in the folder config/special_cases.

Dataset variants can define dataset filters or transforms to test robustness to distribution shifts. A variant is picked by adding +dataset.variant=${VARIANT} to the command line: e.g. CLEVR6 is dataset=clevr +dataset.variant=6, and Tetrominoes with occlusions is dataset=tetrominoes +dataset.variant=occlusion. For more information on dataset variants, see config/dataset/variants/readme.md.

All models are configured through hydra, including the training setup. The default parameters are defined in the model's YAML file, and these can be overridden from the command line. E.g., we can change the foreground sigma, the MLP hidden size, and the learning rate schedule of MONet as follows:

python train_object_discovery.py model=monet dataset=shapestacks model.fg_sigma=0.15 model.encoder_params.mlp_hidden_size=128 trainer.exp_decay_rate=0.8
Click to expand details on available flags

There are some common flags that can be used with every model and dataset:

  • batch_size (default given by the model config).
  • trainer.steps: number of training steps (default given by the model config).
  • data_sizes: size of the train, validation, and test sets (defaults given by the dataset config).
  • trainer.optimizer_config: by default, the class, learning rate, and other parameters can be provided here (see e.g. config/model/monet.yaml). We can also implement a custom _make_optimizers() method that handles more complex settings, e.g., where we need multiple optimizers: see for example config/model/space.yaml and models/space/trainer.py.
  • trainer.clip_grad_norm: float value for gradient norm clipping, or None for no clipping.
  • frequency of checkpointing, validation, and logging: trainer.logweights_steps, trainer.logimages_steps, trainer.logloss_steps, trainer.checkpoint_steps, trainer.logvalid_steps.
  • allow_resume: if the directory of the run exists, this flag controls whether the script loads an existing checkpoint and resumes training, or it throws an exception.
  • num_workers: for PyTorch data loaders.
  • dataset.skip_loading: dummy data is loaded instead of the specified dataset (for debugging).
  • seed: random seed.
  • debug: if true, it launches a minimal run.
  • device: cpu or cuda (default: cuda).

Evaluation: metrics

python eval_metrics.py checkpoint_path=/path/to/run/folder

This command evaluates the reconstruction error (MSE) and 3 segmentation metrics (ARI, SC, mSC). Typically no customization is necessary, but see config/eval_metrics.yaml.

Click to expand details on available flags

The variant_types flag allows to evaluate the metrics on different variants of the original training dataset: this is used by default to evaluate generalization (see the list of default variants in config/eval_metrics.yaml). The overwrite flag allows overwriting the result folder for this evaluation, and is False by default.

The seed, debug, and device flags are also available here, with the same behavior as in train_object_discovery.py.

Evaluation: downstream task

python eval_downstream_prediction.py checkpoint_path=/path/to/run/folder downstream_model=linear

This command trains and evaluates a downstream linear model to predict (from the representations of the upstream model) the properties of the objects in a scene. This is configured by config/eval_downstream_prediction.yaml. See the comments on the file for more information. Note that a results subfolder is created specifically for each combination of matching, downstream model, and dataset variant.

Click to expand details on available flags

Typically useful flags (see the config file for more):

  • downstream_model: the type of downstream model, such as linear or MLP3.
  • matching: method for matching objects with model slots.
  • variant_types: for each of the specified variant types, train a downstream model and then test it on all variant types (including the one it was trained on).
  • steps
  • batch_size
  • learning_rate
  • train_size
  • validation_size
  • test_size

The seed, debug, overwrite, and device flags are also available here, with the same behavior as in eval_metrics.py.

Evaluation: save visualizations

python eval_qualitative.py checkpoint_path=/path/to/run/folder

This command saves model visualizations, and typically does not require customization. The seed, debug, overwrite, and device flags are also available here.

Sweeps

To run many experiments in a structured sweep over parameters and/or settings, the library has a "sweep" functionality.

For example, to train all object-centric models in the study in our paper, we defined a sweep in sweeps/configs/sweep_object_centric.py. This creates a sweep called "object_centric", which maps a model number to a specific configuration of command line arguments.

The first model in the sweep is trained as follows:

python sweep_train.py --sweep-name object_centric --model-num 0

Since in this case we have 10 seeds, 4 models, and 5 datasets, any model number up to 199 would be valid.

This script internally calls train_object_discovery.py with the appropriate arguments as prescribed by the sweep, and uses outputs/sweeps/sweep_${SWEEP_NAME}/${MODEL_NUMBER}/ as output folder.

Use python -m sweeps.sweep_progress SWEEP_NAME to get an overview of the overall progress of the sweep.

Extending the library

The library easily allows adding models, datasets, dataset variants, evaluation metrics, and downstream tasks. Feel free to reach out for questions at:

andrea [đöt] dittadi [åt] gmail [đöt] com

Contributors

Citation

If you use this library in your own work, please consider citing our paper as follows:

@inproceedings{dittadi2022generalization,
  title={Generalization and Robustness Implications in Object-Centric Learning},
  author={Dittadi, Andrea and Papa, Samuele and De Vita, Michele and Sch{\"o}lkopf, Bernhard and Winther, Ole and Locatello, Francesco},
  booktitle={International Conference on Machine Learning},
  year={2022},
}

Notes

In a follow-up paper, we use this library to investigate inductive biases in unsupervised object-centric learning when the objects in the training set have complex textures:

Inductive Biases for Object-Centric Representations in the Presence of Complex Textures
Samuele Papa, Ole Winther, Andrea Dittadi
UAI workshop on Causal Representation Learning, 2022

object-centric-library's People

Contributors

addtt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

object-centric-library's Issues

URL for dataset download

Hello! Thank you for sharing the great code.

I tried to use your repository but found out that provided remote URL for dataset download is not working.
Can you please check if it works?

Also, can you explain how the provided dataset is different from the Multi-object-dataset from deepmind?

Slot attention initialization params no gradients

Hi Andrea, thank you for the interesting work!
I've notice that in slot attention initialization parameters self.slots_mu and self.slots_log_sigma are learnable. Though, sampling from the associated gaussian distribution is not based on the reparameterization trick. Hence, no gradients are propagated to them.

Is it intended?

Question for reproducing the numbers of slot attention on CLEVR

Hello. First of all, thank you for sharing valuable resources!

Thanks to this repo, I was able to easily kickstart research in a new field.

I have a question about reproducing the results on CLEVR10 using the provided .yaml training config.

When I run the experiment using the command python train_object_discovery model=slot-attention dataset=clevr ++num_workers=4, the ARI on the original dataset achieves ~ 0.66, which suggests that the result is not properly reproduced.

I am attaching the training config and evaluation results together (BTW, the logging system is so well organized!).

It would be really helpful if you could point out the part that may be causing the discrepancy. Thank you!

train_config.yaml

seed: 12345
device: cuda
debug: false
num_workers: 4
allow_resume: false
trainer:
  clip_grad_norm: null
  logweights_steps: 1000
  logimages_steps: 10000
  logloss_steps: 1000
  checkpoint_steps: 1000
  logvalid_steps: 25000
  resubmit_steps: null
  resubmit_hours: null
  _target_: models.slot_attention.trainer.SlotAttentionTrainer
  steps: 500000
  use_warmup_lr: true
  warmup_steps: 10000
  use_exp_decay: true
  exp_decay_rate: 0.5
  exp_decay_steps: 100000
  optimizer_config:
    alg: Adam
    lr: 0.0004
dataset:
  output_features: all
  skip_loading: false
  _target_: data.datasets.Clevr
  width: 128
  height: 128
  num_background_objects: 1
  max_num_objects: 11
  name: clevr
  input_channels: 3
  dataset_path: clevr_10-full.hdf5
  downstream_features:
  - x
  - 'y'
  - size
  - shape
  - material
  - color
data_sizes:
- 90000
- 5000
- 5000
model:
  height: ${dataset.height}
  width: ${dataset.width}
  _target_: models.slot_attention.model.SlotAttentionAE
  name: slot-attention
  num_slots: 7
  latent_size: 64
  encoder_params:
    channels:
    - 64
    - 64
    - 64
    - 64
    kernels:
    - 5
    - 5
    - 5
    - 5
    paddings:
    - 2
    - 2
    - 2
    - 2
    strides:
    - 1
    - 2
    - 2
    - 1
  decoder_params:
    conv_transposes: true
    channels:
    - 64
    - 64
    - 64
    - 64
    - 64
    - 4
    kernels:
    - 5
    - 5
    - 5
    - 5
    - 5
    - 3
    strides:
    - 2
    - 2
    - 2
    - 2
    - 1
    - 1
    paddings:
    - 2
    - 2
    - 2
    - 2
    - 2
    - 1
    output_paddings:
    - 1
    - 1
    - 1
    - 1
    - 0
    - 0
    activations:
    - relu
    - relu
    - relu
    - relu
    - relu
    - null
  attention_iters: 3
  mlp_size: 128
  eps: 1.0e-08
  h_broadcast: 8
  w_broadcast: 8
batch_size: 32
uuid: 1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb

results.json

[
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "ari",
      "metric_value": 0.6597288250923157
    }
  },
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "mean_segcover",
      "metric_value": 0.1793033480644226
    }
  },
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "scaled_segcover",
      "metric_value": 0.24997809529304504
    }
  },
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "mse",
      "metric_value": 0.000645454041659832
    }
  },
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "mse_unmodified_fg",
      "metric_value": 0.00045366105041466653
    }
  },
  {
    "train_config.uuid": "1c8ed864-0352-4b1a-9812-7bb2e9b0e8cb",
    "eval_config": {
      "variant_type": "original",
      "checkpoint_path": "/data02/dongwon/object-centric-library/outputs/runs/slot-attention-clevr-2023-07-07_20-14-41",
      "device": "cuda",
      "seed": 12345,
      "batch_size": 64,
      "dataset_size": null,
      "starting_index": null
    },
    "results": {
      "metric_name": "mse_fg",
      "metric_value": 0.00045366105041466653
    }
  }
]

Pretrained weights

Hi, may I ask if you have pretrained weights for different datasets? Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.