Comments (1)
Hello @justlike-prog!
I gave a look at the Detectron2 source code and I found an easy way to implement a workaround for optimizing it. In Detectron2
the whole computation is made by the Resnet-based backbone, so we can simply optimize the backbone for already getting good results.
The problem was that the backbone gives as output a dictionary and we need to map it to a tuple for being consistent with the Nebullvm API. I suggest defining two "wrapper" classes (one for the non-optimized and one for the optimized model) and then using them for running the Detectron2
optimized model.
Let's define the classes as
class BaseModelWrapper(torch.nn.Module):
def __init__(self, core_model, output_dict):
super().__init__()
self.core_model = core_model
self.output_names = [key for key in output_dict.keys()]
def forward(self, *args, **kwargs):
res = self.core_model(*args, **kwargs)
return tuple(res[key] for key in self.output_names)
class OptimizedWrapper(torch.nn.Module):
def __init__(self, optimized_model, output_keys):
super().__init__()
self.optimized_model = optimized_model
self.output_keys = output_keys
def forward(self, *args):
res = self.optimized_model(*args)
return {key: value for key, value in zip(self.output_keys, res)}
Then you can simply run the following code
import copy
from detectron2 import model_zoo
from nebullvm import optimize_torch_model
config_path = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
model = model_zoo.get(config_path, trained=True)
model.eval()
model_backbone = copy.deepcopy(model.backbone)
res = model_backbone(torch.randn(1, 3, 256, 256)) # needed for getting the output_keys
backbone_wrapper = BaseModelWrapper(model_backbone, res)
optimized_model = optimize_torch_model(backbone_wrapper, batch_size=1, input_sizes=[(3, 256, 256)], save_dir="./")
optimized_backbone = OptimizedWrapper(optimized_model, backbone_wrapper.output_names)
# finally replace the old backbone with the optimised one
model.backbone = optimized_backbone
If you need dynamic input shapes you'd need to add a few more arguments to the optimize_torch_model
function. Please give a look at issue #26 for further info.
from nebuly.
Related Issues (20)
- [chatllama]Do I need to split the llama model manully? HOT 2
- [Chatllama] facebook/opt-350m is missing in rlhf/model_list.py HOT 2
- module not found:chatllama.rlhf.dataset HOT 1
- Support for torch 2.0 HOT 1
- Issues with accelerate and deepspeed training HOT 4
- [chatllama]How models enable inference HOT 1
- [Chatllama] what's supposed to be in the Actor checkpoint dir? HOT 3
- [chatllama]Puzzled about the update of the critic model
- [Speedster] Optimization failed with PytorchBackendCompiler HOT 4
- yolov8 + nebuly | AttributeError: type object 'DummyClass' has no attribute 'models' HOT 10
- Evaluating accuracy of only the reward model
- [speedster] _dl_check_map_versions assertion error with optimize_model and ONNX compilers HOT 3
- torch2.0 support on speedster HOT 2
- Yolov8-Pose Model
- [ Speedster] With Hugging Face notebook code on nebulydocker/nebullvm container: RuntimeError: Expected all tensors to be on the same device HOT 5
- How to generate and perform inference for an ONNX model HOT 2
- Forward Forward Algorithm Questions HOT 2
- [Speedster] TensorRt OSError: [WinError 127] The specified procedure could not be found
- [Speedster] optimize_model took 10 hours, and it's not over yet
- nebullvm LICENSE and commercial use?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nebuly.