Code Monkey home page Code Monkey logo

grimai's Introduction

GrimAI

GrimAI is a general purpose library build on top of pytorch.

The objective of the library is to provide a simple and flexible approach to model training.

How to use

  • Import your data and create your dataloader
  • Implement your Callback class that inherits from BaseCallBack. You need to implement the callback function that you want to use during the training. Each of these callbacks will be invoked automatically at each of the steps described by the method name.
    from core.callback.base_callback import BaseCallBack
        class MyCBS(BaseCallBack):
            def __init__(self):
                super().__init__()
            def before_fit(self,*args, **kwargs):pass
            def before_epoch(self,*args,**kwargs):pass
            def before_batch(self,*args, **kwargs):pass
            def before_forward_step(self,*args,**kwargs):pass
            def after_forward_step(self,*args,**kwargs):pass
            def fetch_data(self,*args,**kwargs):print("Mandatory")
            def loss_function(self,*args,**kwargs):print("Mandatory")
            def forward_step(self,*args,**kwargs):print("Mandatory")
            def backward_step(self,*args,**kwargs):print("Mandatory")
            def after_batch(self,*args, **kwargs):pass
            def after_epoch(self, *args, **kwargs):pass
            def after_fit(self,*args, **kwargs):pass
  • Within the inherited class you will always have access to the engine, containing all the variables and methods you need. For example:
    def fetch_data(self,*args,**kwargs):
        return self.engine.batch[0].to(self.engine.device),self.engine.batch[1].to(self.engine.device)
    def forward_step(self,*args,**kwargs):
        if self.engine.scaler is not None:
            with amp.autocast():
                outputs = self.engine.model(self.engine.data)
        else:
            outputs = self.engine.model(self.engine.data)
        return outputs
    def backword_step(self,*args,**kwargs):
        loss = self.engine.loss
        self.engine.optimizer.zero_grad()
        if self.engine.scaler is not None:
            self.engine.scaler.scale(loss).backward()
            self.engine.scaler.step(self.engine.optimizer)
            self.engine.scaler.update()
        else:
            loss.backward()
            self.engine.optimizer.step()
        return loss
  • Pass your callback in the invocation method:
    optimizer = optim.SGD(my_model.parameters(), lr=0.001, momentum=0.9)
    cbs = CBS()
    device = [0]
    engine = Engine(model=my_model,optimizer=optimizer,cbs=cbs,fp16=True,scheduler=None,device=device)
    engine.fit(epochs=10,train_dataloader=train_loader,valid_dataloader = valid_loader)

See the MNIST example for details.

Features

If you use the CBS already provided:

  • Mixed Precision Training already available passing fp16=True in the engine
  • Parallel training on GPUs available passing an array to device. For example with [0,1] your model will use GPU:0 and GPU:1

You can use this CallBack class and inject your special function. For example:

    cbs = CBS()
    def fetch_data(*args,**kwargs):
        print("my fetch data")
    cbs.fetch_data = fetch_data
    device = [0,1]
    engine = Engine(model=my_model,optimizer=optimizer,cbs=cbs,fp16=True,scheduler=None,device=device)
    engine.fit(epochs=10,train_dataloader=train_loader,valid_dataloader = valid_loader)

What's next

  • More stable callbacks class and function available by default
  • Create dataloader automatically for some class of data
  • More examples
  • Export your model with ONNX
  • Installing from pip

Installation

Work in progress

License

MIT

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.