Code Monkey home page Code Monkey logo

Comments (4)

Lyken17 avatar Lyken17 commented on August 29, 2024
    flop = (input_size[0] / m.stride[0] * input_size[1] / m.stride[1]) * m.kernel_size[0] ** 2 * ((m.in_channels/m.groups) * (m.out_channels/m.groups) * m.groups)

is equivalent to

     out_w * out_h * kw * kh * cin / groups * cout / groups * groups 
     => out_w * out_h * cout * kw * kh * cin / groups 
     => output_elements * kw * kh * cin / groups 
     => output_elements * ops_per_element * cin / groups 

This is exactly the same as count_hooks.py#L23. I didn't see any difference.

from pytorch-opcounter.

li3cmz avatar li3cmz commented on August 29, 2024

The code I test by Py-OpCounter is as follows:

from EXTD_64 import build_extd
from thop import profile
from thop import clever_format


num_classes = 2
img_dim = 640 
s3fd_net = build_extd('train', num_classes) # This is the [EXTD-Pytorch](https://github.com/clovaai/EXTD_Pytorch)
net = s3fd_net


input = torch.randn(1, 3, img_dim, img_dim)
flops, params = profile(net, inputs=(input, ))
flops, params = clever_format([flops, params], "%.8f")
print(flops,params)

And the output is

1.08439040G 162.35200000K

While I use "comput_flops" function in EXTD,

import torch
import numpy as np
import torch.nn as nn


def compute_flops(model, image_size):
  flops = 0.
  input_size = image_size
  for m in model.modules():
    if len(list(m.children())) > 0:  # skip for non-leaf module
      continue
   
    if isinstance(m, nn.AvgPool2d) or isinstance(m, nn.MaxPool2d):
      input_size = input_size / 2.
    if isinstance(m, nn.Conv2d):
      if m.groups == 1:
        flop = (input_size[0] / m.stride[0] * input_size[1] / m.stride[1]) * m.kernel_size[0] ** 2 * m.in_channels * m.out_channels
      else:
        flop = (input_size[0] / m.stride[0] * input_size[1] / m.stride[1]) * m.kernel_size[0] ** 2 * ((m.in_channels/m.groups) * (m.out_channels/m.groups) * m.groups)
      flops += flop
      if m.stride[0] == 2: input_size = input_size / 2.
    print(m, flops, input_size)

  return flops / 1000000000., flops / 1000000



num_classes = 2
img_dim = 640 
s3fd_net = build_extd('train', num_classes) 
net = s3fd_net


gflops, mflops = compute_flops(net, np.array([img_dim, img_dim]))
print('# of params in Classification model: %d, flops: %.2f GFLOPS, %.2f MFLOPS, image_size: %d' % \
      (sum([p.data.nelement() for p in net.parameters()]), gflops, mflops, img_dim))

I got the result:

# of params in Classification model: 162352, flops: 11.15 GFLOPS, 11152.59 MFLOPS, image_size: 640

EXTD has a module that is iteratively used during forward. I don't know if it affects the calculation of OpCounter.

from pytorch-opcounter.

Lyken17 avatar Lyken17 commented on August 29, 2024

Nice catch. I realize instead of

    m.total_ops = torch.Tensor([int(total_ops)])

It should be

    m.total_ops += torch.Tensor([int(total_ops)])

to support iteratively forward structures. I have fixed in recent commit. Can you upgrade thop and try again?

from pytorch-opcounter.

li3cmz avatar li3cmz commented on August 29, 2024

Thanks! I have tried again and the problem is solved now!

from pytorch-opcounter.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.