Code Monkey home page Code Monkey logo

gconv's Introduction

Continuous Regular Group Convolutions (WIP ๐Ÿ‘ทโ€โ™€๏ธ๐Ÿ‘ทโ€โ™‚๏ธ)

This package implements a Pytorch framework for group convolutions that are easy to use and implement in existing Pytorch modules. The package offers premade modules for E3 and SE3 convolutions, as well as basic operations such as pooling and normalization for $\mathbb{R}^n \rtimes H$ input. The method is explained in the paper Regular SE(3) Group Convolutions for Volumetric Medical Image Analysis, accepted at MICCAI 2023 (see reference below).

Installation from Source

Download gconv and save to a directory. Then from that directory run the following command:

pip install -e gconv

Getting Started

The gconv modules are as straightforward to use as any regular Pytorch convolution module. The only difference is the output consisting of both the feature maps, as well as the group elements on which they are defined. See the example below:

import torch                                                                        # 1
import gconv.gnn as gnn                                                             # 2
                                                                                    # 3
x1 = torch.randn(16, 3, 28, 28, 28)                                                 # 4
                                                                                    # 5
lifting_layer = gnn.GLiftingConvSE3(in_channels=3, out_channels=16, kernel_size=5)  # 6
gconv_layer = gnn.GSeparableConvSE3(in_channels=16, out_channels=32, kernel_size=5) # 7
                                                                                    # 8
pool = gnn.GAvgGlobalPool()                                                         # 9
                                                                                    # 10
x2, H1 = lifting_layer(x1)                                                          # 11
x3, H2 = gconv_layer(x2, H1)                                                        # 12
                                                                                    # 13
y = pool(x3, H2)                                                                    # 14

In line 5, a random batch of three-channel $\mathbb{R}^3$ volumes is created. In line 6, the $\mathbb{R}^3$ is lifted to $\text{SE}(3) = \mathbb{R}^3 \rtimes \text{SO}(3)$. In line 7, an $\text{SE}(3)$ convolution is performed. In line 14, a global pooling is performed, resulting in $\text{SE}(3)$ invariant features.

Furthermore, gconv offers all the necessary tools to build fully custom group convolutions. All that is required is implementing 5 (or less, depending on the type of convolution) group ops! For more details on how to implement custom group convolutions, see gconv_tutorial.ipynb.

Requirements:

python >= 3.10
torch
tqdm

Reference:

Paper accepted at MICCAI 2023.

@misc{kuipers2023regular,
      title={Regular SE(3) Group Convolutions for Volumetric Medical Image Analysis}, 
      author={Thijs P. Kuipers and Erik J. Bekkers},
      year={2023},
      eprint={2306.13960},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

gconv's People

Contributors

ebekkers avatar thijskuipers1995 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

dgcnz

gconv's Issues

GLiftingKernelSE2 ignores mask=False

Description

GLiftingKernelSE2 creates and uses mask independently from the mask param.

mask: bool = True,
grid_H: Tensor | None = None,
) -> None:
"""
Implements SE2 lifting kernel.
Arguments:
- in_channels: int denoting the number of input channels.
- out_channels: int denoting the number of output channels.
- kernel_size: int denoting the spatial kernel size.
- group_kernel_size: int denoting the group kernel size.
- groups: number of groups for depth-wise separability.
- sampling_mode: str indicating the sampling mode. Supports bilinear (default)
or nearest.
- sampling_padding_mode: str indicating padding mode for sampling. Default
border.
- mask: bool if true, will initialize spherical mask.
- grid_H: tensor of reference grid used for interpolation. If not
provided, a uniform grid of group_kernel_size will be
generated. If provided, will overwrite given group_kernel_size.
"""
if grid_H is None:
grid_H = so2.uniform_grid(group_kernel_size)
grid_Rn = gF.create_grid_R2(kernel_size)
mask = gF.create_spherical_mask_R2(kernel_size)

Discretizing SO(3)

In GLiftingKernelSE3, the default group_kernel_size is = 4. Intuitively, it seems that a much bigger size is needed to sample SO(3), in order to properly estimate the integral. For instance, I tried to uniformly sample (with repulsion) 128 rotations and the distance (angle) between the closest two points (rotations) was ~43 degrees. Is group_kernel_size = 4 the number used in your paper? If so, why did you choose this number?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.