Code Monkey home page Code Monkey logo

simclr-tf2-distribute's Introduction

simclr-tf2-distribute

Implementation of SimCLR with TensorFlow 2 / tf.distribute.Strategy.


SimCLR needs to be trained in a huge batch size, so it is practically necessary to support distributed learning.

This time, use TPU Strategy(If you use 8-cores, CloudTPU is more inexpensive than GPU). If the GPU is better, you can use MirroredStrategy.

If you want to know about tf.distribute.Strategy, see here https://www.tensorflow.org/guide/distributed_training

Acknowledgements

I reuse some of the code from sayakpaul/SimCLR-in-TensorFlow-2 in this implementation.
Thanks for showing me the reference minimal implementation by tf2.

The differences in my implementation are the following points,

  • Supports distributed training, to train with huge batch sizes.
  • Instead of a custom training loop, use tf.keras.Model.fit.
    • Take advantage of the tf.keras.Model class.
    • Easy to use various callbacks, saving the state of the optimizer during training, etc.
  • Use cloud-tpu for training.

Usage

Pre-required

To train with CloudTPU, you need to prepare for the following,

  • Setup GCP project and enable CloudTPU
  • Create CloudTPU node
    • Recommended: preemptive-tpu v3-8 ($2.4/h)
  • Convert the dataset to TFRecord files and upload to GCS.
    • When use TPU, all the files used during training need to be put in GCS.
    • example:gs://{{bucket_name}}/{{dataset_name}}/train-00001.tfrec ...
    • Click here for documentation on tfrecord.
  • Set the environment variable "TPU_NAME"
# Structure of gcs bucket (For your reference).
gs://{{bucket_name}}/

├── datasets
│   ├── dataset-A
│   └── dataset-B
│       ├── train-00001.tfrec
│       ├── train-00002.tfrec
│       ├── ...
│       ├── valid-00149.tfrec
│       └── valid-00150.tfrec
└── jobs
    ├── job-A
    ├── job-B
    └── job-C
        ├── pretrain
        │   ├── checkpoints
        │   ├── logs
        │   └── saved_model
        ├── finetune
        ├── extract-feature
        └── linear-evaluation

Pretrain

$ python src/pretrain.py --global_batch_size=1024 --epochs=50 --learning_rate=0.0001 \
    --temperature=0.1 --embedded_dim=128 --dataset=gs://{{bucket-name}}/{{tfrecord_dir}} \
    --model="resnet" --job_dir=gs://{{bucket-name}}/{{job_dir}}

Finetune

$ python src/finetune.py --global_batch_size=1024 --epochs=50 --learning_rate=0.0001 --proj_head=1 \
    --percentage=10 --num_classes=1000 --dataset=gs://{{bucket-name}}/{{tfrecord_dir}} \
    --job_dir=gs://{{bucket-name}}/{{job_dir}}

Linear Evaluation

# extract
$ python scripts/extract_feature.py --batch_size=512 --proj_head=1 --dataset=gs://{{bucket-name}}/{{tfrecord_dir}} \
    --job_dir=gs://{{bucket-name}}/{{job_dir}} --le_target="finetune"

# linear evaluation
$ python scripts/linear-evaluation.py --batch_size=512 --job_dir=gs://{{bucket-name}}/{{job_dir}} \
     --embedded_dim=512 --num_classes=1000

Visualization

Records of each experiment (Pretrain/Finetune/LE) is stored in the gs://{{job_dir}}/{{task}}/logs directory.

If you want to visualize the experimental results, start tensorboard and load these logs.

$ tensorboard --logdir gs://{{job_dir}}/{{task}}/logs

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.