Code Monkey home page Code Monkey logo

deepsphere-pytorch's Introduction

DeepSphere: a graph-based spherical CNN

Documentation Status

This is a PyTorch implementation of DeepSphere.

Resources

Code:

Papers:

  • DeepSphere: Efficient spherical CNN with HEALPix sampling for cosmological applications, 2018.
    [paper, blog, slides]
  • DeepSphere: towards an equivariant graph-based spherical CNN, 2019.
    [paper, poster]
  • DeepSphere: a graph-based spherical CNN, 2020.
    [paper, slides, video]

Data

The data used for the experiments contains a downsampled snapshot of the Community Atmospheric Model v5 (CAM5) simulation. The data is based on the paper UGSCNN (Jiang et al., 2019). The simulation can be thought of as a 16 channel "image", where each channel corresponds to a climate related measurement. The task is to learn how to infer the correct class for each pixel given the 16 channels. Each pixel is labelled either as background, as being part of a tropical cyclone or as being part of an atmospheric river.

alt text

Quick Start

In order to reproduce the results obtained, it is necessary to install the PyGSP branch containing the graph processing for equiangular, icosahedron, and healpix samplings. In future versions, PyGSP will be in the requirements. Subsequently, please refer yourself to the Pytorch Getting Started information page to run the correct conda install command corresponding to your operating system, python version and cuda version. Once those requirements are met, you can install the deepsphere package in your environment.

Our recommendation for a linux based machine is:

conda create --name deepsphere python=3.7

source activate deepsphere

pip install git+https://github.com/epfl-lts2/pygsp.git@39a0665f637191152605911cf209fc16a36e5ae9#egg=PyGSP

conda install pytorch=1.3.1 torchvision=0.4.2 cudatoolkit=10.0 -c pytorch

pip install git+https://github.com/deepsphere/deepsphere-pytorch

The package offers the experiment parameters stored in a Yaml config file, which can be used by running a script from the command line.

A special note should be made for the pytorch computation device. If nothing is stipulated in the command line, the device is set to CPU. To set the device to GPU (cuda) one can indicate —gpu in the command line, with or without the desired GPU device IDs (e.g. --gpu 1 2, if the model is supposed to run on the GPU 1 and 2).

To visualize any icosahedron or equiangular data the package provides a demonstration Jupyter notebook for data in 2D or 3D.

Using the predefined parameters you can train and validate the model using the following command:

python run_ar_tc.py --config-file config.example.yml --gpu

If you don't have the data yet, please create the folder /data/climate/ (or change the file location in the yaml file) and add download True to the command.

Mathematical Background

The Deepsphere package uses the manifold of the sphere to perform the convolutions on the data. Underlying the application of convolutional networks to spherical data through a graph-based discretization lies the field of Graph Signal Processing (GSP). Graph Signal Processing is a field trying to define classical spectral methods on graphs, similarly to the theories existing in the time domain.

This section attempts to give the key concepts of the sphere manifold in the form of a graph, and how manipulating the data in the eigenvector space allows an optimal convolution operation on the sphere. For an in-depth introduction to the topic, see for example Graph Signal Processing: Overview, Challenges and Applications (2017) or The Emerging Field of Signal Processing on Graphs (2012). For simpler introductions to the matter, you may refer to Chapter 1.2 of J. Paratte's PhD Thesis or Chapter 2.1 of L. Martin's PhD Thesis. For an introduction to graph convolutions in the context of neural networks see for example Convolutional neural networks on graphs with fast localized spectral filtering (2016).

Following GSP paradigms, the convolution operator defined on graphs can be computed simply with a multiplication in the correct domain, just like classical signal processing. Indeed, in traditional signal processing, filtering (i.e., convolution) can be carried out by a pointwise multiplication as long as the signal is transformed to the Fourier domain. Thus, given a graph signal, we define its graph Fourier transform as the projection of the signal onto the set of eigenvectors of the graph Laplacian:

alt text,

where U and Λ are the results of the eigendecomposition of the Laplacian, i.e. alt text .

To bring the data to the spectral domain several Laplacians could be used. We decide here that we select the combinatorial Laplacian,alt text, which is simply defined as:

alt text,

where W is the weighted adjacency matrix of the graph and D is the diagonal matrix composed of the degrees, the sum of the weights of all the edges for each node, on its diagonal.

Filtering, the convolution operator, is defined to this end via a graph filter called g, a continuous function directly in the graph Fourier domain, enabling the direct multiplication. Based on the definition of the graph Fourier domain, we can then rewrite the graph filtering equation as a vector-matrix operation in the original domain (the graph domain):

alt text.

However, the filtering equation defined above involves the knowledge of the full set of eigenvectors U. Hence it implies the diagonalization of the Laplacian L which is extremely costly for large graphs. To circumvent this problem, one can represent the filter g as a polynomial approximation: the n-degree Chebyshev polynomials. The relation between the graph filter g(L), the graph signal x, and the Chebyshev polynomials lies in the approximation:

alt text,

where c_m are the coefficients of the approximation and describe entirely the shape of the graph filter g.

Since the Chebyshev polynomials of the first-kind are defined with the recurrence relation, the computation of the approximation is very efficient compared to diagonalization of L since it simply requires the computation of:

alt text,

where alt text and alt text.

Thus, learning the weights of the polynomial approximations makes it possible to learn generic graph filters. The convolution on a spherical graph comes down to backpropagating to tune the weights of the Chebyshev polynomials.

Unet

The architecture used for the deep learning model is a classic U-Net. The poolings and unpoolings used correspond to three types of possible spherical samplings: icosahedron, healpix and equiangular.

Temporality

Beyond reproducing in pytorch the ARTC experiment, we introduced a new dimension to our package: temporality. We did so following two approaches. First, we combined the U-Net with a recurrent neural network (LSTM) as presented in Recurrent Fully Convolutional Network for Video Segmentation. Secondly we augmented the feature space of the U-Net, thus taking more than one sample as an input.

Metric

The metric used to evaluate the performance of the model is the mean of the average precision of the classes "Atmospheric River" and "Tropical Cyclone". Only around 2% of the data is labelled as an atmospheric river and 0.1% of the data is labelled as a tropical cyclone. For such unbalanced datasets the average precision is an appropriate metric. The average precision metric allows to circumvent to some extent the trade-off between precision and recall performance. Average precision computes the average precision value for recall values over the interval 0 to 1. In other words it measures the area under the Precision-Recall Curve in a piecewise constant discretization manner. For the model, using average precision over each class/label type gives a sense of how well the model's detection is performed in the case of an unbalanced dataset.

Tools

Ignite. Ignite provides a clean training-valdiation-testing loop. Through ignite, engines corresponding to a trainer, validator and tester for the model can be created. Properties of these engines can be set using Handlers. For example, the trainer can have a handler to print certain information during training, and the validator can have a handler to log the metrics or a handler to change the learning rate based on the metrics of the epoch.

Tensorboard. Tensorboard allows to log metrics, training loss and learning rate rhythms. In the script, one can create a Summary Writer and attach to this object diverse saving options.

Visualizations. Visualizations are possible in 2D and 3D. The 2D representation is a flattened version of the sphere with a 2D projection of the sampling used (at the moment, this is possible for the icosahedron and equiangular samplings). The 3D gif rendering allows to represent the lables on a turning world sphere. Finally, an interactive plotting notebook is also presented as an inspiration for interactive plots. It allows to plot the metrics at a point in training (for a certain epoch), alongside the predicted labels plotted in 2D. This prediction is opposed to the plot of the ground truths in 2D.

License & co

The content of this repository is released under the terms of the MIT license.

The code, based on the TensorFlow implementation of DeepSphere, was mostly developed by Laure Vancauwenberghe and Michael Allemann while they were interning at Arcanite Solutions under the supervision of Yoann Ponti, Basile Chatillon, Julien Eberle, Lionel Martin, Johan Paratte, Michaël Defferrard.

Please consider citing our papers if you find this repository useful.

@inproceedings{deepsphere_iclr,
  title = {{DeepSphere}: a graph-based spherical {CNN}},
  author = {Defferrard, Michaël and Milani, Martino and Gusset, Frédérick and Perraudin, Nathanaël},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year = {2020},
  url = {https://openreview.net/forum?id=B1e3OlStPB},
}
@inproceedings{deepsphere_rlgm,
  title = {{DeepSphere}: towards an equivariant graph-based spherical {CNN}},
  author = {Defferrard, Micha\"el and Perraudin, Nathana\"el and Kacprzak, Tomasz and Sgier, Raphael},
  booktitle = {ICLR Workshop on Representation Learning on Graphs and Manifolds},
  year = {2019},
  archiveprefix = {arXiv},
  eprint = {1904.05146},
  url = {https://arxiv.org/abs/1904.05146},
}

deepsphere-pytorch's People

Contributors

aluo-x avatar aurelio-amerio avatar lionel-martin avatar maxjiang93 avatar mdeff avatar onanypoint avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deepsphere-pytorch's Issues

Bug in Healpix sampling in pygsp

Not sure the correct place to put this bug, but currently the healpix sampling is broken.

If we pass n_neighbors=None, the code in pygsp checks for Nside when function accepts nside.

pygsp Laplacian calculation

Thank you for for the great research on spherical graphs.

In https://github.com/deepsphere/deepsphere-pytorch/blob/master/deepsphere/utils/laplacian_funcs.py
the code attempts to import the class SphereIcosahedron and later call the method compute_laplacian on it.

I have installed the version of pygsp as per the readme instructions:

pip install git+https://github.com/epfl-lts2/pygsp.git@39a0665f637191152605911cf209fc16a36e5ae9#egg=PyGSP

However this version has no class SphereIcosahedron

Using Batch Norm after each Convolution layer

Good Evening,
I appreciate your great work on Spherical graphs. I am using the Chebchev Convolution (from pytorch) with symmetric normalization. However, in our case, we also use batch norm or instance norm after each convolution layers. My question is, does it make sense to have a Chebchev convolution with symmetric normalization followed by a batch norm or an instance norm? Does it help in faster convergence?

Thanks for you time :)

No module named 'pygsp.graphs.nngraphs.spherehealpix'

  • Thank you for your excellent research.
  • I can successfully install PyGSP via the tutorial in README on Windows, but failed on Ubuntu.
  • I solved this problem by copying files like "SphereHealpix", "SphereIcosahedron" and "SphereEquiangular" from windows to ubuntu.

root@6482b2ce164f:/usr/data/gzy# pip install git+https://github.com/Droxef/pygsp.git@6b216395beae25bf062d13fbf9abc251eeb5bbff#egg=PyGSP
Collecting PyGSP
Cloning https://github.com/Droxef/pygsp.git (to revision 6b216395beae25bf062d13fbf9abc251eeb5bbff) to /tmp/pip-install-mox7o3jf/pygsp
ERROR: Command errored out with exit status 1:
command: /usr/data/gzy/software/anaconda3/envs/ML/bin/python -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-mox7o3jf/pygsp/setup.py'"'"'; file='"'"'/tmp/pip-install-mox7o3jf/pygsp/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(file);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, file, '"'"'exec'"'"'))' egg_info --egg-base /tmp/pip-pip-egg-info-xkq5bw01
cwd: /tmp/pip-install-mox7o3jf/pygsp/
Complete output (7 lines):
Traceback (most recent call last):
File "", line 1, in
File "/tmp/pip-install-mox7o3jf/pygsp/setup.py", line 11, in
long_description=open('README.rst').read(),
File "/usr/data/gzy/software/anaconda3/envs/ML/lib/python3.6/encodings/ascii.py", line 26, in decode
return codecs.ascii_decode(input, self.errors)[0]
UnicodeDecodeError: 'ascii' codec can't decode byte 0xc5 in position 2512: ordinal not in range(128)
----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Newer python version

I was wondering if there are plans to update the compatibility with newer python versions as 3.7 is at end of life and is no longer receiving security updates. the main issue is that many HPC are no longer building wheels for this which may cause issues in the future.

What kind of data should i provide to use deepsphere-pytorch?

  • Thanks for the excellent and useful work.
  • I wish to deploy this spherical convolutional model on my spherical data (signal from a spherical detector).

A few weeks ago, I basically implemented it and did some training with using a small amount of data on DeepSphere package (based on TensorFlow 1.x ). But at that time I found that my data was too big to be read into memory. So I want to try this version implemented by PyTorch.

I understand DataLoader and other methods in PyTorch. But I am not sure what kind of data I should prepare. Should I prepare some arrays that has been sampled by HealPix? Or should I prepare some panoramic images?

I found that the image in the code seems to have only two dimensions, which are vertices and features. I think vertices is the index after sampling the spherical data, and the feature is the channel. Do I understand it right?

If so, how can I transform the (height, width, channel) image into (vertices, feature) data? Or what sampling method should I use to get the (V, F) data from spherical data in the first place?

In my network environment, the sample data is too big so I can't download it. So it's a bit difficult for me to get clues from the data directly. Sorry, I am still quite confused now. Could you give me some hints? Thanks again!

No module named 'pygsp.graphs.nngraphs.spherehealpix'

Traceback (most recent call last):
File "/projs/AE/zhaoying/projects/sdu/deepsphere-pytorch/scripts/run_ar_tc.py", line 21, in
from deepsphere.models.spherical_unet.unet_model import SphericalUNet
File "/projs/AE/zhaoying/projects/sdu/deepsphere-pytorch/deepsphere/models/spherical_unet/unet_model.py", line 14, in
from deepsphere.utils.laplacian_funcs import get_equiangular_laplacians, get_healpix_laplacians, get_icosahedron_laplacians
File "/projs/AE/zhaoying/projects/sdu/deepsphere-pytorch/deepsphere/utils/laplacian_funcs.py", line 6, in
from pygsp.graphs.nngraphs.spherehealpix import SphereHealpix
ModuleNotFoundError: No module named 'pygsp.graphs.nngraphs.spherehealpix'

What does Healpix pooling return?

Hey. Thanks for the great research work on spherical graphs. I am trying to adapt the code for spherical images.
Coming to my question, You mentioned that the Healpix pooling requires the pixel to be in Nested Ordering. I wanted to ask you, what will be the ordering of the pixel (Ring or Nested?) after Healpix pooling is done on the spherical Graph.

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.