Code Monkey home page Code Monkey logo

Comments (1)

diegofiori avatar diegofiori commented on May 18, 2024 1

Hello @justlike-prog!

I gave a look at the Detectron2 source code and I found an easy way to implement a workaround for optimizing it. In Detectron2 the whole computation is made by the Resnet-based backbone, so we can simply optimize the backbone for already getting good results.

The problem was that the backbone gives as output a dictionary and we need to map it to a tuple for being consistent with the Nebullvm API. I suggest defining two "wrapper" classes (one for the non-optimized and one for the optimized model) and then using them for running the Detectron2 optimized model.

Let's define the classes as

class BaseModelWrapper(torch.nn.Module):
    def __init__(self, core_model, output_dict):
        super().__init__()
        self.core_model = core_model
        self.output_names = [key for key in output_dict.keys()]
    
    def forward(self, *args, **kwargs):
        res = self.core_model(*args, **kwargs)
        return tuple(res[key] for key in self.output_names)


class OptimizedWrapper(torch.nn.Module):
    def __init__(self, optimized_model, output_keys):
        super().__init__()
        self.optimized_model = optimized_model
        self.output_keys = output_keys
    
    def forward(self, *args):
        res = self.optimized_model(*args)
        return {key: value for key, value in zip(self.output_keys, res)}

Then you can simply run the following code

import copy
from detectron2 import model_zoo 
from nebullvm import optimize_torch_model

config_path = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
model = model_zoo.get(config_path, trained=True) 
model.eval()
model_backbone = copy.deepcopy(model.backbone)
res = model_backbone(torch.randn(1, 3, 256, 256))  # needed for getting the output_keys
backbone_wrapper = BaseModelWrapper(model_backbone, res)
optimized_model = optimize_torch_model(backbone_wrapper, batch_size=1, input_sizes=[(3, 256, 256)], save_dir="./")
optimized_backbone = OptimizedWrapper(optimized_model, backbone_wrapper.output_names)
# finally replace the old backbone with the optimised one
model.backbone = optimized_backbone

If you need dynamic input shapes you'd need to add a few more arguments to the optimize_torch_model function. Please give a look at issue #26 for further info.

from nebuly.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.