Comments (4)
Close due to inactivity. Feel free to reopen it if you have further questions
from pytorch-opcounter.
Thanks for your interest. First I want to notify that thop
currently only counts FLOPs for feed-forward, backpropogate might be future feature. Then, if you want to only profile some layers , you can add a special judge in profile()
function. An example is shown below, hope it will help
# assume we want to ignore FLOPs in batchnorm layer
def profile(model, input_size, custom_ops={}):
def add_hooks(m):
if len(list(m.children())) > 0:
return
# ======== add one judgement here ===========
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
return
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
m_type = type(m)
fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
else:
logging.warning("Not implemented for ", m)
if fn is not None:
logging.info("Register FLOP counter for module %s" % str(m))
m.register_forward_hook(fn)
model.eval()
model.apply(add_hooks)
x = torch.zeros(input_size)
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
return total_ops, total_params
from pytorch-opcounter.
Thanks for the reply @Lyken17. I was interested in knowing how much of computation each layer takes. Suppose for simplicity we have 2 conv. layers with different kernel sizes, a fc layer. How can I find how many FLOP's each conv. layer takes. The reason I want to do this is I want to know the total reduction in FLOP's when a layer is dropped. Assuming both forward and backward propagation takes same number of FLOP's. So my idea was to get individual FLOP's for each layer and subtract those from the total FLOP's.
from pytorch-opcounter.
Sure you can do it.
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
# print layer-wise information here.
print(str(m), m.total_ops, m.total_params)
total_ops += m.total_ops
total_params += m.total_params
from pytorch-opcounter.
Related Issues (20)
- AttributeError: 'PReLU' object has no attribute 'total_params'
- How do I calculate the FLOPs of a model with some frozen layers during training? HOT 1
- Does it support non-image models? HOT 1
- question about the MACs of nn.BatchNorm2d
- got 0 ops for nn.MultiheadAttention HOT 7
- Count flops by a range
- thop/profile.py:12: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead. `if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):` HOT 2
- Does MACs and FLOPs count correctly for and INT8 quantized model? HOT 1
- Upload sdist to PyPI HOT 1
- Problem in bert HOT 1
- multiple inputs HOT 1
- Is the latest version calculate MACs or FLOPs HOT 2
- RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
- How to calculate the FLOPs of each type of layers?
- How to exclude flops of 1st input? HOT 1
- Incorrect macs without specifying batch size for conv layers
- will torch.matmul regards as zero_ops ?
- Is thop also effective for calculating Flops for spiking neural networks?
- rename calculate_conv2d_flops HOT 1
- thop calculates torch.nn module params incorrectly HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-opcounter.