Code Monkey home page Code Monkey logo

pytorch-template's Introduction

PyTorch Template Project

A simple template project using PyTorch which can be modified to fit many deep learning projects.

Basic Usage

The code in this repo is an MNIST example of the template, try run:

python main.py

The default arguments list is shown below:

usage: main.py [-h] [-b BATCH_SIZE] [-e EPOCHS] [--resume RESUME]
               [--verbosity VERBOSITY] [--save-dir SAVE_DIR]
               [--save-freq SAVE_FREQ] [--data-dir DATA_DIR]
               [--validation-split VALIDATION_SPLIT] [--no-cuda]

PyTorch Template

optional arguments:
  -h, --help    show this help message and exit
  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        mini-batch size (default: 32)
  -e EPOCHS, --epochs EPOCHS
                        number of total epochs (default: 32)
  --resume RESUME
                        path to latest checkpoint (default: none)
  --verbosity VERBOSITY
                        verbosity, 0: quiet, 1: per epoch, 2: complete (default: 2)
  --save-dir SAVE_DIR
                        directory of saved model (default: model/saved)
  --save-freq SAVE_FREQ
                        training checkpoint frequency (default: 1)
  --data-dir DATA_DIR
                        directory of training/testing data (default: datasets)
  --validation-split VALIDATION_SPLIT
                        ratio of split validation data, [0.0, 1.0) (default: 0.0)
  --no-cuda   use CPU in case there's no GPU support

You can add your own arguments.

Structure

├── base/ - abstract base classes
│   ├── base_data_loader.py - abstract base class for data loaders.
│   ├── base_model.py - abstract base class for models.
│   └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│   └── data_loader.py
│
├── datasets/ - default dataset folder
│
├── logger/ - for training process logging
│   └── logger.py
│
├── model/ - models, losses, and metrics
│   ├── modules/ - submodules of your model
│   ├── saved/ - default checkpoint folder
│   ├── loss.py
│   ├── metric.py
│   └── model.py
│
├── trainer/ - trainers for your project
│   └── trainer.py
│
└── utils
     ├── utils.py
     └── ...

Customization

Training

In most cases, you need to modify trainer/trainer.py to fit the training logic of your project

Data loading

You can customize data loader to fit your project, just modify data_loader/data_loader.py or add other files.

Model

Implement your model under model/

Loss/metrics

If you need to change the loss function or metrics, first import those function in main.py, then modify this part:

loss = my_loss
metrics = [my_metric]

You'll see the logging has changed during training:

⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.033256
{'epoch': 1, 'loss': 0.14182623870152963, 'my_metric': 0.9568761114404268, 'val_loss': 0.06394806604976841, 'val_my_metric': 0.9804478609625669}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.14183.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.013225
⋯

Multiple metrics

If you have multiple metrics in your project, just add it to the metrics list:

loss = my_loss
metrics = [my_metric, my_metric2]

Now the logging shows two metrics:

⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.003278
{'epoch': 1, 'loss': 0.13541310020907665, 'my_metric': 0.9590804682868999, 'my_metric2': 1.9181609365737997, 'val_loss': 0.05264156081223173, 'val_my_metric': 0.9837901069518716, 'val_my_metric2': 1.9675802139037433}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.13541.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.023072
⋯

Currently the name shown in log is the name of the function.

Additional logging

If you have additional information to be logged, you can modify _train_epoch() in class Trainer, for example, say you have an additional log saved as a dictionary:

additional_log = {"x": x, "y": y}

just merge it with log as shown below before returning:

log = {**log, **additional_log}
return log

Validation data

If you have separate validation data, try implement another data loader for validation, otherwise if you just want to split validation data from training data, try pass --validation-split 0.1, in some cases you might need to modify utils/util.py

Checkpoint naming

If you need to add prefix to your checkpoint, modify this line in main.py

identifier = type(model).__name__ + '_'

The prefix of the model will change, if you need to further change the naming of checkpoints, try modify _save_checkpoint() in class BaseTrainer

Contributing

Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8

Acknowledgments

This project is heavily inspired by the project Tensorflow-Project-Template by Mahmoud Gemy, be sure to star it!

pytorch-template's People

Contributors

victoresque avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.