Code Monkey home page Code Monkey logo

Comments (3)

Gabri95 avatar Gabri95 commented on June 14, 2024

Hi @prclibo,

Thanks for your question!

Your interpretation is indeed correct but this is the purpose of the SteerableBasis class.
The 'out_channels x in_channels x 4 x len(angles) ' dimensional bases are built in line 192 by the irreps_bases (created between lines 100 and 105).
The purpose of SteerableBasis is to combine the small dimensional bases associated to the irreps in the input and output representations to build a basis for the arbitrary representations passed in input.
This is done by padding the irreps bases and applying the change of basis matrices associated with the specified input and output representations.
In other words, SterableBasis implements the equation 4 (page 4) of the paper, while IrrepsBasis implements the solutions for equation 3 for a specific irreps pair.

The padded basis built by SteerableBasis is then used for the kernel expansion inside the convolution layer.

If I understood well, your suggestion is to directly use the irreps bases for the kernel expansion.
I agree that this would be theoretically more efficient as the bases are generally much smaller.
However, different irreps generally have different sizes and different bases. Moreover, the same layer often contains multiple copies of all of them. Finally, the change of basis matrices Q_in and Q_out needs to be applied on the final kernel after filling each small irreps block.
These issues prevent an efficient implementation of the kernel expansion directly based on irreps.
Indeed, one would need to iterate over every pair of input/output irreps and fill those blocks sequentially plus applying the final change of basis matrices (Q_out and Q_in in Eq. 4) to the whole filter at run time.
Because in most architectures only one single representation (i.e. field type) is used per layer (e.g. in a GCNN), all these operations can be performed offline (in the SteerableBasis class) to build a single (although larger) basis which can be shared among all pairs of input/output fields.
Moreover, because only one single basis is used, the kernel expansion can be efficiently performed on the GPU with a simple batched matrix multiplication.
Here the basis for a single pair of input/output field types (representations) is expanded.
Here the full convolutional kernel is filled with the expanded bases associated with each input/output fields pairs.
If all input fields have the same type and all output fields have the same (potentially different) type, we can expand the same basis for each pair (line 325).

We initially implemented an alternative version of the kernel expansion which directly uses the irreps basis but found it to be much slower in practice and, therefore, we chose to not include it in the final version of the library.
However, our implementation was purely in Python with a couple of for-loops to iterate over all pairs.
I guess this could be much more efficiently implemented in CUDA, parallelising the block-filling part of the algorithm.
The application of the change of basis matrices Q_in and Q_out would still introduce some additional cost but, maybe, overall, this method will still be faster than the currently implemented one if the CUDA-implemented irreps-based block filling saves enough computations.

Do you think the kernel expansion could benefit from a CUDA-implementation as explained above?

Please, let me know if this makes sense for you or if you have more questions

Thanks!
Gabriele

from e2cnn.

prclibo avatar prclibo commented on June 14, 2024

@Gabri95 Thank you so much for the detailed reply!

Yes I agree it might not be straightforward to alleviate such inefficiency due to the different numbers of basis for each irreps. From the view of complexity. A slice of samples[:, :, x, y] returned by the following code

samples = self._sample_direct_sum(angles)

is usually of the form:

array([[ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.7071, -0.7071,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.7071,  0.7071,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ]])

I think invalid zeros is one of the reasons for inefficiency. There is probably a way (though troublesome) to merge multiple slices before the change of basis.

Do you think the kernel expansion could benefit from a CUDA-implementation as explained above?

Yes. I agree CUDA parallelization will help. And just to clarify, this block filling should only happen in initialization, right?

from e2cnn.

Gabri95 avatar Gabri95 commented on June 14, 2024

Yes, indeed, all this is part of the preprocessing, which is performed the first time one instantiates a convolution layer (N.B.: the following instantiations of conv layers which share the same basis do not recompute the basis).

After initialization, only the final basis is stored in the form of a single tensor, such that the construction of the filter can be efficiently done in a single batched matrix multiplication (here).
To have one unique tensor, we need to pad the smaller irreps bases, which theoretically results in an increased computational cost.
However, in practice, this is much more efficient as it can be nicely run on GPU.

Unfortunately, I am not proficient enough with CUDA yet to support this optimization :(

from e2cnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.