Code Monkey home page Code Monkey logo

capsgnn's Introduction

CapsGNN

PWC codebeat badge repo size benedekrozemberczki

A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019).

Abstract

The high-quality node embeddings learned from the Graph Neural Networks (GNNs) have been applied to a wide range of node-based applications and some of them have achieved state-of-the-art (SOTA) performance. However, when applying node embeddings learned from GNNs to generate graph embeddings, the scalar node representation may not suffice to preserve the node/graph properties efficiently, resulting in sub-optimal graph embeddings. Inspired by the Capsule Neural Network (CapsNet), we propose the Capsule Graph Neural Network (CapsGNN), which adopts the concept of capsules to address the weakness in existing GNN-based graph embeddings algorithms. By extracting node features in the form of capsules, routing mechanism can be utilized to capture important information at the graph level. As a result, our model generates multiple embeddings for each graph to capture graph properties from different aspects. The attention module incorporated in CapsGNN is used to tackle graphs with various sizes which also enables the model to focus on critical parts of the graphs. Our extensive evaluations with 10 graph-structured datasets demonstrate that CapsGNN has a powerful mechanism that operates to capture macroscopic properties of the whole graph by data-driven. It outperforms other SOTA techniques on several graph classification tasks, by virtue of the new instrument.

This repository provides a PyTorch implementation of CapsGNN as described in the paper:

Capsule Graph Neural Network. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Paper]

The core Capsule Neural Network implementation adapted is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id and node label has to be indexed from 0. Keys of dictionaries are stored strings in order to make JSON serialization possible.

Every JSON file has the following key-value structure:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
 "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
 "target": 1}

The **edges** key has an edge list value which descibes the connectivity structure. The **labels** key has labels for each node which are stored as a dictionary -- within this nested dictionary labels are values, node identifiers are keys. The **target** key has an integer value which is the class membership.

Outputs

The predictions are saved in the `output/` directory. Each embedding has a header and a column with the graph identifiers. Finally, the predictions are sorted by the identifier column.

Options

Training a CapsGNN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --training-graphs   STR    Training graphs folder.      Default is `input/train/`.
  --testing-graphs    STR    Testing graphs folder.       Default is `input/test/`.
  --prediction-path   STR    Output predictions file.     Default is `output/watts_predictions.csv`.

Model options

  --epochs                      INT     Number of epochs.                  Default is 100.
  --batch-size                  INT     Number fo graphs per batch.        Default is 32.
  --gcn-filters                 INT     Number of filters in GCNs.         Default is 20.
  --gcn-layers                  INT     Number of GCNs chained together.   Default is 2.
  --inner-attention-dimension   INT     Number of neurons in attention.    Default is 20.  
  --capsule-dimensions          INT     Number of capsule neurons.         Default is 8.
  --number-of-capsules          INT     Number of capsules in layer.       Default is 8.
  --weight-decay                FLOAT   Weight decay of Adam.              Defatuls is 10^-6.
  --lambd                       FLOAT   Regularization parameter.          Default is 0.5.
  --theta                       FLOAT   Reconstruction loss weight.        Default is 0.1.
  --learning-rate               FLOAT   Adam learning rate.                Default is 0.01.

Examples

The following commands learn a model and save the predictions. Training a model on the default dataset:

$ python src/main.py

Training a CapsGNNN model for a 100 epochs.

$ python src/main.py --epochs 100

Changing the batch size.

$ python src/main.py --batch-size 128

License

capsgnn's People

Contributors

benedekrozemberczki avatar lhbrichard avatar utkarsh87 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

capsgnn's Issues

Not able to install torch-scatter with torch 0.4.1

Hello,

Thanks for sharing the implementation.

While I'm try to run your code I get some error for installing the environment.
I have torch 0.4.1, but not able to install torch-scatter.Got the following error:
fatal error: torch/extension.h: No such file or directory

But I can successfully install them for torch 1.0.

Is your code working for torch 1.0? Or how to install torch-scatter for torch 0.4.1?

Details:

$ pip list
Package Version


backcall 0.1.0
certifi 2018.8.24
....
torch 0.4.1.post2
torch-geometric 1.1.1
torchfile 0.1.0
torchvision 0.2.1
tornado 5.1
tqdm 4.31.1
traitlets 4.3.2
urllib3 1.23
visdom 0.1.8.5
vispy 0.5.3
.... ....

$pip install torch-scatter

Coordinate Addition module & Routing

Hi, thanks for your codes of GapsGNN. And I have some questions about Coordinate Addition module and Routing.

  1. Do you use Coordinate Addition module in this codes?
  2. In /src/layers.py, line 137 : c_ij = torch.nn.functional.softmax(b_ij, dim=0) .
    At this time, b_ij.size(0) == 1, why use dim =0 ?

Thanks again.

Other datasets

I notice some datasets in your paper such as RE-M5K and RE-M12K. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

Experimental result

Thank you for your work. Then the experimental results I ran out were very different from those in the paper.

Reproduce Issues

Hi, thanks for your PyTorch codes of GapsGNN. I try to run the codes on NCI, DD, and other graph classification datasets, but it doesn't work (For example, training loss converges to 2.0, and test acc is about 50% on NCI1 after several iterations.)
How should I do if I want to run these codes on NCI, DD and etc?
Thanks again.

D&D dataset

I notice some datasets in your paper such as D&D dataset. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

how to repeat your expriments?

Enumerating feature and target values.

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 14754.82it/s]

Training started.

Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00, 1.90it/s]
CapsGNN (Loss=0.7279): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]

Scoring.

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 128.47it/s]

Accuracy: 0.3333

Accuracy is too small

Train on Image Dataset

Hi.. thanks for sharing this repo.
What in need to modify so I can train this on an image dataset?
Thanks

How to generate such dataset

When I transfer the MUTAG to the specific json file, the accuracy is only round 62%. Can you give me some suggestion or the generate file

About dataset

Hi, could you tell me what dataset use in this project?

Something about reshape

Hi @benedekrozemberczki ! Thank you for your work!

I have a question at line 61 and 62 of CapsGNN/src/capsgnn.py

hidden_representations = torch.cat(tuple(hidden_representations))
hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters,-1)

Why you directly reshape L*N,D to 1,L,D,N instead of using permutation after reshape, e.g

hidden_representations = hidden_representations.view(1, self.args.gcn_layers, -1,self.args.gcn_filters).permute(0,1,3,2)

Thank you for your help!

运行代码就报错?都没修改

Traceback (most recent call last):
File "src/main.py", line 19, in
main()
File "src/main.py", line 14, in main
model.fit()
File "/lmq/CapsGNN/src/capsgnn.py", line 293, in fit
prediction, reconstruction_loss = self.model(data)
File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/lmq/CapsGNN/src/capsgnn.py", line 147, in forward
first_capsule_output = self.first_capsule(hidden_representations)
File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/lmq/CapsGNN/src/layers.py", line 86, in forward
u = [self.unitsi for i in range(self.num_units)]
File "/lmq/CapsGNN/src/layers.py", line 86, in
u = [self.unitsi for i in range(self.num_units)]
File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 313, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 309, in _conv_forward
return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [1, 2, 20, 15]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.