Code Monkey home page Code Monkey logo

torch2trt_dynamic's Introduction

torch2trt dynamic

This is a branch of torch2trt with dynamic input support

Note that not all layers support dynamic input such as torch.split() etc...

Usage

Here are some examples

Convert

from torch2trt_dynamic import torch2trt_dynamic
import torch
from torch import nn
from torchvision.models.resnet import resnet50

# create some regular pytorch model...
model = resnet50().cuda().eval()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
opt_shape_param = [
    [
        [1, 3, 128, 128],   # min
        [1, 3, 256, 256],   # opt
        [1, 3, 512, 512]    # max
    ]
]
model_trt = torch2trt_dynamic(model, [x], fp16_mode=False, opt_shape_param=opt_shape_param)

Execute

We can execute the returned TRTModule just like the original PyTorch model

x = torch.rand(1,3,256,256).cuda()
with torch.no_grad():
    y = model(x)
    y_trt = model_trt(x)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))

Save and load

We can save the model as a state_dict.

torch.save(model_trt.state_dict(), 'alexnet_trt.pth')

We can load the saved model into a TRTModule

from torch2trt_dynamic import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('alexnet_trt.pth'))

Setup

To install without compiling plugins, call the following

git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic
cd torch2trt_dynamic
python setup.py develop

Set plugins(optional)

Some layers such as GN need c++ plugins. Install the plugin project below

amirstan_plugin

DO NOT FORGET to export the environment variable AMIRSTAN_LIBRARY_PATH

How to add (or override) a converter

Here we show how to add a converter for the ReLU module using the TensorRT Python API.

import tensorrt as trt
from torch2trt_dynamic import tensorrt_converter

@tensorrt_converter('torch.nn.ReLU.forward')
def convert_ReLU(ctx):
    input = ctx.method_args[1]
    output = ctx.method_return
    layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU)
    output._trt = layer.get_output(0)

The converter takes one argument, a ConversionContext, which will contain the following

  • ctx.network - The TensorRT network that is being constructed.

  • ctx.method_args - Positional arguments that were passed to the specified PyTorch function. The _trt attribute is set for relevant input tensors.

  • ctx.method_kwargs - Keyword arguments that were passed to the specified PyTorch function.

  • ctx.method_return - The value returned by the specified PyTorch function. The converter must set the _trt attribute where relevant.

Please see this folder for more examples.

torch2trt_dynamic's People

Contributors

jaybdub avatar grimoire avatar narendasan avatar geoffreychen777 avatar mt1871 avatar snowmasaya avatar v-qjqs avatar vfdev-5 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.