Code Monkey home page Code Monkey logo

pytorch-project-framework's Introduction

PyTorch-Project-Framework

Travis CircleCI License PRs Welcome

A high cohesion, low coupling, and plug-and-play project framework for PyTorch.

Folder Structure

  ├── configs
  |    ├── BaseConfig.py  - the loader of all configuration file
  |    ├── BaseTest.py  - the test class of all configuration file
  |    ├── Env.py  - the loader of environmental configuration file
  |    └── Run.py  - the loader of hyperparameter configuration file
  |
  ├── datasets
  |    ├── functional  - the package of functional methods
  |    ├── BaseDataset.py  - the abstract class of all dataset
  |    ├── BaseTest.py  - the test class of all dataset
  |    └── ...  - any dataset of your project
  |
  ├── models
  |    ├── functional  - the package of functional methods
  |    ├── shallow  - the package of shallow methods
  |    ├── BaseModel.py  - the abstract class of all model
  |    ├── BaseTest.py  - the test class of all model
  |    └── ...  - any model of your project
  |
  ├── res
  |    ├── env  - the folder contains any json file of environmental configuration
  |    ├── datasets  - the folder contains any json file of dataset configuration
  |    ├── models  - the folder contains any json file of model configuration
  |    └── run  - the folder contains any json file of hyperparameter configuration
  |
  ├── test
  |    ├── test_configs.py  - the unittest classes of package configs
  |    ├── test_datasets.py  - the unittest classes of package datasets
  |    ├── test_models.py  - the unittest classes of package models
  |    └── test_utils.py  - the unittest classes of package utils
  |
  ├── utils
  |    ├── common.py  - the common methods
  |    ├── logger.py  - the logger class
  |    ├── summary.py  - the summary class
  |    └── ...  - any utils of your project
  |
  ├── main.py  - the main class of framework
  |
  └── test_component.py  - the global test class

Main Components

Datasets

  • Base dataset

    Base dataset is an abstract class that must be Inherited by any dataset you create, the idea behind this is that there's much shared stuff between all datasets. The base dataset mainly contains:

    • more - add / update unique configuration to dataset
    • load - load dataset
    • split - create trainset and testset
  • Your dataset

    Here's where you implement your dataset. So you should:

    • Create your dataset class and inherit the BaseDataset class
    • Override load method
    • Override other methods if your need special implementation
    • Add your dataset name to datasets/__init__.py
    • Create json file of your dataset's configuration in res/datasets/

Models

  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model mainly contains:

    • check_cfg - filter data set
    • train - train step
    • test - test step
    • load - load previously trained model
    • save - save model
  • Your model

    Here's where you implement your model. So you should:

    • Create your model class and inherit the BaseModel class
    • Override train / test method
    • Override other methods if your need special implementation
    • Add your model name to models/__init__.py
    • Create json file of your model's configuration in res/models/

How to Use

Here's how to use this framework, you should do the following:

  • Dataset

    • In datasets folder create a class that inherit the BaseDataset class

       # YourDataset.py
       class YourDataset(datasets.BaseDataset):
           def __init__(self, cfg, **kwargs):
               super().__init__(cfg, **kwargs)
    • Override load method to load dataset

       # In YourDataset class
       def load(self):
           """
           Here load your dataset
           The parameters in `cfg` are load from json file of your dataset's configuration
           For example:
           - Create 4 random images of size (depth, height, width) as source data 
           - Create 4 random labels as target data
           Return data dictionary and the amount of data
           """
      
           data_count = 4
           source = numpy.random.rand(data_count, self.cfg.depth, self.cfg.height, self.cfg.width)
           target = numpy.random.randint(0, self.cfg.label_count, (data_count, 1))
      
           return {'source': source, 'target': target}, data_count
    • Add your dataset name to datasets/__init__.py

      from .YourDataset import YourDataset
    • Create json file of your dataset's configuration in res/datasets/

      {
          "name": "YourDataset", // same with your dataset class name
          // All dataset parameter your need where create `YourDataset` class
          // For example, the size of images and K-fold cross-validation
          "source": {
              depth: 3,
              height: 128,
              width: 128
          },
          "cross_folds": 2
      }
      
  • Model

    • In models folder create a class that inherit the BaseModel class

       # YourModel.py
       class YourModel(models.BaseModels):
           def __init__(self, cfg, data_cfg, run, **kwargs):
               super().__init__(cfg, data_cfg, run, **kwargs)
      
               # The parameters in `cfg` are load from json file of your model's configuration
               # The parameters in `data_cfg` are load from json file of dataset's configuration
               # The parameters in `run` are load from json file of hyperparameter configuration
      
               # Create model, optimizer, criterion, and etc.
               # For example:
               # - model: Linear
               # - criterion: L1 loss
               # - optimizer: Adam
               self.model = torch.nn.Linear(self.cfg.input_dims, self.cfg.output_dims).to(self.device)
               self.criterion = torch.nn.L1Loss.to(self.device)
               self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.run.lr, betas=(self.run.b1, self.run.b2))
    • Override two methods train and test to write the logic of the training and testing process

      # In YourModel class
      def train(self, epoch_info, sample_dict):
          """
          epoch_info: the epoch information
          sample_dict: the dictionary of train data
      
          Implement the logic of training process
          For example:
              source -> [model] -> predict -> [criterion] (+target) -> loss
          Return loss dictionary
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.train()
          self.optimizer.zero_grad()
          predict = self.model(source)
          loss = self.criterion(predict, target)
          loss.backward()
          self.optimizer.step()
      
          # Others you need to calculate
      
          return {'loss': loss}
      
      def test(self, epoch_info, sample_dict):
          """
          batch_idx: the epoch information
          sample_dict: the dictionary of test data
      
          Implement the logic of testing process
          For example:
              source -> [model] -> predict
          Return dictionary of data which you want saved
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.eval()
          predict = self.model(source)
      
          # Others you need to calculate
      
          return {'target': target, 'predict': predict}
    • Add your model name to models/__init__.py

      from .YourModel import YourModel
    • Create json file of your model's configuration in res/models/

      {
          "name": "YourModel", // same with your model class name
          // All model parameter your need where create `YourModel` class
          // For example, the dimensions of input and output
          "input_dims": 256,
          "output_dims": 1
      }
      
  • Hyperparameter

    • Create json file of your hyperparameter's configuration in res/run/

      {
          "name": "YourHP",
          // Basic hyperparameter
          "batch_size": 32,
          "epochs": 200,
          "save_step": 10,
          // Hyperparameters your need where create optimizer in `YourModel` class or others
          // For example, learning rate
          "lr": 2e-4
      }
      
  • Run main.py to start training or testing

    • Training with configuration files res/datasets/yourdataset.json, res/models/yourmodel.json, and res/run/yourhp.json on GPU 0

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0

    Every save_step epoch trained model and data that you want to save will be saved in the folder save/[yourmodel]-[yourhp]-[yourdataset]-[index of cross-validation].

    • If you want to test epoch 10

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0 -t 10

Contributing

Any kind of enhancement or contribution is welcomed.

License

The code is licensed with the MIT license.

pytorch-project-framework's People

Contributors

lmy0217 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.